-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Streamline the LigandMPNN Model (#92)
* move ligandmpnn to logits ino a method * move ligand mpnn testing to file * refactor the code to simpler methods; update invocation * update cargo deps * update docs
- Loading branch information
Showing
10 changed files
with
250 additions
and
128 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,25 @@ | ||
[workspace.package] | ||
version = "0.1.0" | ||
edition = "2021" | ||
authors = ["Zach Charlop-Powers<[email protected]>"] | ||
description = "Molecular visualization tools" | ||
license = "MIT OR Apache-2.0" | ||
|
||
|
||
[workspace] | ||
members = [ | ||
"ferritin-bevy", | ||
"ferritin-cellscape", | ||
"ferritin-core", | ||
"ferritin-plms", | ||
"ferritin-examples", | ||
"ferritin-molviewspec", | ||
"ferritin-onnx-models", | ||
"ferritin-plms", | ||
"ferritin-pymol", | ||
"ferritin-test-data", | ||
# "ferritin-tui", | ||
] | ||
resolver = "2" | ||
|
||
[workspace.package] | ||
version = "0.1.0" | ||
authors = ["Zach Charlop-Powers<[email protected]>"] | ||
description = "Molecular visualization tools" | ||
license = "MIT OR Apache-2.0" | ||
edition = "2021" | ||
|
||
[workspace.dependencies] | ||
anyhow = "1.0" | ||
bitflags = "2.6.0" | ||
|
@@ -34,8 +34,14 @@ once_cell = "1.20.2" | |
pdbtbx = "0.12.0" | ||
rand = "0.8.5" | ||
safetensors = "0.4.5" | ||
serde = { version = "1.0", features = ["derive"] } | ||
serde_bytes = "0.11.15" | ||
serde_json = "1.0.134" | ||
serde_repr = "0.1.19" | ||
tokenizers = { version = "0.21.0", default-features = false } | ||
|
||
[workspace.dependencies.serde] | ||
version = "1.0" | ||
features = ["derive"] | ||
|
||
[workspace.dependencies.tokenizers] | ||
version = "0.21.0" | ||
default-features = false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,7 @@ | ||
# ESM2-ONNX | ||
# LigandMPNN-ONNX | ||
|
||
Convert ESM2_6T_35M to the ONNX format. Run it locally via RUST / [ORT](https://ort.pyke.io) | ||
|
||
```sh | ||
cargo run --example ligandmpnn-onnx -- --model-id 8M --protein-string \ | ||
MAFSAEDVLKEYDRR\<mask\>RMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL | ||
cargo run --example ligandmpnn-onnx | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,16 @@ | ||
use anyhow::Result; | ||
use candle_core::Device; | ||
use ferritin_core::{AtomCollection, StructureFeatures}; | ||
use ferritin_onnx_models::{ | ||
ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64, LigandMPNN, | ||
LigandMPNNModels, | ||
}; | ||
use ferritin_core::AtomCollection; | ||
use ferritin_onnx_models::LigandMPNN; | ||
use ferritin_test_data::TestFile; | ||
use ort::{ | ||
execution_providers::CUDAExecutionProvider, | ||
session::{builder::GraphOptimizationLevel, Session}, | ||
value::Tensor, | ||
}; | ||
|
||
fn main() -> Result<()> { | ||
let (encoder_path, decoder_path) = LigandMPNN::load_model_path(LigandMPNNModels::LigandMPNN)?; | ||
|
||
ort::init() | ||
.with_name("LigandMPNN") | ||
.with_execution_providers([CUDAExecutionProvider::default().build()]) | ||
.commit()?; | ||
|
||
// Common session builder configuration | ||
let session_config = Session::builder()? | ||
.with_optimization_level(GraphOptimizationLevel::Level1)? | ||
.with_intra_threads(1)?; | ||
|
||
let encoder_model = session_config.clone().commit_from_file(&encoder_path)?; | ||
let decoder_model = session_config.clone().commit_from_file(&decoder_path)?; | ||
|
||
// https://github.com/zachcp/ferritin/blob/main/ferritin-plms/src/ligandmpnn/ligandmpnn/configs.rs#L82 | ||
println!("Loading the Model and Tokenizer......."); | ||
let (protfile, _handle) = TestFile::protein_01().create_temp()?; | ||
let (pdb, _) = pdbtbx::open(protfile).expect("PDB/CIF"); | ||
let ac = AtomCollection::from(&pdb); | ||
|
||
println!("Creating the input Tensors......."); | ||
let device = Device::Cpu; | ||
let x_bb = ac.to_numeric_backbone_atoms(&device)?; | ||
let (lig_coords_array, lig_elements_array, lig_mask_array) = | ||
ac.to_numeric_ligand_atoms(&device)?; | ||
let data_nd = tensor_to_ndarray_f32(x_bb)?; | ||
let lig_coords_array_nd = tensor_to_ndarray_f32(lig_coords_array)?; | ||
let lig_elements_array_nd = tensor_to_ndarray_i64(lig_elements_array)?; | ||
let lig_mask_array_nd = tensor_to_ndarray_f32(lig_mask_array)?; | ||
|
||
println!("Runnning the Encoder Model......."); | ||
let encoder_outputs = encoder_model.run(ort::inputs![ | ||
"coords" => data_nd, | ||
"ligand_coords" => lig_coords_array_nd, | ||
"ligand_types" => lig_elements_array_nd, | ||
"ligand_mask" => lig_mask_array_nd | ||
]?)?; | ||
|
||
println!("Creating the Inpute to the Decoder......."); | ||
let h_V = encoder_outputs["h_V"].try_extract_tensor::<f32>()?; | ||
let h_E = encoder_outputs["h_E"].try_extract_tensor::<f32>()?; | ||
let E_idx = encoder_outputs["E_idx"].try_extract_tensor::<i64>()?; | ||
|
||
let position_tensor = { | ||
let data = vec![10 as i64]; // Single value | ||
let array = ndarray::Array::from_shape_vec([1], data)?; // Shape [1] | ||
Tensor::from_array(array)? | ||
}; | ||
|
||
println!("Temp and Position Are Hardcoded........"); | ||
let temp_tensor = { | ||
let data = vec![0.1 as f32]; // Single value | ||
let array = ndarray::Array::from_shape_vec([1], data)?; | ||
Tensor::from_array(array)? | ||
}; | ||
|
||
let decoder_outputs = decoder_model.run(ort::inputs![ | ||
"h_V" => h_V, | ||
"h_E" => h_E, | ||
"E_idx" => E_idx, | ||
"position" => position_tensor, | ||
"temperature" => temp_tensor, | ||
]?)?; | ||
|
||
println!("Decoder Outputs are logits."); | ||
let logits = decoder_outputs["logits"] | ||
.try_extract_tensor::<f32>()? | ||
.to_owned(); | ||
|
||
println!("Converted to Candle."); | ||
let logit_tensor = ndarray_to_tensor_f32(logits); | ||
println!("{:?}", logit_tensor); | ||
let model = LigandMPNN::new().unwrap(); | ||
let logits = model.run_model(ac, 10, 0.1).unwrap(); | ||
println!("{:?}", logits); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
//! Ferritin Onnx Models | ||
//! | ||
//! This crate provides easy access to various ONNX models for protein and ligand prediction. | ||
//! The models are downloaded from HuggingFace and run using ONNX Runtime. | ||
//! Currently supports ESM2 and LigandMPNN models. | ||
//! | ||
pub mod models; | ||
pub mod utilities; | ||
|
||
// pub use models::amplify::{AMPLIFYModels, AMPLIFY}; | ||
pub use models::esm2::{ESM2Models, ESM2}; | ||
pub use models::ligandmpnn::{LigandMPNN, LigandMPNNModels}; | ||
pub use models::ligandmpnn::{LigandMPNN, ModelType}; | ||
pub use utilities::{ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.