Skip to content

Commit

Permalink
fix 1FAP import by adding it to Test Data
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 13, 2024
1 parent 48e352e commit a437ad8
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 92 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion ferritin-bevy/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[package]
[package]
name = "ferritin-bevy"
version.workspace = true
edition.workspace = true
Expand All @@ -13,3 +13,4 @@ ferritin-core = { path = "../ferritin-core" }
pdbtbx.workspace = true

[dev-dependencies]
ferritin-test-data = { path = "../ferritin-test-data" }
7 changes: 6 additions & 1 deletion ferritin-bevy/src/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,15 @@ impl Structure {
#[cfg(test)]
mod tests {
use super::*;
use feritin_test_data::TestData;
use ferritin_test_data::TestFile;

#[test]
fn test_pdb_to_mesh() {
let (molfile, _handle) = TestFile::protein_04().create_temp()?;

// let (pdb, _errors) = pdbtbx::open("examples/1fap.cif").unwrap();
let (pdb, _errors) = pdbtbx::open("examples/1fap.cif").unwrap();

let structure = Structure::builder().pdb(AtomCollection::from(&pdb)).build();
assert_eq!(structure.pdb.get_size(), 2154);
let mesh = structure.to_mesh();
Expand Down
180 changes: 90 additions & 90 deletions ferritin-esm/src/esm2/models/esm2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,107 +199,107 @@ impl ESM2 {
// })
// }

// fn forward(
// &self,
// tokens: &Tensor,
// repr_layers: &[i32],
// need_head_weights: bool,
// return_contacts: bool,
// ) -> Result<BTreeMap<String, Tensor>> {
// let need_head_weights = need_head_weights || return_contacts;

// let padding_mask = tokens.eq(self.padding_idx)?;

// let mut x = self
// .embed_tokens
// .forward(tokens)?
// .mul_scalar(self.embed_scale)?;

// if self.token_dropout {
// let mask = tokens.eq(self.mask_idx)?.unsqueeze(-1)?;
// x = x.masked_fill(&mask, 0.0)?;

// let mask_ratio_train = 0.15 * 0.8;
// let src_lengths = padding_mask.logical_not()?.sum_keepdim(-1)?;
// let mask_ratio_observed = tokens
// .eq(self.mask_idx)?
// .sum_keepdim(-1)?
// .to_dtype(x.dtype())?
// .div(&src_lengths)?;
// let scale = (1.0 - mask_ratio_train) / (1.0 - mask_ratio_observed)?;
// x = x.mul(&scale.unsqueeze(-1)?)?;
// }
fn forward(
&self,
tokens: &Tensor,
repr_layers: &[i32],
need_head_weights: bool,
return_contacts: bool,
) -> Result<BTreeMap<String, Tensor>> {
let need_head_weights = need_head_weights || return_contacts;

let padding_mask = tokens.eq(self.padding_idx)?;

let mut x = self
.embed_tokens
.forward(tokens)?
.mul_scalar(self.embed_scale)?;

if self.token_dropout {
let mask = tokens.eq(self.mask_idx)?.unsqueeze(-1)?;
x = x.masked_fill(&mask, 0.0)?;

let mask_ratio_train = 0.15 * 0.8;
let src_lengths = padding_mask.logical_not()?.sum_keepdim(-1)?;
let mask_ratio_observed = tokens
.eq(self.mask_idx)?
.sum_keepdim(-1)?
.to_dtype(x.dtype())?
.div(&src_lengths)?;
let scale = (1.0 - mask_ratio_train) / (1.0 - mask_ratio_observed)?;
x = x.mul(&scale.unsqueeze(-1)?)?;
}

// if !padding_mask.all()? {
// let not_padding = padding_mask.logical_not()?.to_dtype(x.dtype())?;
// x = x.mul(&not_padding.unsqueeze(-1)?)?;
// }
if !padding_mask.all()? {
let not_padding = padding_mask.logical_not()?.to_dtype(x.dtype())?;
x = x.mul(&not_padding.unsqueeze(-1)?)?;
}

// let repr_layers: HashSet<_> = repr_layers.iter().cloned().collect();
// let mut hidden_representations = BTreeMap::new();
// if repr_layers.contains(&0) {
// hidden_representations.insert("0".to_string(), x.clone());
// }
let repr_layers: HashSet<_> = repr_layers.iter().cloned().collect();
let mut hidden_representations = BTreeMap::new();
if repr_layers.contains(&0) {
hidden_representations.insert("0".to_string(), x.clone());
}

// let mut attn_weights = Vec::new();
// x = x.transpose(0, 1)?;
let mut attn_weights = Vec::new();
x = x.transpose(0, 1)?;

// let padding_mask = if !padding_mask.any()? {
// None
// } else {
// Some(padding_mask)
// };
let padding_mask = if !padding_mask.any()? {
None
} else {
Some(padding_mask)
};

// for (layer_idx, layer) in self.layers.iter().enumerate() {
// let (new_x, attn) = layer.forward(&x, padding_mask.as_ref(), need_head_weights)?;
// x = new_x;
for (layer_idx, layer) in self.layers.iter().enumerate() {
let (new_x, attn) = layer.forward(&x, padding_mask.as_ref(), need_head_weights)?;
x = new_x;

// if repr_layers.contains(&(layer_idx as i32 + 1)) {
// hidden_representations
// .insert((layer_idx + 1).to_string(), x.transpose(0, 1)?.clone());
// }
if repr_layers.contains(&(layer_idx as i32 + 1)) {
hidden_representations
.insert((layer_idx + 1).to_string(), x.transpose(0, 1)?.clone());
}

// if need_head_weights {
// attn_weights.push(attn.transpose(1, 0)?);
// }
// }
if need_head_weights {
attn_weights.push(attn.transpose(1, 0)?);
}
}

// x = self.emb_layer_norm_after.forward(&x)?;
// x = x.transpose(0, 1)?;
x = self.emb_layer_norm_after.forward(&x)?;
x = x.transpose(0, 1)?;

// if repr_layers.contains(&(self.layers.len() as i32)) {
// hidden_representations.insert(self.layers.len().to_string(), x.clone());
// }
if repr_layers.contains(&(self.layers.len() as i32)) {
hidden_representations.insert(self.layers.len().to_string(), x.clone());
}

// let logits = self.lm_head.forward(&x)?;

// let mut result = BTreeMap::new();
// result.insert("logits".to_string(), logits);
// result.insert("representations".to_string(), x);

// if need_head_weights {
// let attentions = Tensor::stack(&attn_weights, 1)?;
// if let Some(padding_mask) = padding_mask {
// let attention_mask = padding_mask.logical_not()?.to_dtype(attentions.dtype())?;
// let attention_mask = attention_mask
// .unsqueeze(1)?
// .mul(&attention_mask.unsqueeze(2)?)?;
// result.insert(
// "attentions".to_string(),
// attentions.mul(&attention_mask.unsqueeze(1)?.unsqueeze(1)?)?,
// );
// } else {
// result.insert("attentions".to_string(), attentions);
// }

// if return_contacts {
// let contacts = self.contact_head.forward(tokens, &attentions)?;
// result.insert("contacts".to_string(), contacts);
// }
// }
let logits = self.lm_head.forward(&x)?;

let mut result = BTreeMap::new();
result.insert("logits".to_string(), logits);
result.insert("representations".to_string(), x);

if need_head_weights {
let attentions = Tensor::stack(&attn_weights, 1)?;
if let Some(padding_mask) = padding_mask {
let attention_mask = padding_mask.logical_not()?.to_dtype(attentions.dtype())?;
let attention_mask = attention_mask
.unsqueeze(1)?
.mul(&attention_mask.unsqueeze(2)?)?;
result.insert(
"attentions".to_string(),
attentions.mul(&attention_mask.unsqueeze(1)?.unsqueeze(1)?)?,
);
} else {
result.insert("attentions".to_string(), attentions);
}

if return_contacts {
let contacts = self.contact_head.forward(tokens, &attentions)?;
result.insert("contacts".to_string(), contacts);
}
}

// Ok(result)
// }
Ok(result)
}

// pub fn predict_contacts(&self, tokens: &Tensor) -> Result<Tensor> {
// let mut result = self.forward(tokens, &[], false, true)?;
Expand Down

0 comments on commit a437ad8

Please sign in to comment.