From bf1c91e09d894e80a3e78803488d96b0eef22a83 Mon Sep 17 00:00:00 2001 From: zachcp Date: Sun, 29 Dec 2024 16:59:35 -0500 Subject: [PATCH] Streamline the LigandMPNN Model (#92) * move ligandmpnn to logits ino a method * move ligand mpnn testing to file * refactor the code to simpler methods; update invocation * update cargo deps * update docs --- Cargo.lock | 13 +- Cargo.toml | 28 ++- docs/index.qmd | 5 + ferritin-bevy/Cargo.toml | 2 +- .../examples/ligandmpnn-onnx/Readme.md | 5 +- .../examples/ligandmpnn-onnx/main.rs | 85 +------ ferritin-onnx-models/Cargo.toml | 1 + ferritin-onnx-models/src/lib.rs | 8 +- ferritin-onnx-models/src/models/esm2/mod.rs | 11 +- .../src/models/ligandmpnn/mod.rs | 220 ++++++++++++++++-- 10 files changed, 250 insertions(+), 128 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 105e675f..80ee3176 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1469,9 +1469,9 @@ dependencies = [ [[package]] name = "bon" -version = "3.3.1" +version = "3.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61030e4aaccae9727cc388843dcc7ad1fb9e1ccdef5571e3e8393976b49b74ce" +checksum = "fe7acc34ff59877422326db7d6f2d845a582b16396b6b08194942bf34c6528ab" dependencies = [ "bon-macros", "rustversion", @@ -1479,9 +1479,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.3.1" +version = "3.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67d5d25cc9bd33120702000acc60836db15f06eabc4466230bf79dc80bd0a6ee" +checksum = "4159dd617a7fbc9be6a692fe69dc2954f8e6bb6bb5e4d7578467441390d77fd0" dependencies = [ "darling", "ident_case", @@ -2598,6 +2598,7 @@ dependencies = [ "ndarray", "ndarray-safetensors", "ort", + "pdbtbx", "safetensors", "tokenizers 0.21.0", ] @@ -3142,9 +3143,9 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "glow" diff --git a/Cargo.toml b/Cargo.toml index 24f018e7..abd2a163 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,25 +1,25 @@ -[workspace.package] -version = "0.1.0" -edition = "2021" -authors = ["Zach Charlop-Powers"] -description = "Molecular visualization tools" -license = "MIT OR Apache-2.0" - - [workspace] members = [ "ferritin-bevy", "ferritin-cellscape", "ferritin-core", - "ferritin-plms", "ferritin-examples", "ferritin-molviewspec", "ferritin-onnx-models", + "ferritin-plms", "ferritin-pymol", "ferritin-test-data", + # "ferritin-tui", ] resolver = "2" +[workspace.package] +version = "0.1.0" +authors = ["Zach Charlop-Powers"] +description = "Molecular visualization tools" +license = "MIT OR Apache-2.0" +edition = "2021" + [workspace.dependencies] anyhow = "1.0" bitflags = "2.6.0" @@ -34,8 +34,14 @@ once_cell = "1.20.2" pdbtbx = "0.12.0" rand = "0.8.5" safetensors = "0.4.5" -serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11.15" serde_json = "1.0.134" serde_repr = "0.1.19" -tokenizers = { version = "0.21.0", default-features = false } + +[workspace.dependencies.serde] +version = "1.0" +features = ["derive"] + +[workspace.dependencies.tokenizers] +version = "0.21.0" +default-features = false diff --git a/docs/index.qmd b/docs/index.qmd index 8c84fd5c..93f3ccde 100644 --- a/docs/index.qmd +++ b/docs/index.qmd @@ -39,6 +39,11 @@ Native Rust access to pymol `pse` binary files. [rust-docs](https://zachcp.gith Utilities for consuming and creating [molviewspec files](https://molstar.org/mol-view-spec/). [rust-docs](https://zachcp.github.io/ferritin/doc/ferritin_molviewspec/index.html). [source](https://github.com/zachcp/ferritin/tree/main/ferritin-molviewspec). +### Ferritin-ONNX-Models + +Access to protein PLM via the ONNX runtime. [rust-docs](https://zachcp.github.io/ferritin/doc/ferritin_onnx_models/index.html). [source](https://github.com/zachcp/ferritin/tree/main/ferritin-onnx-models). + + ### Ferritin-PLMs Utilities for working with protein language models. [rust-docs](https://zachcp.github.io/ferritin/doc/ferritin_plms/index.html). [source](https://github.com/zachcp/ferritin/tree/main/ferritin-plms). diff --git a/ferritin-bevy/Cargo.toml b/ferritin-bevy/Cargo.toml index 2702dfbe..d6fbbf68 100644 --- a/ferritin-bevy/Cargo.toml +++ b/ferritin-bevy/Cargo.toml @@ -8,7 +8,7 @@ description.workspace = true [dependencies] bevy = "0.15.0" -bon = "3.3.1" +bon = "3.3.2" ferritin-core = { path = "../ferritin-core" } pdbtbx.workspace = true diff --git a/ferritin-examples/examples/ligandmpnn-onnx/Readme.md b/ferritin-examples/examples/ligandmpnn-onnx/Readme.md index 35a3230d..de000c62 100644 --- a/ferritin-examples/examples/ligandmpnn-onnx/Readme.md +++ b/ferritin-examples/examples/ligandmpnn-onnx/Readme.md @@ -1,8 +1,7 @@ -# ESM2-ONNX +# LigandMPNN-ONNX Convert ESM2_6T_35M to the ONNX format. Run it locally via RUST / [ORT](https://ort.pyke.io) ```sh -cargo run --example ligandmpnn-onnx -- --model-id 8M --protein-string \ -MAFSAEDVLKEYDRR\RMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL +cargo run --example ligandmpnn-onnx ``` diff --git a/ferritin-examples/examples/ligandmpnn-onnx/main.rs b/ferritin-examples/examples/ligandmpnn-onnx/main.rs index 989a1aa3..ddf93a77 100644 --- a/ferritin-examples/examples/ligandmpnn-onnx/main.rs +++ b/ferritin-examples/examples/ligandmpnn-onnx/main.rs @@ -1,91 +1,16 @@ use anyhow::Result; -use candle_core::Device; -use ferritin_core::{AtomCollection, StructureFeatures}; -use ferritin_onnx_models::{ - ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64, LigandMPNN, - LigandMPNNModels, -}; +use ferritin_core::AtomCollection; +use ferritin_onnx_models::LigandMPNN; use ferritin_test_data::TestFile; -use ort::{ - execution_providers::CUDAExecutionProvider, - session::{builder::GraphOptimizationLevel, Session}, - value::Tensor, -}; fn main() -> Result<()> { - let (encoder_path, decoder_path) = LigandMPNN::load_model_path(LigandMPNNModels::LigandMPNN)?; - - ort::init() - .with_name("LigandMPNN") - .with_execution_providers([CUDAExecutionProvider::default().build()]) - .commit()?; - - // Common session builder configuration - let session_config = Session::builder()? - .with_optimization_level(GraphOptimizationLevel::Level1)? - .with_intra_threads(1)?; - - let encoder_model = session_config.clone().commit_from_file(&encoder_path)?; - let decoder_model = session_config.clone().commit_from_file(&decoder_path)?; - - // https://github.com/zachcp/ferritin/blob/main/ferritin-plms/src/ligandmpnn/ligandmpnn/configs.rs#L82 println!("Loading the Model and Tokenizer......."); let (protfile, _handle) = TestFile::protein_01().create_temp()?; let (pdb, _) = pdbtbx::open(protfile).expect("PDB/CIF"); let ac = AtomCollection::from(&pdb); - - println!("Creating the input Tensors......."); - let device = Device::Cpu; - let x_bb = ac.to_numeric_backbone_atoms(&device)?; - let (lig_coords_array, lig_elements_array, lig_mask_array) = - ac.to_numeric_ligand_atoms(&device)?; - let data_nd = tensor_to_ndarray_f32(x_bb)?; - let lig_coords_array_nd = tensor_to_ndarray_f32(lig_coords_array)?; - let lig_elements_array_nd = tensor_to_ndarray_i64(lig_elements_array)?; - let lig_mask_array_nd = tensor_to_ndarray_f32(lig_mask_array)?; - - println!("Runnning the Encoder Model......."); - let encoder_outputs = encoder_model.run(ort::inputs![ - "coords" => data_nd, - "ligand_coords" => lig_coords_array_nd, - "ligand_types" => lig_elements_array_nd, - "ligand_mask" => lig_mask_array_nd - ]?)?; - - println!("Creating the Inpute to the Decoder......."); - let h_V = encoder_outputs["h_V"].try_extract_tensor::()?; - let h_E = encoder_outputs["h_E"].try_extract_tensor::()?; - let E_idx = encoder_outputs["E_idx"].try_extract_tensor::()?; - - let position_tensor = { - let data = vec![10 as i64]; // Single value - let array = ndarray::Array::from_shape_vec([1], data)?; // Shape [1] - Tensor::from_array(array)? - }; - - println!("Temp and Position Are Hardcoded........"); - let temp_tensor = { - let data = vec![0.1 as f32]; // Single value - let array = ndarray::Array::from_shape_vec([1], data)?; - Tensor::from_array(array)? - }; - - let decoder_outputs = decoder_model.run(ort::inputs![ - "h_V" => h_V, - "h_E" => h_E, - "E_idx" => E_idx, - "position" => position_tensor, - "temperature" => temp_tensor, - ]?)?; - - println!("Decoder Outputs are logits."); - let logits = decoder_outputs["logits"] - .try_extract_tensor::()? - .to_owned(); - - println!("Converted to Candle."); - let logit_tensor = ndarray_to_tensor_f32(logits); - println!("{:?}", logit_tensor); + let model = LigandMPNN::new().unwrap(); + let logits = model.run_model(ac, 10, 0.1).unwrap(); + println!("{:?}", logits); Ok(()) } diff --git a/ferritin-onnx-models/Cargo.toml b/ferritin-onnx-models/Cargo.toml index 35b36c25..7b6a78b1 100644 --- a/ferritin-onnx-models/Cargo.toml +++ b/ferritin-onnx-models/Cargo.toml @@ -20,3 +20,4 @@ ferritin-core = { path = "../ferritin-core" } [dev-dependencies] anyhow.workspace = true ferritin-test-data = { path = "../ferritin-test-data" } +pdbtbx.workspace = true diff --git a/ferritin-onnx-models/src/lib.rs b/ferritin-onnx-models/src/lib.rs index 9266d304..81d8954a 100644 --- a/ferritin-onnx-models/src/lib.rs +++ b/ferritin-onnx-models/src/lib.rs @@ -1,7 +1,13 @@ +//! Ferritin Onnx Models +//! +//! This crate provides easy access to various ONNX models for protein and ligand prediction. +//! The models are downloaded from HuggingFace and run using ONNX Runtime. +//! Currently supports ESM2 and LigandMPNN models. +//! pub mod models; pub mod utilities; // pub use models::amplify::{AMPLIFYModels, AMPLIFY}; pub use models::esm2::{ESM2Models, ESM2}; -pub use models::ligandmpnn::{LigandMPNN, LigandMPNNModels}; +pub use models::ligandmpnn::{LigandMPNN, ModelType}; pub use utilities::{ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64}; diff --git a/ferritin-onnx-models/src/models/esm2/mod.rs b/ferritin-onnx-models/src/models/esm2/mod.rs index 9378cc80..efe0813c 100644 --- a/ferritin-onnx-models/src/models/esm2/mod.rs +++ b/ferritin-onnx-models/src/models/esm2/mod.rs @@ -1,5 +1,14 @@ -//! ESM2 Struct. Loads the hf tokenizer +//! ESM2 Tokenizer. Models converted to ONNX format from [ESM2](https://github.com/facebookresearch/esm) +//! and uploaded to HuggingFace hub. The tokenizer is included in this crate and loaded from +//! memory using `tokenizer.json`. This is fairly minimal - for the full set of ESM2 models +//! please see the ESM2 repository and the HuggingFace hub. //! +//! # Models: +//! * ESM2_T6_8M - small 6-layer protein language model +//! * ESM2_T12_35M - medium 12-layer protein language model +//! * ESM2_T30_150M - large 30-layer protein language model +//! + use anyhow::{anyhow, Result}; use candle_hf_hub::api::sync::Api; use std::path::PathBuf; diff --git a/ferritin-onnx-models/src/models/ligandmpnn/mod.rs b/ferritin-onnx-models/src/models/ligandmpnn/mod.rs index 5004a1b0..bb30f3b4 100644 --- a/ferritin-onnx-models/src/models/ligandmpnn/mod.rs +++ b/ferritin-onnx-models/src/models/ligandmpnn/mod.rs @@ -1,43 +1,213 @@ -//! ESM2 Struct. Loads the hf tokenizer +//! Module for running Ligand- and Protein-MPNN Models //! -use anyhow::{anyhow, Result}; +//! This module provides functionality for running LigandMPNN and ProteinMPNN models +//! to predict amino acid sequences given protein structure coordinates and ligand information. +//! +//! The models are loaded from the Hugging Face model hub and executed using ONNX Runtime. +//! +//! +use crate::{ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64}; +use anyhow::Result; +use candle_core::{Device, Tensor}; use candle_hf_hub::api::sync::Api; +use ferritin_core::{AtomCollection, StructureFeatures}; +use ndarray::ArrayBase; +use ort::{ + execution_providers::CUDAExecutionProvider, + session::{ + builder::{GraphOptimizationLevel, SessionBuilder}, + Session, + }, +}; use std::path::PathBuf; -pub enum LigandMPNNModels { - ProteinMPNN, - LigandMPNN, -} +type NdArrayF32 = ArrayBase, ndarray::Dim>; +type NdArrayI64 = ArrayBase, ndarray::Dim>; -pub struct LigandMPNN {} +pub enum ModelType { + Protein, + Ligand, +} -impl LigandMPNN { - pub fn load_model_path(model: LigandMPNNModels) -> Result<(PathBuf, PathBuf)> { - let api = Api::new().unwrap(); - let (repo_id, encoder_filename, decoder_filename) = match model { - LigandMPNNModels::ProteinMPNN => ( - "zcpbx/proteinmpnn-v48-030-onnx".to_string(), +impl ModelType { + pub fn get_paths(&self) -> (&'static str, &'static str, &'static str) { + match self { + ModelType::Protein => ( + "zcpbx/proteinmpnn-v48-030-onnx", "protmpnn_encoder.onnx", "protmpnn_decoder_step.onnx", ), - LigandMPNNModels::LigandMPNN => ( - "zcpbx/ligandmpnn-v32-030-25-onnx".to_string(), + ModelType::Ligand => ( + "zcpbx/ligandmpnn-v32-030-25-onnx", "ligand_encoder.onnx", "ligand_decoder.onnx", ), - }; - let encoder_path = api.model(repo_id.clone()).get(encoder_filename).unwrap(); - let decoder_path = api.model(repo_id).get(decoder_filename).unwrap(); - Ok((encoder_path, decoder_path)) - } - // pub fn load_tokenizer() -> Result { - // let tokenizer_bytes = include_bytes!("tokenizer.json"); - // Tokenizer::from_bytes(tokenizer_bytes) - // .map_err(|e| anyhow!("Failed to load tokenizer: {}", e)) - // } + } + } +} + +/// A deep learning model for predicting amino acid sequences from protein structure coordinates. +/// +/// This model comes in two variants: +/// * `Protein` - The original ProteinMPNN model for protein sequence design +/// * `Ligand` - An extended version that considers ligand information +/// +/// # Example +/// ```no_run +/// use ferritin_onnx_models::ModelType; +/// let model_type = ModelType::Ligand; +/// let paths = model_type.get_paths(); +/// ``` +/// +/// The models are loaded from pre-trained ONNX format files hosted on the Hugging Face model hub. +/// Each model consists of an encoder and decoder component that work together to generate +/// amino acid sequence predictions based on structural information. +/// +pub struct LigandMPNN { + session: SessionBuilder, + encoder_path: PathBuf, + decoder_path: PathBuf, +} + +impl LigandMPNN { + pub fn new() -> Result { + let session = Self::create_session()?; + let (encoder_path, decoder_path) = Self::load_model_paths(ModelType::Ligand)?; + + Ok(Self { + session, + encoder_path, + decoder_path, + }) + } + + fn create_session() -> Result { + ort::init() + .with_name("LigandMPNN") + .with_execution_providers([CUDAExecutionProvider::default().build()]) + .commit()?; + Ok(Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level1)? + .with_intra_threads(1)?) + } + + fn load_model_paths(model_type: ModelType) -> Result<(PathBuf, PathBuf)> { + let api = Api::new()?; + let (repo_id, encoder_name, decoder_name) = model_type.get_paths(); + Ok(( + api.model(repo_id.to_string()).get(&encoder_name)?, + api.model(repo_id.to_string()).get(&decoder_name)?, + )) + } + + pub fn run_model(&self, ac: AtomCollection, position: i64, temperature: f32) -> Result { + let (h_V, h_E, E_idx) = self.run_encoder(&ac)?; + self.run_decoder(h_V, h_E, E_idx, temperature, position) + } + + pub fn run_encoder(&self, ac: &AtomCollection) -> Result<(NdArrayF32, NdArrayF32, NdArrayI64)> { + let device = Device::Cpu; + let encoder_model = self.session.clone().commit_from_file(&self.encoder_path)?; + + let x_bb = ac.to_numeric_backbone_atoms(&device)?; + let (lig_coords, lig_elements, lig_mask) = ac.to_numeric_ligand_atoms(&device)?; + let coords_nd = tensor_to_ndarray_f32(x_bb)?; + let lig_coords_nd = tensor_to_ndarray_f32(lig_coords)?; + let lig_types_nd = tensor_to_ndarray_i64(lig_elements)?; + let lig_mask_nd = tensor_to_ndarray_f32(lig_mask)?; + + let encoder_inputs = ort::inputs![ + "coords" => coords_nd, + "ligand_coords" => lig_coords_nd, + "ligand_types" => lig_types_nd, + "ligand_mask" => lig_mask_nd + ]?; + + let encoder_outputs = encoder_model.run(encoder_inputs)?; + + Ok(( + encoder_outputs["h_V"] + .try_extract_tensor::()? + .to_owned(), + encoder_outputs["h_E"] + .try_extract_tensor::()? + .to_owned(), + encoder_outputs["E_idx"] + .try_extract_tensor::()? + .to_owned(), + )) + } + + pub fn run_decoder( + &self, + h_V: NdArrayF32, + h_E: NdArrayF32, + E_idx: NdArrayI64, + temperature: f32, + position: i64, + ) -> Result { + let decoder_model = self.session.clone().commit_from_file(&self.decoder_path)?; + + let position_tensor = + ort::value::Tensor::from_array(ndarray::Array::from_shape_vec([1], vec![position])?)?; + + let temp_tensor = ort::value::Tensor::from_array(ndarray::Array::from_shape_vec( + [1], + vec![temperature], + )?)?; + + let decoder_inputs = ort::inputs![ + "h_V" => h_V, + "h_E" => h_E, + "E_idx" => E_idx, + "position" => position_tensor, + "temperature" => temp_tensor, + ]?; + + let decoder_outputs = decoder_model.run(decoder_inputs)?; + let logits = decoder_outputs["logits"] + .try_extract_tensor::()? + .to_owned(); + + ndarray_to_tensor_f32(logits) + } } #[cfg(test)] mod tests { use super::*; + use ferritin_test_data::TestFile; + use pdbtbx; + + fn setup_test_data() -> AtomCollection { + let (protfile, _handle) = TestFile::protein_01().create_temp().unwrap(); + let (pdb, _) = pdbtbx::open(protfile).expect("PDB/CIF"); + AtomCollection::from(&pdb) + } + + #[test] + fn test_model_initialization() { + let model = LigandMPNN::new().unwrap(); + assert!(model.encoder_path.exists()); + assert!(model.decoder_path.exists()); + } + + #[test] + fn test_encoder_output_dimensions() { + let model = LigandMPNN::new().unwrap(); + let ac = setup_test_data(); + let (h_v, h_e, e_idx) = model.run_encoder(&ac).unwrap(); + + assert_eq!(h_v.shape(), &[1, 154, 128]); + assert_eq!(h_e.shape(), &[1, 154, 16, 128]); + assert_eq!(e_idx.shape(), &[1, 154, 16]); + } + + #[test] + fn test_full_pipeline() { + let model = LigandMPNN::new().unwrap(); + let ac = setup_test_data(); + let logits = model.run_model(ac, 10, 0.1).unwrap(); + assert_eq!(logits.dims2().unwrap(), (1, 21)); + } }