Skip to content

Commit

Permalink
update MHA
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 13, 2024
1 parent 243d23c commit 70ca962
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
1 change: 0 additions & 1 deletion ferritin-esm/src/esm2/models/modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,6 @@ impl RobertaLMHead {
let ESM2Config { hidden_size, .. } = config;
let dense = nn::linear(*hidden_size as usize, *hidden_size as usize, vb.pp("dense"))?;
let layer_norm = ESM1bLayerNorm::load(vb.pp("LayerNorm"), config)?;

Ok(Self { dense, layer_norm })
}
// pub fn new(
Expand Down
17 changes: 17 additions & 0 deletions ferritin-esm/src/esm2/models/multihead_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@ pub struct MultiheadAttention {

impl MultiheadAttention {
pub fn load(vb: VarBuilder, config: &ESM2Config) -> Result<Self> {
let ESM2Config { .. } = config;

let kdim = kdim.unwrap_or(embed_dim);
let vdim = vdim.unwrap_or(embed_dim);
let qkv_same_dim = kdim == embed_dim && vdim == embed_dim;

let head_dim = embed_dim / num_heads;
assert!(
head_dim * num_heads == embed_dim,
"embed_dim must be divisible by num_heads"
);
let scaling = (head_dim as f64).powf(-0.5);

let q_proj = nn::linear(embed_dim, embed_dim, bias, vb.pp("q_proj"))?;
let k_proj = nn::linear(kdim, embed_dim, bias, vb.pp("k_proj"))?;
let v_proj = nn::linear(vdim, embed_dim, bias, vb.pp("v_proj"))?;

// MultiheadAttention::new(
// embed_dim,
// attention_heads,
Expand Down

0 comments on commit 70ca962

Please sign in to comment.