Skip to content

Commit

Permalink
Clean and Lint Pass (#99)
Browse files Browse the repository at this point in the history
* ruct lint pass

* a bit more

* new safetensors messe sup conversion
  • Loading branch information
zachcp authored Jan 3, 2025
1 parent b45f47a commit 68190a2
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 49 deletions.
32 changes: 20 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ members = [
"ferritin-plms",
"ferritin-pymol",
"ferritin-test-data",
# "ferritin-tui",
]
resolver = "2"

Expand All @@ -29,7 +28,7 @@ candle-hf-hub = "0.3.3"
candle-metal-kernels = "0.8.1"
candle-nn = "0.8.1"
candle-transformers = "0.8.1"
itertools = "0.13.0"
itertools = "0.14.0"
once_cell = "1.20.2"
pdbtbx = "0.12.0"
rand = "0.8.5"
Expand Down
4 changes: 2 additions & 2 deletions ferritin-core/src/featurize/structure_features.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Protein->Tensor utiilities useful for Machine Learning
use super::utilities::{aa1to_int, aa3to1, get_nearest_neighbours, int_to_aa1, AAAtom};
use super::utilities::{aa1to_int, aa3to1, get_nearest_neighbours, AAAtom};
use crate::AtomCollection;
use candle_core::{DType, Device, Error as CandleError, IndexOp, Result, Tensor, D};
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use itertools::MultiUnzip;
use pdbtbx::Element;
use strum::IntoEnumIterator;
Expand Down
6 changes: 1 addition & 5 deletions ferritin-core/src/featurize/utilities.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_core::{IndexOp, Result, Tensor, D};
use strum::{Display, EnumIter, EnumString};

#[rustfmt::skip]
Expand Down Expand Up @@ -329,11 +329,7 @@ pub fn get_nearest_neighbours(
#[cfg(test)]
mod tests {
use super::*;
// use crate::ligandmpnn::proteinfeatures::LMPNNFeatures;
use crate::AtomCollection;
use ferritin_test_data::TestFile;
use pdbtbx;
use pdbtbx::Element;

#[test]
fn test_residue_codes() {
Expand Down
3 changes: 0 additions & 3 deletions ferritin-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@ ferritin-bevy = { path = "../ferritin-bevy" }
ferritin-core = { path = "../ferritin-core" }
ferritin-plms = { path = "../ferritin-plms" }
ferritin-onnx-models = { path = "../ferritin-onnx-models" }
ndarray-safetensors = "0.2"
ndarray = "0.16"
ort = "=2.0.0-rc.9"
pdbtbx.workspace = true
safetensors.workspace = true
serde_json.workspace = true
# wonnx = "0.5.1"


[target.'cfg(target_os = "macos")'.features]
Expand Down
8 changes: 1 addition & 7 deletions ferritin-examples/examples/esm2-onnx/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
use anyhow::{Error as E, Result};
use anyhow::Result;
use clap::Parser;
use ferritin_onnx_models::{ESM2Models, ESM2};
use ndarray::Array2;
use ort::{
execution_providers::CUDAExecutionProvider,
session::{builder::GraphOptimizationLevel, Session},
};
use std::env;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
Expand Down
4 changes: 2 additions & 2 deletions ferritin-examples/examples/ligandmpnn-onnx/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ fn main() -> Result<()> {
let (protfile, _handle) = TestFile::protein_01().create_temp()?;
let (pdb, _) = pdbtbx::open(protfile).expect("PDB/CIF");
let ac = AtomCollection::from(&pdb);
let model = LigandMPNN::new().unwrap();
let logits = model.run_model(ac, 10, 0.1).unwrap();
let model = LigandMPNN::new()?;
let logits = model.run_model(ac, 10, 0.1)?;
println!("{:?}", logits);

Ok(())
Expand Down
13 changes: 6 additions & 7 deletions ferritin-onnx-models/src/models/esm2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ use anyhow::{anyhow, Result};
use candle_core::{Tensor, D};
use candle_hf_hub::api::sync::Api;
use candle_nn::ops;
use ndarray::s;
use ndarray::Array2;
use ort::{
execution_providers::CUDAExecutionProvider,
session::{
builder::{GraphOptimizationLevel, SessionBuilder},
output, Session,
},
value::Value,
builder::{GraphOptimizationLevel, SessionBuilder}
, Session,
}
,
};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
Expand Down Expand Up @@ -60,7 +59,7 @@ impl ESM2 {
})
}
pub fn load_model_path(model: ESM2Models) -> Result<PathBuf> {
let api = Api::new().unwrap();
let api = Api::new()?;
let repo_id = match model {
ESM2Models::ESM2_T6_8M => "zcpbx/esm2-t6-8m-UR50D-onnx",
ESM2Models::ESM2_T12_35M => "zcpbx/esm2-t12-35M-UR50D-onnx",
Expand All @@ -69,7 +68,7 @@ impl ESM2 {
}
.to_string();

let model_path = api.model(repo_id).get("model.onnx").unwrap();
let model_path = api.model(repo_id).get("model.onnx")?;
Ok(model_path)
}
pub fn load_tokenizer() -> Result<Tokenizer> {
Expand Down
4 changes: 2 additions & 2 deletions ferritin-plms/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ candle-core.workspace = true
candle-metal-kernels = { workspace = true, optional = true }
candle-nn.workspace = true
candle-transformers.workspace = true
clap = "4.5.23"
clap = { version = "4.5", features = ["derive"] }
ferritin-core = { path = "../ferritin-core" }
ferritin-test-data = { path = "../ferritin-test-data" }
itertools.workspace = true
Expand All @@ -41,4 +41,4 @@ candle-metal-kernels.workspace = true
[dev-dependencies]
assert_cmd = "2.0.16"
ferritin-test-data = { path = "../ferritin-test-data" }
tempfile = "3.14.0"
tempfile = "3.15.0"
6 changes: 3 additions & 3 deletions ferritin-plms/src/esm/layers/blocks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::attention::MultiHeadAttention;
use super::geom_attention::GeometricReasoningOriginalImpl;
use crate::esm::models::esmc::{ESMCConfig, Ffn_Type};
use crate::esm::models::esmc::{ESMCConfig, FfnType};
// use crate::esm::utils::structure::affine3d::Affine3D;
use candle_core::{Module, Result, Tensor, D};
use candle_nn::{self as nn, VarBuilder};
Expand Down Expand Up @@ -155,8 +155,8 @@ impl UnifiedTransformerBlock {
let geom_attn = None;

let ffn = match ffn_type {
Ffn_Type::SWIGLU => SwiGLU::load(vb.pp("ffn"), config)?,
_ => unimplemented!(), // Ffn_Type::GLU => unimplemented!(),
FfnType::SWIGLU => SwiGLU::load(vb.pp("ffn"), config)?,
_ => unimplemented!(), // FfnType::GLU => unimplemented!(),
};

Ok(Self {
Expand Down
6 changes: 3 additions & 3 deletions ferritin-plms/src/esm/models/esmc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl ESMTokenizer {
}

#[derive(Clone, Copy)]
pub enum Ffn_Type {
pub enum FfnType {
SWIGLU,
GLU,
}
Expand All @@ -49,7 +49,7 @@ pub struct ESMCConfig {
pub n_heads: usize,
pub n_layers: usize,
pub v_head_transformer: Option<usize>,
pub ffn_type: Ffn_Type,
pub ffn_type: FfnType,
pub tokenizer: ESMTokenizer,
// oringal above.
pub use_plain_attn: bool,
Expand Down Expand Up @@ -80,7 +80,7 @@ impl ESMCConfig {
n_heads: 15,
n_layers: 30,
v_head_transformer: None,
ffn_type: Ffn_Type::SWIGLU,
ffn_type: FfnType::SWIGLU,
tokenizer: ESMTokenizer::Esm3OpenSmall,
use_plain_attn: true,
n_layers_geom: 1,
Expand Down
2 changes: 1 addition & 1 deletion ferritin-test-data/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ license.workspace = true
description.workspace = true

[dependencies]
tempfile = "3.14.0"
tempfile = "3.15.0"

0 comments on commit 68190a2

Please sign in to comment.