diff --git a/ferritin-esm/src/esm2/models/esm2.rs b/ferritin-esm/src/esm2/models/esm2.rs index b60c1d17..b30ad557 100644 --- a/ferritin-esm/src/esm2/models/esm2.rs +++ b/ferritin-esm/src/esm2/models/esm2.rs @@ -61,6 +61,34 @@ impl ESM2Config { vocab_size: 33, } } + pub fn esm2_t6_8M_ur50() -> Self { + Self { + num_attention_heads: 20, + attention_probs_dropout_prob: 0.0, + classifier_dropout: None, + emb_layer_norm_before: false, + esmfold_config: None, + hidden_act: "gelu".to_string(), + hidden_dropout_prob: 0.0, + hidden_size: 320, + initializer_range: 0.02, + intermediate_size: 1280, + is_folding_model: false, + layer_norm_eps: 1e-5, + mask_token_id: 32, + max_position_embeddings: 1026, + model_type: "esm".to_string(), + num_hidden_layers: 6, + pad_token_id: 1, + position_embedding_type: "rotary".to_string(), + token_dropout: true, + torch_dtype: "float32".to_string(), + transformers_version: "4.25.0.dev0".to_string(), + use_cache: true, + vocab_list: None, + vocab_size: 33, + } + } } /// ESM2 Architecture @@ -77,12 +105,11 @@ impl ESM2 { // note: in thisload function we do NOT handle the embedding code which gets invoked only when the model is invokes with tokens pub fn load(vb: VarBuilder, config: &ESM2Config) -> Result { let ESM2Config { - num_attention_heads, - .. + num_hidden_layers, .. } = config; - let num_layers = num_attention_heads.clone() as usize; - let mut layers = Vec::with_capacity(num_layers as usize); - for i in 0..num_layers { + + let mut layers = Vec::with_capacity(*num_hidden_layers as usize); + for i in 0..*num_hidden_layers { let transformer_layer = TransformerLayer::load(vb.pp(format!("layer.{}", i)), config)?; layers.push(transformer_layer); } @@ -92,10 +119,10 @@ impl ESM2 { Ok(Self { embed_tokens: None, - layers: layers, - contact_head: contact_head, - emb_layer_norm_after: emb_layer_norm_after, - lm_head: lm_head, + layers, + contact_head, + emb_layer_norm_after, + lm_head, config: config.clone(), }) }