Skip to content

Commit

Permalink
start matching the layer names to the reference
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 13, 2024
1 parent 70ca962 commit 7faa8a7
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions ferritin-esm/src/esm2/models/multihead_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use super::esm2::ESM2Config;
use super::rotary_embedding::RotaryEmbedding;
use candle_core::{Device, Module, Result, Tensor};
use candle_nn::{init, linear, ops, VarBuilder};
use candle_nn::{self as nn, linear, ops, VarBuilder};
use std::collections::HashMap;

// use uuid::Uuid;
Expand Down Expand Up @@ -89,20 +89,24 @@ impl MultiheadAttention {
pub fn load(vb: VarBuilder, config: &ESM2Config) -> Result<Self> {
let ESM2Config { .. } = config;

let kdim = kdim.unwrap_or(embed_dim);
let vdim = vdim.unwrap_or(embed_dim);
let qkv_same_dim = kdim == embed_dim && vdim == embed_dim;
// let kdim = kdim.unwrap_or(embed_dim);
// let vdim = vdim.unwrap_or(embed_dim);
// let qkv_same_dim = kdim == embed_dim && vdim == embed_dim;

let kdim = 100;
let vdim = 100;
let qkv_same_dim = true;
let num_heads = 10;
let embed_dim = 10;
let head_dim = embed_dim / num_heads;
assert!(
head_dim * num_heads == embed_dim,
"embed_dim must be divisible by num_heads"
);
let scaling = (head_dim as f64).powf(-0.5);

let q_proj = nn::linear(embed_dim, embed_dim, bias, vb.pp("q_proj"))?;
let k_proj = nn::linear(kdim, embed_dim, bias, vb.pp("k_proj"))?;
let v_proj = nn::linear(vdim, embed_dim, bias, vb.pp("v_proj"))?;
let q_proj = nn::linear(embed_dim, embed_dim, vb.pp("self.query"))?;
let k_proj = nn::linear(kdim, embed_dim, vb.pp("self.key"))?;
let v_proj = nn::linear(vdim, embed_dim, vb.pp("self.value"))?;

// MultiheadAttention::new(
// embed_dim,
Expand Down

0 comments on commit 7faa8a7

Please sign in to comment.