From 7e555a041c7ef05ae914c8abe5161e19f0adafcc Mon Sep 17 00:00:00 2001 From: zachcp Date: Thu, 2 Jan 2025 16:14:40 -0500 Subject: [PATCH] update main.rs (#95) --- ferritin-examples/examples/amplify/main.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ferritin-examples/examples/amplify/main.rs b/ferritin-examples/examples/amplify/main.rs index e894e82c..f109904f 100644 --- a/ferritin-examples/examples/amplify/main.rs +++ b/ferritin-examples/examples/amplify/main.rs @@ -88,16 +88,20 @@ fn main() -> Result<()> { let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; println!("Encoding......."); - let encoded = model.forward(&token_ids, None, false, false)?; + let encoded = model.forward(&token_ids, None, false, true)?; println!("Predicting......."); let predictions = encoded.logits.argmax(D::Minus1)?; + println!("Pred Dims: {:?}", encoded.logits.dims()); println!("Decoding......."); let indices: Vec = predictions.to_vec2()?[0].to_vec(); let decoded = tokenizer.decode(indices.as_slice(), true); println!("Decoded: {:?}, ", decoded); + + let contact_map = encoded.get_contact_map()?; + println!("Contact Map Calculated: {:?}, ", contact_map); } Ok(())