Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 6, 2024
1 parent f14f0c0 commit 1bc5c65
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 56 deletions.
251 changes: 250 additions & 1 deletion ferritin-esm/Readme.md

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions ferritin-esm/examples/esmc/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use candle_core::{DType, Device, D};
use candle_hf_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
use ferritin_esm::{ESMCConfig, ESMC};
use safetensors::SafeTensors;

// pub fn esmc_300m_202412(device: &Device) -> Result<Box<dyn Model>> {
// let tokenizer = get_model_tokenizers(ESM3_OPEN_SMALL)?.sequence;
Expand Down Expand Up @@ -38,8 +37,15 @@ fn main() -> Result<()> {

let vb = VarBuilder::from_backend(Box::new(pth), DType::F32, Device::Cpu);
let config = ESMCConfig::esmc_300m();
let esmc = ESMC::load(vb, config);
let esmc = ESMC::load(vb.clone(), config)?;
// println!("ESMC Loaded: {}", esmc);

// Error: cannot find tensor transformer.layer.attention.layer_norm.weight

println!(
"VB: {}",
vb.contains_tensor("transformer.blocks.6.attn.layernorm_qkv.1.weight")
);

println!("ESMC Loaded");
Ok(())
}
15 changes: 10 additions & 5 deletions ferritin-esm/src/esm/layers/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,21 @@ impl MultiHeadAttention {
} = config;

let d_head = d_model / n_heads;
let ln_conf = LayerNormConfig::from(1e-5);
let layernorm = nn::layer_norm(*d_model, ln_conf, vb.pp("layer_norm"))?;
let linear = nn::linear(*d_model, d_model * 3, vb.pp("linear1"))?;
// let ln_conf = LayerNormConfig::from(1e-5);
let ln_conf = LayerNormConfig {
eps: 1e-5,
remove_mean: true,
affine: false,
};
let layernorm = nn::layer_norm(*d_model, ln_conf, vb.pp("layernorm_qkv.0"))?;
let linear = nn::linear_no_bias(*d_model, d_model * 3, vb.pp("layernorm_qkv.1"))?;
let layernorm_qkv = nn::seq().add(layernorm).add(linear);
let out_proj = nn::linear(*d_model, *d_model, vb.pp("out_proj"))?;

let out_proj = nn::linear_no_bias(*d_model, *d_model, vb.pp("out_proj"))?;
// note: only handling the True case for the moment
// let qk_layernorm = true
let q_ln = Box::new(nn::layer_norm(*d_model, ln_conf, vb.pp("q_ln"))?);
let k_ln = Box::new(nn::layer_norm(*d_model, ln_conf, vb.pp("k_ln"))?);

let rotary = RotaryEmbedding::load(vb.pp("rotary"), config)?;

Ok(Self {
Expand Down
34 changes: 19 additions & 15 deletions ferritin-esm/src/esm/layers/blocks.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::attention::MultiHeadAttention;
use super::geom_attention::GeometricReasoningOriginalImpl;
use crate::esm::models::esmc::{ESMCConfig, Ffn_Type};
use crate::esm::utils::structure::affine3d::Affine3D;
// use crate::esm::utils::structure::affine3d::Affine3D;
use candle_core::{Module, Result, Tensor, D};
use candle_nn::ops::silu;
use candle_nn::{self as nn, VarBuilder};
Expand Down Expand Up @@ -29,9 +29,9 @@ impl SwiGLU {
let hidden_dim = Self::swiglu_correction_fn(*expansion_ratio, *d_model);

Ok(Self {
layer_norm: nn::layer_norm(*d_model, 1e-5, vb.pp("layer_norm"))?,
linear1: nn::linear(*d_model, hidden_dim * 2, vb.pp("linear1"))?,
linear2: nn::linear(hidden_dim, *d_model, vb.pp("linear2"))?,
layer_norm: nn::layer_norm(*d_model, 1e-5, vb.pp("0"))?,
linear1: nn::linear_no_bias(*d_model, hidden_dim * 2, vb.pp("1"))?,
linear2: nn::linear_no_bias(hidden_dim, *d_model, vb.pp("3"))?,
})
}
}
Expand Down Expand Up @@ -131,29 +131,33 @@ impl UnifiedTransformerBlock {
pub fn load(vb: VarBuilder, config: &ESMCConfig, layer: usize) -> Result<Self> {
let ESMCConfig {
ffn_type,
v_head_transformer,
use_plain_attn,
n_layers_geom,
residue_scaling_factor,
..
} = config;

let use_geom_attn: bool = layer < *n_layers_geom;

let attn = match use_plain_attn {
false => None,
true => Some(MultiHeadAttention::load(vb.pp("attention"), config)?),
true => Some(MultiHeadAttention::load(vb.pp("attn"), config)?),
};

let geom_attn = match use_geom_attn {
false => None,
true => Some(GeometricReasoningOriginalImpl::load(
vb.pp("geometric"),
config,
)?),
};
// println!("LAYER; GEOM: {}, {}", layer, n_layers_geom);
let use_geom_attn: bool = layer < *n_layers_geom;
// println!("Geom ATTN {}", use_geom_attn);
// let geom_attn = match use_geom_attn {
// false => None,
// true => Some(GeometricReasoningOriginalImpl::load(
// vb.pp("geometric"),
// config,
// )?),
// };

let geom_attn = None;

let ffn = match ffn_type {
Ffn_Type::SWIGLU => SwiGLU::load(vb.pp("swiglue"), config)?,
Ffn_Type::SWIGLU => SwiGLU::load(vb.pp("ffn"), config)?,
_ => unimplemented!(), // Ffn_Type::GLU => unimplemented!(),
};

Expand Down
4 changes: 3 additions & 1 deletion ferritin-esm/src/esm/layers/geom_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ impl GeometricReasoningOriginalImpl {
} = config;

let num_vector_messages = 1usize;
let v_heads = v_head_transformer.unwrap();

// todo: this is a hidden param. Needs to be fixed
let v_heads = v_head_transformer.unwrap_or(128);

let dim_proj = 4 * v_heads * 3 + v_heads * 3 * num_vector_messages;
let channels_out = v_heads * 3 * num_vector_messages;
Expand Down
10 changes: 3 additions & 7 deletions ferritin-esm/src/esm/layers/regression_head.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,14 @@ impl RegressionHead {
..
} = config;

let linear1 = nn::linear(
*d_model,
*regression_head_hidden_dim,
vb.pp("regression_linear"),
)?;
let linear1 = nn::linear(*d_model, *regression_head_hidden_dim, vb.pp("0"))?;
let gelu = candle_nn::Activation::Gelu;
let ln_conf = LayerNormConfig::from(1e-5);
let norm = nn::layer_norm(*regression_head_hidden_dim, ln_conf, vb.pp("layer_norm"))?;
let norm = nn::layer_norm(*regression_head_hidden_dim, ln_conf, vb.pp("2"))?;
let linear2 = nn::linear(
*regression_head_hidden_dim,
*regression_head_output_dim,
vb.pp("linear2"),
vb.pp("3"),
)?;

let model = nn::seq().add(linear1).add(gelu).add(norm).add(linear2);
Expand Down
17 changes: 13 additions & 4 deletions ferritin-esm/src/esm/layers/transformer_stack.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::esm::layers::blocks::UnifiedTransformerBlock;
use crate::esm::models::esmc::ESMCConfig;
use crate::esm::utils::structure::affine3d::Affine3D;
// use crate::esm::utils::structure::affine3d::Affine3D;
use candle_core::{Module, Result, Tensor, D};
use candle_nn::{self as nn, LayerNorm, LayerNormConfig};

Expand Down Expand Up @@ -31,11 +31,20 @@ impl TransformerStack {

let mut blocks = Vec::with_capacity(*n_layers as usize);
for i in 0..*n_layers {
blocks.push(UnifiedTransformerBlock::load(vb.pp("layer"), &config, i)?);
blocks.push(UnifiedTransformerBlock::load(
vb.pp(format!("blocks.{}", i)),
&config,
i,
)?);
}

let ln_conf = LayerNormConfig::from(1e-5);
let norm = nn::layer_norm(*d_model, ln_conf, vb.pp("layer_norm"))?;
// let ln_conf = LayerNormConfig::from(1e-5);
let ln_conf = LayerNormConfig {
eps: 1e-5,
remove_mean: true,
affine: false,
};
let norm = nn::layer_norm(*d_model, ln_conf, vb.pp("norm"))?;

Ok(Self { blocks, norm })
}
Expand Down
19 changes: 3 additions & 16 deletions ferritin-esm/src/esm/models/esmc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,30 +122,17 @@ impl ESMC {
pub fn load(vb: VarBuilder, config: ESMCConfig) -> Result<Self> {
let ESMCConfig {
d_model,
n_heads,
n_layers,
v_head_transformer,
ffn_type,
tokenizer,
use_plain_attn,
n_layers_geom,
scale_residue,
residue_scaling_factor,
mask_and_zero_frameless,
bias,
qk_layernorm,
expansion_ratio,
regression_head_output_dim
regression_head_hidden_dim,
embedding_dim,
..
} = config;

let tokenizer_collection = tokenizer.get_model_tokenizers();

Ok(Self {
embed: nn::embedding(embedding_dim, d_model as usize, vb.pp("embedding"))?,
embed: nn::embedding(embedding_dim, d_model as usize, vb.pp("embed"))?,
transformer: TransformerStack::load(vb.pp("transformer"), &config)?,
sequence_head: RegressionHead::load(vb.pp("regression"), &config)?,
sequence_head: RegressionHead::load(vb.pp("sequence_head"), &config)?,
tokenizer: tokenizer_collection.sequence,
})
}
Expand Down
3 changes: 2 additions & 1 deletion ferritin-esm/src/esm/utils/structure/affine3D.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ pub trait Rotation: Sized + Clone {
}

fn requires_grad(&self) -> bool {
self.tensor().requires_grad()
// self.tensor().requires_grad()
unimplemented!()
}

fn to_dtype(&self, dtype: DType) -> Result<Self>;
Expand Down
6 changes: 3 additions & 3 deletions ferritin-esm/src/esm/utils/structure/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod affine3d;
mod protein_chain;
mod protein_complex;
mod protein_structure;
// mod protein_chain;
// mod protein_complex;
// mod protein_structure;
1 change: 1 addition & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ test-ligandmpnn:
cargo test --features metal -p ferritin-ligandmpnn test_cli_command_run_example_06 -- --nocapture

esmc:
#RUST_BACKTRACE=1 cargo run --example esmc
cargo run --example esmc

0 comments on commit 1bc5c65

Please sign in to comment.