Skip to content

Commit

Permalink
Streamline the LigandMPNN Model (#92)
Browse files Browse the repository at this point in the history
* move ligandmpnn to logits ino a method

* move ligand mpnn testing to file

* refactor the code to simpler methods;  update invocation

* update cargo deps

* update docs
  • Loading branch information
zachcp authored Dec 29, 2024
1 parent 0b99001 commit bf1c91e
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 128 deletions.
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 17 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
[workspace.package]
version = "0.1.0"
edition = "2021"
authors = ["Zach Charlop-Powers<[email protected]>"]
description = "Molecular visualization tools"
license = "MIT OR Apache-2.0"


[workspace]
members = [
"ferritin-bevy",
"ferritin-cellscape",
"ferritin-core",
"ferritin-plms",
"ferritin-examples",
"ferritin-molviewspec",
"ferritin-onnx-models",
"ferritin-plms",
"ferritin-pymol",
"ferritin-test-data",
# "ferritin-tui",
]
resolver = "2"

[workspace.package]
version = "0.1.0"
authors = ["Zach Charlop-Powers<[email protected]>"]
description = "Molecular visualization tools"
license = "MIT OR Apache-2.0"
edition = "2021"

[workspace.dependencies]
anyhow = "1.0"
bitflags = "2.6.0"
Expand All @@ -34,8 +34,14 @@ once_cell = "1.20.2"
pdbtbx = "0.12.0"
rand = "0.8.5"
safetensors = "0.4.5"
serde = { version = "1.0", features = ["derive"] }
serde_bytes = "0.11.15"
serde_json = "1.0.134"
serde_repr = "0.1.19"
tokenizers = { version = "0.21.0", default-features = false }

[workspace.dependencies.serde]
version = "1.0"
features = ["derive"]

[workspace.dependencies.tokenizers]
version = "0.21.0"
default-features = false
5 changes: 5 additions & 0 deletions docs/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ Native Rust access to pymol `pse` binary files. [rust-docs](https://zachcp.gith

Utilities for consuming and creating [molviewspec files](https://molstar.org/mol-view-spec/). [rust-docs](https://zachcp.github.io/ferritin/doc/ferritin_molviewspec/index.html). [source](https://github.com/zachcp/ferritin/tree/main/ferritin-molviewspec).

### Ferritin-ONNX-Models

Access to protein PLM via the ONNX runtime. [rust-docs](https://zachcp.github.io/ferritin/doc/ferritin_onnx_models/index.html). [source](https://github.com/zachcp/ferritin/tree/main/ferritin-onnx-models).


### Ferritin-PLMs

Utilities for working with protein language models. [rust-docs](https://zachcp.github.io/ferritin/doc/ferritin_plms/index.html). [source](https://github.com/zachcp/ferritin/tree/main/ferritin-plms).
2 changes: 1 addition & 1 deletion ferritin-bevy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ description.workspace = true

[dependencies]
bevy = "0.15.0"
bon = "3.3.1"
bon = "3.3.2"
ferritin-core = { path = "../ferritin-core" }
pdbtbx.workspace = true

Expand Down
5 changes: 2 additions & 3 deletions ferritin-examples/examples/ligandmpnn-onnx/Readme.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# ESM2-ONNX
# LigandMPNN-ONNX

Convert ESM2_6T_35M to the ONNX format. Run it locally via RUST / [ORT](https://ort.pyke.io)

```sh
cargo run --example ligandmpnn-onnx -- --model-id 8M --protein-string \
MAFSAEDVLKEYDRR\<mask\>RMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL
cargo run --example ligandmpnn-onnx
```
85 changes: 5 additions & 80 deletions ferritin-examples/examples/ligandmpnn-onnx/main.rs
Original file line number Diff line number Diff line change
@@ -1,91 +1,16 @@
use anyhow::Result;
use candle_core::Device;
use ferritin_core::{AtomCollection, StructureFeatures};
use ferritin_onnx_models::{
ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64, LigandMPNN,
LigandMPNNModels,
};
use ferritin_core::AtomCollection;
use ferritin_onnx_models::LigandMPNN;
use ferritin_test_data::TestFile;
use ort::{
execution_providers::CUDAExecutionProvider,
session::{builder::GraphOptimizationLevel, Session},
value::Tensor,
};

fn main() -> Result<()> {
let (encoder_path, decoder_path) = LigandMPNN::load_model_path(LigandMPNNModels::LigandMPNN)?;

ort::init()
.with_name("LigandMPNN")
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;

// Common session builder configuration
let session_config = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_intra_threads(1)?;

let encoder_model = session_config.clone().commit_from_file(&encoder_path)?;
let decoder_model = session_config.clone().commit_from_file(&decoder_path)?;

// https://github.com/zachcp/ferritin/blob/main/ferritin-plms/src/ligandmpnn/ligandmpnn/configs.rs#L82
println!("Loading the Model and Tokenizer.......");
let (protfile, _handle) = TestFile::protein_01().create_temp()?;
let (pdb, _) = pdbtbx::open(protfile).expect("PDB/CIF");
let ac = AtomCollection::from(&pdb);

println!("Creating the input Tensors.......");
let device = Device::Cpu;
let x_bb = ac.to_numeric_backbone_atoms(&device)?;
let (lig_coords_array, lig_elements_array, lig_mask_array) =
ac.to_numeric_ligand_atoms(&device)?;
let data_nd = tensor_to_ndarray_f32(x_bb)?;
let lig_coords_array_nd = tensor_to_ndarray_f32(lig_coords_array)?;
let lig_elements_array_nd = tensor_to_ndarray_i64(lig_elements_array)?;
let lig_mask_array_nd = tensor_to_ndarray_f32(lig_mask_array)?;

println!("Runnning the Encoder Model.......");
let encoder_outputs = encoder_model.run(ort::inputs![
"coords" => data_nd,
"ligand_coords" => lig_coords_array_nd,
"ligand_types" => lig_elements_array_nd,
"ligand_mask" => lig_mask_array_nd
]?)?;

println!("Creating the Inpute to the Decoder.......");
let h_V = encoder_outputs["h_V"].try_extract_tensor::<f32>()?;
let h_E = encoder_outputs["h_E"].try_extract_tensor::<f32>()?;
let E_idx = encoder_outputs["E_idx"].try_extract_tensor::<i64>()?;

let position_tensor = {
let data = vec![10 as i64]; // Single value
let array = ndarray::Array::from_shape_vec([1], data)?; // Shape [1]
Tensor::from_array(array)?
};

println!("Temp and Position Are Hardcoded........");
let temp_tensor = {
let data = vec![0.1 as f32]; // Single value
let array = ndarray::Array::from_shape_vec([1], data)?;
Tensor::from_array(array)?
};

let decoder_outputs = decoder_model.run(ort::inputs![
"h_V" => h_V,
"h_E" => h_E,
"E_idx" => E_idx,
"position" => position_tensor,
"temperature" => temp_tensor,
]?)?;

println!("Decoder Outputs are logits.");
let logits = decoder_outputs["logits"]
.try_extract_tensor::<f32>()?
.to_owned();

println!("Converted to Candle.");
let logit_tensor = ndarray_to_tensor_f32(logits);
println!("{:?}", logit_tensor);
let model = LigandMPNN::new().unwrap();
let logits = model.run_model(ac, 10, 0.1).unwrap();
println!("{:?}", logits);

Ok(())
}
1 change: 1 addition & 0 deletions ferritin-onnx-models/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ ferritin-core = { path = "../ferritin-core" }
[dev-dependencies]
anyhow.workspace = true
ferritin-test-data = { path = "../ferritin-test-data" }
pdbtbx.workspace = true
8 changes: 7 additions & 1 deletion ferritin-onnx-models/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
//! Ferritin Onnx Models
//!
//! This crate provides easy access to various ONNX models for protein and ligand prediction.
//! The models are downloaded from HuggingFace and run using ONNX Runtime.
//! Currently supports ESM2 and LigandMPNN models.
//!
pub mod models;
pub mod utilities;

// pub use models::amplify::{AMPLIFYModels, AMPLIFY};
pub use models::esm2::{ESM2Models, ESM2};
pub use models::ligandmpnn::{LigandMPNN, LigandMPNNModels};
pub use models::ligandmpnn::{LigandMPNN, ModelType};
pub use utilities::{ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64};
11 changes: 10 additions & 1 deletion ferritin-onnx-models/src/models/esm2/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
//! ESM2 Struct. Loads the hf tokenizer
//! ESM2 Tokenizer. Models converted to ONNX format from [ESM2](https://github.com/facebookresearch/esm)
//! and uploaded to HuggingFace hub. The tokenizer is included in this crate and loaded from
//! memory using `tokenizer.json`. This is fairly minimal - for the full set of ESM2 models
//! please see the ESM2 repository and the HuggingFace hub.
//!
//! # Models:
//! * ESM2_T6_8M - small 6-layer protein language model
//! * ESM2_T12_35M - medium 12-layer protein language model
//! * ESM2_T30_150M - large 30-layer protein language model
//!
use anyhow::{anyhow, Result};
use candle_hf_hub::api::sync::Api;
use std::path::PathBuf;
Expand Down
Loading

0 comments on commit bf1c91e

Please sign in to comment.