Skip to content

Commit

Permalink
update the modules
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 13, 2024
1 parent e1da76d commit 243d23c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 28 deletions.
48 changes: 24 additions & 24 deletions ferritin-esm/src/esm2/models/esm2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,30 @@ use tokenizers::Tokenizer;

#[derive(Deserialize, Clone)]
pub struct ESM2Config {
num_attention_heads: i32,
attention_probs_dropout_prob: f32,
classifier_dropout: Option<f32>,
emb_layer_norm_before: bool,
esmfold_config: Option<String>,
hidden_act: String,
hidden_dropout_prob: f32,
hidden_size: i32,
initializer_range: f32,
intermediate_size: i32,
is_folding_model: bool,
layer_norm_eps: f32,
mask_token_id: i32,
max_position_embeddings: i32,
model_type: String,
num_hidden_layers: i32,
pad_token_id: i32,
position_embedding_type: String,
token_dropout: bool,
torch_dtype: String,
transformers_version: String,
use_cache: bool,
vocab_list: Option<Vec<String>>,
vocab_size: i32,
pub(crate) num_attention_heads: i32,
pub(crate) attention_probs_dropout_prob: f32,
pub(crate) classifier_dropout: Option<f32>,
pub(crate) emb_layer_norm_before: bool,
pub(crate) esmfold_config: Option<String>,
pub(crate) hidden_act: String,
pub(crate) hidden_dropout_prob: f32,
pub(crate) hidden_size: i32,
pub(crate) initializer_range: f32,
pub(crate) intermediate_size: i32,
pub(crate) is_folding_model: bool,
pub(crate) layer_norm_eps: f32,
pub(crate) mask_token_id: i32,
pub(crate) max_position_embeddings: i32,
pub(crate) model_type: String,
pub(crate) num_hidden_layers: i32,
pub(crate) pad_token_id: i32,
pub(crate) position_embedding_type: String,
pub(crate) token_dropout: bool,
pub(crate) torch_dtype: String,
pub(crate) transformers_version: String,
pub(crate) use_cache: bool,
pub(crate) vocab_list: Option<Vec<String>>,
pub(crate) vocab_size: i32,
}

impl ESM2Config {
Expand Down
6 changes: 2 additions & 4 deletions ferritin-esm/src/esm2/models/modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,8 @@ pub struct RobertaLMHead {

impl RobertaLMHead {
pub fn load(vb: VarBuilder, config: &ESM2Config) -> Result<Self> {
// Todo: Fix this!
let embed_dim = 100;
let output_dim = 100;
let dense = candle_nn::linear(embed_dim, output_dim, vb.pp("dense"))?;
let ESM2Config { hidden_size, .. } = config;
let dense = nn::linear(*hidden_size as usize, *hidden_size as usize, vb.pp("dense"))?;
let layer_norm = ESM1bLayerNorm::load(vb.pp("LayerNorm"), config)?;

Ok(Self { dense, layer_norm })
Expand Down

0 comments on commit 243d23c

Please sign in to comment.