Skip to content

Commit

Permalink
start passing in values from config
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 13, 2024
1 parent fddf038 commit e1da76d
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions ferritin-esm/src/esm2/models/esm2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ impl ESM2Config {
vocab_size: 33,
}
}
pub fn esm2_t6_8M_ur50() -> Self {
Self {
num_attention_heads: 20,
attention_probs_dropout_prob: 0.0,
classifier_dropout: None,
emb_layer_norm_before: false,
esmfold_config: None,
hidden_act: "gelu".to_string(),
hidden_dropout_prob: 0.0,
hidden_size: 320,
initializer_range: 0.02,
intermediate_size: 1280,
is_folding_model: false,
layer_norm_eps: 1e-5,
mask_token_id: 32,
max_position_embeddings: 1026,
model_type: "esm".to_string(),
num_hidden_layers: 6,
pad_token_id: 1,
position_embedding_type: "rotary".to_string(),
token_dropout: true,
torch_dtype: "float32".to_string(),
transformers_version: "4.25.0.dev0".to_string(),
use_cache: true,
vocab_list: None,
vocab_size: 33,
}
}
}

/// ESM2 Architecture
Expand All @@ -77,12 +105,11 @@ impl ESM2 {
// note: in thisload function we do NOT handle the embedding code which gets invoked only when the model is invokes with tokens
pub fn load(vb: VarBuilder, config: &ESM2Config) -> Result<Self> {
let ESM2Config {
num_attention_heads,
..
num_hidden_layers, ..
} = config;
let num_layers = num_attention_heads.clone() as usize;
let mut layers = Vec::with_capacity(num_layers as usize);
for i in 0..num_layers {

let mut layers = Vec::with_capacity(*num_hidden_layers as usize);
for i in 0..*num_hidden_layers {
let transformer_layer = TransformerLayer::load(vb.pp(format!("layer.{}", i)), config)?;
layers.push(transformer_layer);
}
Expand All @@ -92,10 +119,10 @@ impl ESM2 {

Ok(Self {
embed_tokens: None,
layers: layers,
contact_head: contact_head,
emb_layer_norm_after: emb_layer_norm_after,
lm_head: lm_head,
layers,
contact_head,
emb_layer_norm_after,
lm_head,
config: config.clone(),
})
}
Expand Down

0 comments on commit e1da76d

Please sign in to comment.