diff --git a/ferritin-esm/src/esm2/models/modules.rs b/ferritin-esm/src/esm2/models/modules.rs index 4424cb09..085ffc44 100644 --- a/ferritin-esm/src/esm2/models/modules.rs +++ b/ferritin-esm/src/esm2/models/modules.rs @@ -3,10 +3,9 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use crate::ESM2Config; - use super::axial_attention::{ColumnSelfAttention, RowSelfAttention}; use super::multihead_attention::MultiheadAttention; +use crate::ESM2Config; use candle_core::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{self as nn, VarBuilder, VarMap}; use std::f64::consts::PI; @@ -38,7 +37,6 @@ fn apc(x: &Tensor) -> Result<()> { pub struct ESM1LayerNorm { // weight: Tensor, // bias: Option, - // eps: f64, } impl ESM1LayerNorm { @@ -86,8 +84,8 @@ pub type ESM1bLayerNorm = ESM1LayerNorm; pub struct TransformerLayer { self_attn: MultiheadAttention, self_attn_layer_norm: ESM1bLayerNorm, - fc1: nn::Linear, - fc2: nn::Linear, + // fc1: nn::Linear, + // fc2: nn::Linear, final_layer_norm: ESM1bLayerNorm, } @@ -103,21 +101,21 @@ impl TransformerLayer { .. } = config; - // Todo: Fix this + // Todo: Fix this! let embed_dim = 100; let ffn_embed_dim = 100; let layer_norm = ESM1LayerNorm::load(vb.pp("Layer_Norm"), config)?; let multi_head = MultiheadAttention::load(vb.pp("attention"), config)?; - let fc1 = nn::linear(embed_dim, ffn_embed_dim, vb.pp("fc1"))?; - let fc2 = nn::linear(ffn_embed_dim, embed_dim, vb.pp("fc2"))?; + // let fc1 = nn::linear(embed_dim, ffn_embed_dim, vb.pp("fc1"))?; + // let fc2 = nn::linear(ffn_embed_dim, embed_dim, vb.pp("fc2"))?; let final_layer_norm = ESM1LayerNorm::load(vb.pp("LayerNorm"), config)?; Ok(Self { self_attn: multi_head, self_attn_layer_norm: layer_norm, - fc1, - fc2, + // fc1, + // fc2, final_layer_norm, }) } @@ -372,15 +370,19 @@ impl SinusoidalPositionalEmbedding { #[derive(Debug)] pub struct RobertaLMHead { - // dense: candle_nn::Linear, - // layer_norm: ESM1bLayerNorm, - // weight: Tensor, - // bias: Tensor, + dense: candle_nn::Linear, + layer_norm: ESM1bLayerNorm, } impl RobertaLMHead { pub fn load(vb: VarBuilder, config: &ESM2Config) -> Result { - Ok(Self {}) + // Todo: Fix this! + let embed_dim = 100; + let output_dim = 100; + let dense = candle_nn::linear(embed_dim, output_dim, vb.pp("dense"))?; + let layer_norm = ESM1bLayerNorm::load(vb.pp("LayerNorm"), config)?; + + Ok(Self { dense, layer_norm }) } // pub fn new( // embed_dim: usize,