From b5aa9c3319c6cfb49811c8a19e914fe04cbf8be6 Mon Sep 17 00:00:00 2001 From: Zachary Charlop-Powers Date: Fri, 13 Dec 2024 12:56:26 -0500 Subject: [PATCH] update comment out forward --- ferritin-bevy/src/structure.rs | 2 +- ferritin-esm/src/esm2/models/esm2.rs | 180 +++++++++++++-------------- 2 files changed, 91 insertions(+), 91 deletions(-) diff --git a/ferritin-bevy/src/structure.rs b/ferritin-bevy/src/structure.rs index 902ced00..d8e27c2c 100644 --- a/ferritin-bevy/src/structure.rs +++ b/ferritin-bevy/src/structure.rs @@ -272,7 +272,7 @@ mod tests { #[test] fn test_pdb_to_mesh() { - let (molfile, _handle) = TestFile::protein_04().create_temp()?; + let (molfile, _handle) = TestFile::protein_04().create_temp(); // let (pdb, _errors) = pdbtbx::open("examples/1fap.cif").unwrap(); let (pdb, _errors) = pdbtbx::open("examples/1fap.cif").unwrap(); diff --git a/ferritin-esm/src/esm2/models/esm2.rs b/ferritin-esm/src/esm2/models/esm2.rs index dc2e53b0..869853a9 100644 --- a/ferritin-esm/src/esm2/models/esm2.rs +++ b/ferritin-esm/src/esm2/models/esm2.rs @@ -199,107 +199,107 @@ impl ESM2 { // }) // } - fn forward( - &self, - tokens: &Tensor, - repr_layers: &[i32], - need_head_weights: bool, - return_contacts: bool, - ) -> Result> { - let need_head_weights = need_head_weights || return_contacts; - - let padding_mask = tokens.eq(self.padding_idx)?; - - let mut x = self - .embed_tokens - .forward(tokens)? - .mul_scalar(self.embed_scale)?; - - if self.token_dropout { - let mask = tokens.eq(self.mask_idx)?.unsqueeze(-1)?; - x = x.masked_fill(&mask, 0.0)?; - - let mask_ratio_train = 0.15 * 0.8; - let src_lengths = padding_mask.logical_not()?.sum_keepdim(-1)?; - let mask_ratio_observed = tokens - .eq(self.mask_idx)? - .sum_keepdim(-1)? - .to_dtype(x.dtype())? - .div(&src_lengths)?; - let scale = (1.0 - mask_ratio_train) / (1.0 - mask_ratio_observed)?; - x = x.mul(&scale.unsqueeze(-1)?)?; - } + // fn forward( + // &self, + // tokens: &Tensor, + // repr_layers: &[i32], + // need_head_weights: bool, + // return_contacts: bool, + // ) -> Result> { + // let need_head_weights = need_head_weights || return_contacts; + + // let padding_mask = tokens.eq(self.padding_idx)?; + + // let mut x = self + // .embed_tokens + // .forward(tokens)? + // .mul_scalar(self.embed_scale)?; + + // if self.token_dropout { + // let mask = tokens.eq(self.mask_idx)?.unsqueeze(-1)?; + // x = x.masked_fill(&mask, 0.0)?; + + // let mask_ratio_train = 0.15 * 0.8; + // let src_lengths = padding_mask.logical_not()?.sum_keepdim(-1)?; + // let mask_ratio_observed = tokens + // .eq(self.mask_idx)? + // .sum_keepdim(-1)? + // .to_dtype(x.dtype())? + // .div(&src_lengths)?; + // let scale = (1.0 - mask_ratio_train) / (1.0 - mask_ratio_observed)?; + // x = x.mul(&scale.unsqueeze(-1)?)?; + // } - if !padding_mask.all()? { - let not_padding = padding_mask.logical_not()?.to_dtype(x.dtype())?; - x = x.mul(¬_padding.unsqueeze(-1)?)?; - } + // if !padding_mask.all()? { + // let not_padding = padding_mask.logical_not()?.to_dtype(x.dtype())?; + // x = x.mul(¬_padding.unsqueeze(-1)?)?; + // } - let repr_layers: HashSet<_> = repr_layers.iter().cloned().collect(); - let mut hidden_representations = BTreeMap::new(); - if repr_layers.contains(&0) { - hidden_representations.insert("0".to_string(), x.clone()); - } + // let repr_layers: HashSet<_> = repr_layers.iter().cloned().collect(); + // let mut hidden_representations = BTreeMap::new(); + // if repr_layers.contains(&0) { + // hidden_representations.insert("0".to_string(), x.clone()); + // } - let mut attn_weights = Vec::new(); - x = x.transpose(0, 1)?; + // let mut attn_weights = Vec::new(); + // x = x.transpose(0, 1)?; - let padding_mask = if !padding_mask.any()? { - None - } else { - Some(padding_mask) - }; + // let padding_mask = if !padding_mask.any()? { + // None + // } else { + // Some(padding_mask) + // }; - for (layer_idx, layer) in self.layers.iter().enumerate() { - let (new_x, attn) = layer.forward(&x, padding_mask.as_ref(), need_head_weights)?; - x = new_x; + // for (layer_idx, layer) in self.layers.iter().enumerate() { + // let (new_x, attn) = layer.forward(&x, padding_mask.as_ref(), need_head_weights)?; + // x = new_x; - if repr_layers.contains(&(layer_idx as i32 + 1)) { - hidden_representations - .insert((layer_idx + 1).to_string(), x.transpose(0, 1)?.clone()); - } + // if repr_layers.contains(&(layer_idx as i32 + 1)) { + // hidden_representations + // .insert((layer_idx + 1).to_string(), x.transpose(0, 1)?.clone()); + // } - if need_head_weights { - attn_weights.push(attn.transpose(1, 0)?); - } - } + // if need_head_weights { + // attn_weights.push(attn.transpose(1, 0)?); + // } + // } - x = self.emb_layer_norm_after.forward(&x)?; - x = x.transpose(0, 1)?; + // x = self.emb_layer_norm_after.forward(&x)?; + // x = x.transpose(0, 1)?; - if repr_layers.contains(&(self.layers.len() as i32)) { - hidden_representations.insert(self.layers.len().to_string(), x.clone()); - } + // if repr_layers.contains(&(self.layers.len() as i32)) { + // hidden_representations.insert(self.layers.len().to_string(), x.clone()); + // } - let logits = self.lm_head.forward(&x)?; - - let mut result = BTreeMap::new(); - result.insert("logits".to_string(), logits); - result.insert("representations".to_string(), x); - - if need_head_weights { - let attentions = Tensor::stack(&attn_weights, 1)?; - if let Some(padding_mask) = padding_mask { - let attention_mask = padding_mask.logical_not()?.to_dtype(attentions.dtype())?; - let attention_mask = attention_mask - .unsqueeze(1)? - .mul(&attention_mask.unsqueeze(2)?)?; - result.insert( - "attentions".to_string(), - attentions.mul(&attention_mask.unsqueeze(1)?.unsqueeze(1)?)?, - ); - } else { - result.insert("attentions".to_string(), attentions); - } - - if return_contacts { - let contacts = self.contact_head.forward(tokens, &attentions)?; - result.insert("contacts".to_string(), contacts); - } - } + // let logits = self.lm_head.forward(&x)?; + + // let mut result = BTreeMap::new(); + // result.insert("logits".to_string(), logits); + // result.insert("representations".to_string(), x); + + // if need_head_weights { + // let attentions = Tensor::stack(&attn_weights, 1)?; + // if let Some(padding_mask) = padding_mask { + // let attention_mask = padding_mask.logical_not()?.to_dtype(attentions.dtype())?; + // let attention_mask = attention_mask + // .unsqueeze(1)? + // .mul(&attention_mask.unsqueeze(2)?)?; + // result.insert( + // "attentions".to_string(), + // attentions.mul(&attention_mask.unsqueeze(1)?.unsqueeze(1)?)?, + // ); + // } else { + // result.insert("attentions".to_string(), attentions); + // } + + // if return_contacts { + // let contacts = self.contact_head.forward(tokens, &attentions)?; + // result.insert("contacts".to_string(), contacts); + // } + // } - Ok(result) - } + // Ok(result) + // } // pub fn predict_contacts(&self, tokens: &Tensor) -> Result { // let mut result = self.forward(tokens, &[], false, true)?;