Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 13, 2024
1 parent b7c073c commit fddf038
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions ferritin-esm/src/esm2/models/modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

use crate::ESM2Config;

use super::axial_attention::{ColumnSelfAttention, RowSelfAttention};
use super::multihead_attention::MultiheadAttention;
use crate::ESM2Config;
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{self as nn, VarBuilder, VarMap};
use std::f64::consts::PI;
Expand Down Expand Up @@ -38,7 +37,6 @@ fn apc(x: &Tensor) -> Result<()> {
pub struct ESM1LayerNorm {
// weight: Tensor,
// bias: Option<Tensor>,
// eps: f64,
}

impl ESM1LayerNorm {
Expand Down Expand Up @@ -86,8 +84,8 @@ pub type ESM1bLayerNorm = ESM1LayerNorm;
pub struct TransformerLayer {
self_attn: MultiheadAttention,
self_attn_layer_norm: ESM1bLayerNorm,
fc1: nn::Linear,
fc2: nn::Linear,
// fc1: nn::Linear,
// fc2: nn::Linear,
final_layer_norm: ESM1bLayerNorm,
}

Expand All @@ -103,21 +101,21 @@ impl TransformerLayer {
..
} = config;

// Todo: Fix this
// Todo: Fix this!
let embed_dim = 100;
let ffn_embed_dim = 100;

let layer_norm = ESM1LayerNorm::load(vb.pp("Layer_Norm"), config)?;
let multi_head = MultiheadAttention::load(vb.pp("attention"), config)?;
let fc1 = nn::linear(embed_dim, ffn_embed_dim, vb.pp("fc1"))?;
let fc2 = nn::linear(ffn_embed_dim, embed_dim, vb.pp("fc2"))?;
// let fc1 = nn::linear(embed_dim, ffn_embed_dim, vb.pp("fc1"))?;
// let fc2 = nn::linear(ffn_embed_dim, embed_dim, vb.pp("fc2"))?;
let final_layer_norm = ESM1LayerNorm::load(vb.pp("LayerNorm"), config)?;

Ok(Self {
self_attn: multi_head,
self_attn_layer_norm: layer_norm,
fc1,
fc2,
// fc1,
// fc2,
final_layer_norm,
})
}
Expand Down Expand Up @@ -372,15 +370,19 @@ impl SinusoidalPositionalEmbedding {

#[derive(Debug)]
pub struct RobertaLMHead {
// dense: candle_nn::Linear,
// layer_norm: ESM1bLayerNorm,
// weight: Tensor,
// bias: Tensor,
dense: candle_nn::Linear,
layer_norm: ESM1bLayerNorm,
}

impl RobertaLMHead {
pub fn load(vb: VarBuilder, config: &ESM2Config) -> Result<Self> {
Ok(Self {})
// Todo: Fix this!
let embed_dim = 100;
let output_dim = 100;
let dense = candle_nn::linear(embed_dim, output_dim, vb.pp("dense"))?;
let layer_norm = ESM1bLayerNorm::load(vb.pp("LayerNorm"), config)?;

Ok(Self { dense, layer_norm })
}
// pub fn new(
// embed_dim: usize,
Expand Down

0 comments on commit fddf038

Please sign in to comment.