Skip to content

Commit

Permalink
update comment out forward
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 13, 2024
1 parent a437ad8 commit b5aa9c3
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 91 deletions.
2 changes: 1 addition & 1 deletion ferritin-bevy/src/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ mod tests {

#[test]
fn test_pdb_to_mesh() {
let (molfile, _handle) = TestFile::protein_04().create_temp()?;
let (molfile, _handle) = TestFile::protein_04().create_temp();

// let (pdb, _errors) = pdbtbx::open("examples/1fap.cif").unwrap();
let (pdb, _errors) = pdbtbx::open("examples/1fap.cif").unwrap();
Expand Down
180 changes: 90 additions & 90 deletions ferritin-esm/src/esm2/models/esm2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,107 +199,107 @@ impl ESM2 {
// })
// }

fn forward(
&self,
tokens: &Tensor,
repr_layers: &[i32],
need_head_weights: bool,
return_contacts: bool,
) -> Result<BTreeMap<String, Tensor>> {
let need_head_weights = need_head_weights || return_contacts;

let padding_mask = tokens.eq(self.padding_idx)?;

let mut x = self
.embed_tokens
.forward(tokens)?
.mul_scalar(self.embed_scale)?;

if self.token_dropout {
let mask = tokens.eq(self.mask_idx)?.unsqueeze(-1)?;
x = x.masked_fill(&mask, 0.0)?;

let mask_ratio_train = 0.15 * 0.8;
let src_lengths = padding_mask.logical_not()?.sum_keepdim(-1)?;
let mask_ratio_observed = tokens
.eq(self.mask_idx)?
.sum_keepdim(-1)?
.to_dtype(x.dtype())?
.div(&src_lengths)?;
let scale = (1.0 - mask_ratio_train) / (1.0 - mask_ratio_observed)?;
x = x.mul(&scale.unsqueeze(-1)?)?;
}
// fn forward(
// &self,
// tokens: &Tensor,
// repr_layers: &[i32],
// need_head_weights: bool,
// return_contacts: bool,
// ) -> Result<BTreeMap<String, Tensor>> {
// let need_head_weights = need_head_weights || return_contacts;

// let padding_mask = tokens.eq(self.padding_idx)?;

// let mut x = self
// .embed_tokens
// .forward(tokens)?
// .mul_scalar(self.embed_scale)?;

// if self.token_dropout {
// let mask = tokens.eq(self.mask_idx)?.unsqueeze(-1)?;
// x = x.masked_fill(&mask, 0.0)?;

// let mask_ratio_train = 0.15 * 0.8;
// let src_lengths = padding_mask.logical_not()?.sum_keepdim(-1)?;
// let mask_ratio_observed = tokens
// .eq(self.mask_idx)?
// .sum_keepdim(-1)?
// .to_dtype(x.dtype())?
// .div(&src_lengths)?;
// let scale = (1.0 - mask_ratio_train) / (1.0 - mask_ratio_observed)?;
// x = x.mul(&scale.unsqueeze(-1)?)?;
// }

if !padding_mask.all()? {
let not_padding = padding_mask.logical_not()?.to_dtype(x.dtype())?;
x = x.mul(&not_padding.unsqueeze(-1)?)?;
}
// if !padding_mask.all()? {
// let not_padding = padding_mask.logical_not()?.to_dtype(x.dtype())?;
// x = x.mul(&not_padding.unsqueeze(-1)?)?;
// }

let repr_layers: HashSet<_> = repr_layers.iter().cloned().collect();
let mut hidden_representations = BTreeMap::new();
if repr_layers.contains(&0) {
hidden_representations.insert("0".to_string(), x.clone());
}
// let repr_layers: HashSet<_> = repr_layers.iter().cloned().collect();
// let mut hidden_representations = BTreeMap::new();
// if repr_layers.contains(&0) {
// hidden_representations.insert("0".to_string(), x.clone());
// }

let mut attn_weights = Vec::new();
x = x.transpose(0, 1)?;
// let mut attn_weights = Vec::new();
// x = x.transpose(0, 1)?;

let padding_mask = if !padding_mask.any()? {
None
} else {
Some(padding_mask)
};
// let padding_mask = if !padding_mask.any()? {
// None
// } else {
// Some(padding_mask)
// };

for (layer_idx, layer) in self.layers.iter().enumerate() {
let (new_x, attn) = layer.forward(&x, padding_mask.as_ref(), need_head_weights)?;
x = new_x;
// for (layer_idx, layer) in self.layers.iter().enumerate() {
// let (new_x, attn) = layer.forward(&x, padding_mask.as_ref(), need_head_weights)?;
// x = new_x;

if repr_layers.contains(&(layer_idx as i32 + 1)) {
hidden_representations
.insert((layer_idx + 1).to_string(), x.transpose(0, 1)?.clone());
}
// if repr_layers.contains(&(layer_idx as i32 + 1)) {
// hidden_representations
// .insert((layer_idx + 1).to_string(), x.transpose(0, 1)?.clone());
// }

if need_head_weights {
attn_weights.push(attn.transpose(1, 0)?);
}
}
// if need_head_weights {
// attn_weights.push(attn.transpose(1, 0)?);
// }
// }

x = self.emb_layer_norm_after.forward(&x)?;
x = x.transpose(0, 1)?;
// x = self.emb_layer_norm_after.forward(&x)?;
// x = x.transpose(0, 1)?;

if repr_layers.contains(&(self.layers.len() as i32)) {
hidden_representations.insert(self.layers.len().to_string(), x.clone());
}
// if repr_layers.contains(&(self.layers.len() as i32)) {
// hidden_representations.insert(self.layers.len().to_string(), x.clone());
// }

let logits = self.lm_head.forward(&x)?;

let mut result = BTreeMap::new();
result.insert("logits".to_string(), logits);
result.insert("representations".to_string(), x);

if need_head_weights {
let attentions = Tensor::stack(&attn_weights, 1)?;
if let Some(padding_mask) = padding_mask {
let attention_mask = padding_mask.logical_not()?.to_dtype(attentions.dtype())?;
let attention_mask = attention_mask
.unsqueeze(1)?
.mul(&attention_mask.unsqueeze(2)?)?;
result.insert(
"attentions".to_string(),
attentions.mul(&attention_mask.unsqueeze(1)?.unsqueeze(1)?)?,
);
} else {
result.insert("attentions".to_string(), attentions);
}

if return_contacts {
let contacts = self.contact_head.forward(tokens, &attentions)?;
result.insert("contacts".to_string(), contacts);
}
}
// let logits = self.lm_head.forward(&x)?;

// let mut result = BTreeMap::new();
// result.insert("logits".to_string(), logits);
// result.insert("representations".to_string(), x);

// if need_head_weights {
// let attentions = Tensor::stack(&attn_weights, 1)?;
// if let Some(padding_mask) = padding_mask {
// let attention_mask = padding_mask.logical_not()?.to_dtype(attentions.dtype())?;
// let attention_mask = attention_mask
// .unsqueeze(1)?
// .mul(&attention_mask.unsqueeze(2)?)?;
// result.insert(
// "attentions".to_string(),
// attentions.mul(&attention_mask.unsqueeze(1)?.unsqueeze(1)?)?,
// );
// } else {
// result.insert("attentions".to_string(), attentions);
// }

// if return_contacts {
// let contacts = self.contact_head.forward(tokens, &attentions)?;
// result.insert("contacts".to_string(), contacts);
// }
// }

Ok(result)
}
// Ok(result)
// }

// pub fn predict_contacts(&self, tokens: &Tensor) -> Result<Tensor> {
// let mut result = self.forward(tokens, &[], false, true)?;
Expand Down

0 comments on commit b5aa9c3

Please sign in to comment.