Skip to content

Commit

Permalink
fix vb name
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 13, 2024
1 parent b59b622 commit 0e5d0f4
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 45 deletions.
3 changes: 2 additions & 1 deletion ferritin-esm/src/esm2/models/esm2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ impl ESM2 {
layers.push(transformer_layer);
}
let contact_head = ContactPredictionHead::load(vb.pp("esm.contact_head"), config)?;
let emb_layer_norm_after = ESM1bLayerNorm::load(vb.pp("emb_layer_norm_after"), config)?;
let emb_layer_norm_after =
ESM1bLayerNorm::load(vb.pp("esm.encoder.emb_layer_norm_after"), config)?;
let lm_head = RobertaLMHead::load(vb.pp("lm_head"), config)?;

Ok(Self {
Expand Down
72 changes: 36 additions & 36 deletions ferritin-esm/src/esm2/models/modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,44 +471,44 @@ pub struct NormalizedResidualBlock<T: Module> {
layer_norm: ESM1bLayerNorm,
}

// impl<T: Module> NormalizedResidualBlock<T> {
// pub fn new(layer: T, embedding_dim: usize, dropout: f64) -> Result<Self> {
// let vb = VarBuilder::zeros();
// Ok(Self {
// layer,
// dropout,
// layer_norm: ESM1bLayerNorm::new(embedding_dim, 1e-12, true, vb)?,
// })
// }
impl<T: Module> NormalizedResidualBlock<T> {
// pub fn new(layer: T, embedding_dim: usize, dropout: f64) -> Result<Self> {
// let vb = VarBuilder::zeros();
// Ok(Self {
// layer,
// dropout,
// layer_norm: ESM1bLayerNorm::new(embedding_dim, 1e-12, true, vb)?,
// })
// }

// pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
// let residual = x;
// let x = self.layer_norm.forward(x)?;
// let x = self.layer.forward(&x)?;
// let x = if self.dropout > 0. {
// x.dropout(self.dropout)?
// } else {
// x
// };
// x.add(residual)
// }
// pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
// let residual = x;
// let x = self.layer_norm.forward(x)?;
// let x = self.layer.forward(&x)?;
// let x = if self.dropout > 0. {
// x.dropout(self.dropout)?
// } else {
// x
// };
// x.add(residual)
// }

// pub fn forward_t<A, B>(&self, x: &Tensor, a: A, b: B) -> Result<(Tensor, Tensor)>
// where
// T: ModuleWithAttention<A, B>,
// {
// let residual = x;
// let x = self.layer_norm.forward(x)?;
// let (x, attn) = self.layer.forward_t(&x, a, b)?;
// let x = if self.dropout > 0. {
// x.dropout(self.dropout)?
// } else {
// x
// };
// let x = x.add(residual)?;
// Ok((x, attn))
// }
// }
// pub fn forward_t<A, B>(&self, x: &Tensor, a: A, b: B) -> Result<(Tensor, Tensor)>
// where
// T: ModuleWithAttention<A, B>,
// {
// let residual = x;
// let x = self.layer_norm.forward(x)?;
// let (x, attn) = self.layer.forward_t(&x, a, b)?;
// let x = if self.dropout > 0. {
// x.dropout(self.dropout)?
// } else {
// x
// };
// let x = x.add(residual)?;
// Ok((x, attn))
// }
}

pub trait ModuleWithAttention<A, B> {
fn forward_t(&self, x: &Tensor, a: A, b: B) -> Result<(Tensor, Tensor)>;
Expand Down
14 changes: 6 additions & 8 deletions ferritin-esm/src/esm2/models/multihead_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ use super::esm2::ESM2Config;
use super::rotary_embedding::RotaryEmbedding;
use candle_core::{Device, Module, Result, Tensor};
use candle_nn::{self as nn, linear, ops, VarBuilder};
use std::collections::HashMap;

// use uuid::Uuid;

// pub fn utils_softmax(x: &Tensor, dim: i64, onnx_trace: bool) -> Result<Tensor> {
// if onnx_trace {
Expand Down Expand Up @@ -72,16 +69,16 @@ pub struct MultiheadAttention {
// scaling: f64,
// self_attention: bool,
// encoder_decoder_attention: bool,
q_proj: nn::Linear,
k_proj: nn::Linear,
v_proj: nn::Linear,
out_proj: nn::Linear,
// bias_k: Option<Tensor>,
// bias_v: Option<Tensor>,
// add_zero_attn: bool,
rot_emb: Option<RotaryEmbedding>,
// onnx_trace: bool,
// enable_torch_version: bool,
q_proj: nn::Linear,
k_proj: nn::Linear,
v_proj: nn::Linear,
out_proj: nn::Linear,
rot_emb: Option<RotaryEmbedding>,
// incremental_state: FairseqIncrementalState,
}

Expand All @@ -93,6 +90,7 @@ impl MultiheadAttention {
..
} = config;

// "num_attention_heads": 20,
let embed_dim = *hidden_size as usize;
let num_heads = *num_attention_heads as usize;
let head_dim = embed_dim / num_heads;
Expand Down
3 changes: 3 additions & 0 deletions ferritin-esm/src/esm2/models/rotary_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ pub struct RotaryEmbedding {

impl RotaryEmbedding {
pub fn load(vb: VarBuilder, config: &ESM2Config) -> Result<Self> {
// todo: I am pullin gout the num_hidden_layers here but the real shape is
// Name: esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq, Shape: [8]
// where is that `8` coming from?
let ESM2Config {
num_hidden_layers, ..
} = config;
Expand Down

0 comments on commit 0e5d0f4

Please sign in to comment.