diff --git a/ferritin-esm/src/esm2/models/esm2.rs b/ferritin-esm/src/esm2/models/esm2.rs index b30ad557..8f712021 100644 --- a/ferritin-esm/src/esm2/models/esm2.rs +++ b/ferritin-esm/src/esm2/models/esm2.rs @@ -6,30 +6,30 @@ use tokenizers::Tokenizer; #[derive(Deserialize, Clone)] pub struct ESM2Config { - num_attention_heads: i32, - attention_probs_dropout_prob: f32, - classifier_dropout: Option, - emb_layer_norm_before: bool, - esmfold_config: Option, - hidden_act: String, - hidden_dropout_prob: f32, - hidden_size: i32, - initializer_range: f32, - intermediate_size: i32, - is_folding_model: bool, - layer_norm_eps: f32, - mask_token_id: i32, - max_position_embeddings: i32, - model_type: String, - num_hidden_layers: i32, - pad_token_id: i32, - position_embedding_type: String, - token_dropout: bool, - torch_dtype: String, - transformers_version: String, - use_cache: bool, - vocab_list: Option>, - vocab_size: i32, + pub(crate) num_attention_heads: i32, + pub(crate) attention_probs_dropout_prob: f32, + pub(crate) classifier_dropout: Option, + pub(crate) emb_layer_norm_before: bool, + pub(crate) esmfold_config: Option, + pub(crate) hidden_act: String, + pub(crate) hidden_dropout_prob: f32, + pub(crate) hidden_size: i32, + pub(crate) initializer_range: f32, + pub(crate) intermediate_size: i32, + pub(crate) is_folding_model: bool, + pub(crate) layer_norm_eps: f32, + pub(crate) mask_token_id: i32, + pub(crate) max_position_embeddings: i32, + pub(crate) model_type: String, + pub(crate) num_hidden_layers: i32, + pub(crate) pad_token_id: i32, + pub(crate) position_embedding_type: String, + pub(crate) token_dropout: bool, + pub(crate) torch_dtype: String, + pub(crate) transformers_version: String, + pub(crate) use_cache: bool, + pub(crate) vocab_list: Option>, + pub(crate) vocab_size: i32, } impl ESM2Config { diff --git a/ferritin-esm/src/esm2/models/modules.rs b/ferritin-esm/src/esm2/models/modules.rs index 085ffc44..386b2d24 100644 --- a/ferritin-esm/src/esm2/models/modules.rs +++ b/ferritin-esm/src/esm2/models/modules.rs @@ -376,10 +376,8 @@ pub struct RobertaLMHead { impl RobertaLMHead { pub fn load(vb: VarBuilder, config: &ESM2Config) -> Result { - // Todo: Fix this! - let embed_dim = 100; - let output_dim = 100; - let dense = candle_nn::linear(embed_dim, output_dim, vb.pp("dense"))?; + let ESM2Config { hidden_size, .. } = config; + let dense = nn::linear(*hidden_size as usize, *hidden_size as usize, vb.pp("dense"))?; let layer_norm = ESM1bLayerNorm::load(vb.pp("LayerNorm"), config)?; Ok(Self { dense, layer_norm })