diff --git a/Cargo.lock b/Cargo.lock index 1cd03dc8..41068254 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2512,6 +2512,7 @@ dependencies = [ "bevy", "bon", "ferritin-core", + "ferritin-test-data", "pdbtbx", ] diff --git a/ferritin-bevy/Cargo.toml b/ferritin-bevy/Cargo.toml index ed67ff37..060261ba 100644 --- a/ferritin-bevy/Cargo.toml +++ b/ferritin-bevy/Cargo.toml @@ -1,4 +1,4 @@ - [package] +[package] name = "ferritin-bevy" version.workspace = true edition.workspace = true @@ -13,3 +13,4 @@ ferritin-core = { path = "../ferritin-core" } pdbtbx.workspace = true [dev-dependencies] +ferritin-test-data = { path = "../ferritin-test-data" } diff --git a/ferritin-bevy/src/structure.rs b/ferritin-bevy/src/structure.rs index bcd8c190..902ced00 100644 --- a/ferritin-bevy/src/structure.rs +++ b/ferritin-bevy/src/structure.rs @@ -268,10 +268,15 @@ impl Structure { #[cfg(test)] mod tests { use super::*; - use feritin_test_data::TestData; + use ferritin_test_data::TestFile; + #[test] fn test_pdb_to_mesh() { + let (molfile, _handle) = TestFile::protein_04().create_temp()?; + + // let (pdb, _errors) = pdbtbx::open("examples/1fap.cif").unwrap(); let (pdb, _errors) = pdbtbx::open("examples/1fap.cif").unwrap(); + let structure = Structure::builder().pdb(AtomCollection::from(&pdb)).build(); assert_eq!(structure.pdb.get_size(), 2154); let mesh = structure.to_mesh(); diff --git a/ferritin-esm/src/esm2/models/esm2.rs b/ferritin-esm/src/esm2/models/esm2.rs index 869853a9..dc2e53b0 100644 --- a/ferritin-esm/src/esm2/models/esm2.rs +++ b/ferritin-esm/src/esm2/models/esm2.rs @@ -199,107 +199,107 @@ impl ESM2 { // }) // } - // fn forward( - // &self, - // tokens: &Tensor, - // repr_layers: &[i32], - // need_head_weights: bool, - // return_contacts: bool, - // ) -> Result> { - // let need_head_weights = need_head_weights || return_contacts; - - // let padding_mask = tokens.eq(self.padding_idx)?; - - // let mut x = self - // .embed_tokens - // .forward(tokens)? - // .mul_scalar(self.embed_scale)?; - - // if self.token_dropout { - // let mask = tokens.eq(self.mask_idx)?.unsqueeze(-1)?; - // x = x.masked_fill(&mask, 0.0)?; - - // let mask_ratio_train = 0.15 * 0.8; - // let src_lengths = padding_mask.logical_not()?.sum_keepdim(-1)?; - // let mask_ratio_observed = tokens - // .eq(self.mask_idx)? - // .sum_keepdim(-1)? - // .to_dtype(x.dtype())? - // .div(&src_lengths)?; - // let scale = (1.0 - mask_ratio_train) / (1.0 - mask_ratio_observed)?; - // x = x.mul(&scale.unsqueeze(-1)?)?; - // } + fn forward( + &self, + tokens: &Tensor, + repr_layers: &[i32], + need_head_weights: bool, + return_contacts: bool, + ) -> Result> { + let need_head_weights = need_head_weights || return_contacts; + + let padding_mask = tokens.eq(self.padding_idx)?; + + let mut x = self + .embed_tokens + .forward(tokens)? + .mul_scalar(self.embed_scale)?; + + if self.token_dropout { + let mask = tokens.eq(self.mask_idx)?.unsqueeze(-1)?; + x = x.masked_fill(&mask, 0.0)?; + + let mask_ratio_train = 0.15 * 0.8; + let src_lengths = padding_mask.logical_not()?.sum_keepdim(-1)?; + let mask_ratio_observed = tokens + .eq(self.mask_idx)? + .sum_keepdim(-1)? + .to_dtype(x.dtype())? + .div(&src_lengths)?; + let scale = (1.0 - mask_ratio_train) / (1.0 - mask_ratio_observed)?; + x = x.mul(&scale.unsqueeze(-1)?)?; + } - // if !padding_mask.all()? { - // let not_padding = padding_mask.logical_not()?.to_dtype(x.dtype())?; - // x = x.mul(¬_padding.unsqueeze(-1)?)?; - // } + if !padding_mask.all()? { + let not_padding = padding_mask.logical_not()?.to_dtype(x.dtype())?; + x = x.mul(¬_padding.unsqueeze(-1)?)?; + } - // let repr_layers: HashSet<_> = repr_layers.iter().cloned().collect(); - // let mut hidden_representations = BTreeMap::new(); - // if repr_layers.contains(&0) { - // hidden_representations.insert("0".to_string(), x.clone()); - // } + let repr_layers: HashSet<_> = repr_layers.iter().cloned().collect(); + let mut hidden_representations = BTreeMap::new(); + if repr_layers.contains(&0) { + hidden_representations.insert("0".to_string(), x.clone()); + } - // let mut attn_weights = Vec::new(); - // x = x.transpose(0, 1)?; + let mut attn_weights = Vec::new(); + x = x.transpose(0, 1)?; - // let padding_mask = if !padding_mask.any()? { - // None - // } else { - // Some(padding_mask) - // }; + let padding_mask = if !padding_mask.any()? { + None + } else { + Some(padding_mask) + }; - // for (layer_idx, layer) in self.layers.iter().enumerate() { - // let (new_x, attn) = layer.forward(&x, padding_mask.as_ref(), need_head_weights)?; - // x = new_x; + for (layer_idx, layer) in self.layers.iter().enumerate() { + let (new_x, attn) = layer.forward(&x, padding_mask.as_ref(), need_head_weights)?; + x = new_x; - // if repr_layers.contains(&(layer_idx as i32 + 1)) { - // hidden_representations - // .insert((layer_idx + 1).to_string(), x.transpose(0, 1)?.clone()); - // } + if repr_layers.contains(&(layer_idx as i32 + 1)) { + hidden_representations + .insert((layer_idx + 1).to_string(), x.transpose(0, 1)?.clone()); + } - // if need_head_weights { - // attn_weights.push(attn.transpose(1, 0)?); - // } - // } + if need_head_weights { + attn_weights.push(attn.transpose(1, 0)?); + } + } - // x = self.emb_layer_norm_after.forward(&x)?; - // x = x.transpose(0, 1)?; + x = self.emb_layer_norm_after.forward(&x)?; + x = x.transpose(0, 1)?; - // if repr_layers.contains(&(self.layers.len() as i32)) { - // hidden_representations.insert(self.layers.len().to_string(), x.clone()); - // } + if repr_layers.contains(&(self.layers.len() as i32)) { + hidden_representations.insert(self.layers.len().to_string(), x.clone()); + } - // let logits = self.lm_head.forward(&x)?; - - // let mut result = BTreeMap::new(); - // result.insert("logits".to_string(), logits); - // result.insert("representations".to_string(), x); - - // if need_head_weights { - // let attentions = Tensor::stack(&attn_weights, 1)?; - // if let Some(padding_mask) = padding_mask { - // let attention_mask = padding_mask.logical_not()?.to_dtype(attentions.dtype())?; - // let attention_mask = attention_mask - // .unsqueeze(1)? - // .mul(&attention_mask.unsqueeze(2)?)?; - // result.insert( - // "attentions".to_string(), - // attentions.mul(&attention_mask.unsqueeze(1)?.unsqueeze(1)?)?, - // ); - // } else { - // result.insert("attentions".to_string(), attentions); - // } - - // if return_contacts { - // let contacts = self.contact_head.forward(tokens, &attentions)?; - // result.insert("contacts".to_string(), contacts); - // } - // } + let logits = self.lm_head.forward(&x)?; + + let mut result = BTreeMap::new(); + result.insert("logits".to_string(), logits); + result.insert("representations".to_string(), x); + + if need_head_weights { + let attentions = Tensor::stack(&attn_weights, 1)?; + if let Some(padding_mask) = padding_mask { + let attention_mask = padding_mask.logical_not()?.to_dtype(attentions.dtype())?; + let attention_mask = attention_mask + .unsqueeze(1)? + .mul(&attention_mask.unsqueeze(2)?)?; + result.insert( + "attentions".to_string(), + attentions.mul(&attention_mask.unsqueeze(1)?.unsqueeze(1)?)?, + ); + } else { + result.insert("attentions".to_string(), attentions); + } + + if return_contacts { + let contacts = self.contact_head.forward(tokens, &attentions)?; + result.insert("contacts".to_string(), contacts); + } + } - // Ok(result) - // } + Ok(result) + } // pub fn predict_contacts(&self, tokens: &Tensor) -> Result { // let mut result = self.forward(tokens, &[], false, true)?;