From 70ca962fa899d4d3b75fb84664ea8eb7ec6722c0 Mon Sep 17 00:00:00 2001 From: Zachary Charlop-Powers Date: Fri, 13 Dec 2024 11:53:56 -0500 Subject: [PATCH] update MHA --- ferritin-esm/src/esm2/models/modules.rs | 1 - .../src/esm2/models/multihead_attention.rs | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ferritin-esm/src/esm2/models/modules.rs b/ferritin-esm/src/esm2/models/modules.rs index 386b2d24..1263b8e6 100644 --- a/ferritin-esm/src/esm2/models/modules.rs +++ b/ferritin-esm/src/esm2/models/modules.rs @@ -379,7 +379,6 @@ impl RobertaLMHead { let ESM2Config { hidden_size, .. } = config; let dense = nn::linear(*hidden_size as usize, *hidden_size as usize, vb.pp("dense"))?; let layer_norm = ESM1bLayerNorm::load(vb.pp("LayerNorm"), config)?; - Ok(Self { dense, layer_norm }) } // pub fn new( diff --git a/ferritin-esm/src/esm2/models/multihead_attention.rs b/ferritin-esm/src/esm2/models/multihead_attention.rs index c995517e..4cab61b2 100644 --- a/ferritin-esm/src/esm2/models/multihead_attention.rs +++ b/ferritin-esm/src/esm2/models/multihead_attention.rs @@ -87,6 +87,23 @@ pub struct MultiheadAttention { impl MultiheadAttention { pub fn load(vb: VarBuilder, config: &ESM2Config) -> Result { + let ESM2Config { .. } = config; + + let kdim = kdim.unwrap_or(embed_dim); + let vdim = vdim.unwrap_or(embed_dim); + let qkv_same_dim = kdim == embed_dim && vdim == embed_dim; + + let head_dim = embed_dim / num_heads; + assert!( + head_dim * num_heads == embed_dim, + "embed_dim must be divisible by num_heads" + ); + let scaling = (head_dim as f64).powf(-0.5); + + let q_proj = nn::linear(embed_dim, embed_dim, bias, vb.pp("q_proj"))?; + let k_proj = nn::linear(kdim, embed_dim, bias, vb.pp("k_proj"))?; + let v_proj = nn::linear(vdim, embed_dim, bias, vb.pp("v_proj"))?; + // MultiheadAttention::new( // embed_dim, // attention_heads,