Skip to content

Commit

Permalink
Restore Candle for Featurization (#90)
Browse files Browse the repository at this point in the history
* restore Candle

* add featurization for Ligands.
  • Loading branch information
zachcp authored Dec 28, 2024
1 parent e32c410 commit 8bc4f58
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions ferritin-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ license.workspace = true
description.workspace = true

[dependencies]
pdbtbx.workspace = true
candle-core.workspace = true
itertools.workspace = true
ndarray = { version = "0.16" }
lazy_static = "1.5.0"
pdbtbx.workspace = true
strum = { version = "0.25", features = ["derive"] }

[dev-dependencies]
Expand Down
6 changes: 3 additions & 3 deletions ferritin-core/src/featurize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
//! - Geometric features like distances, angles
//! - Chemical features like hydrophobicity, charge
//! - Evolutionary features from MSA profiles
mod ndarray_impl;
// mod ndarray_impl;
mod structure_features;
mod utilities;

pub use ndarray_impl::{ProteinFeatures, StructureFeatures};
pub use utilities::{aa1to_int, aa3to1, int_to_aa1};
pub use structure_features::StructureFeatures;
62 changes: 58 additions & 4 deletions ferritin-core/src/featurize/ndarray_impl.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::utilities::{aa1to_int, aa3to1, AAAtom};
use crate::AtomCollection;
use itertools::MultiUnzip;
use ndarray::{Array, Array1, Array2, Array4};
use ndarray::{Array, Array1, Array2, Array3, Array4};
use pdbtbx::Element;
use std::collections::{HashMap, HashSet};
use strum::IntoEnumIterator;
Expand Down Expand Up @@ -104,9 +104,62 @@ impl StructureFeatures for AtomCollection {
Array::from_shape_vec((1, res_count, 37, 3), atom37_data)
}

// https://github.com/dauparas/LigandMPNN/blob/091ab1ff5fb4d13854cf6a7c41ec531e1d9d3e67/data_utils.py#L833
// try:
// Y = np.array(other_atoms.getCoords(), dtype=np.float32)
// Y_t = list(other_atoms.getElements())
// Y_t = np.array(
// [
// element_dict[y_t.upper()] if y_t.upper() in element_list else 0
// for y_t in Y_t
// ],
// dtype=np.int32,
// )
// Y_m = (Y_t != 1) * (Y_t != 0)

// Y = Y[Y_m, :]
// Y_t = Y_t[Y_m]
// Y_m = Y_m[Y_m]
// except:
// Y = np.zeros([1, 3], np.float32)
// Y_t = np.zeros([1], np.int32)
// Y_m = np.zeros([1], np.int32)

// Y = input_dict["Y"]
// Y_t = input_dict["Y_t"]
// Y_m = input_dict["Y_m"]
//
// output_dict["Y"] = Y[None,]
// output_dict["Y_t"] = Y_t[None,]
// output_dict["Y_m"] = Y_m[None,]
// if not use_atom_context:
// output_dict["Y_m"] = 0.0 * output_dict["Y_m"]

// 5gb weights
//
// Name: ligand_coords, Type: Input {
// name: "ligand_coords",
// input_type: Tensor {
// ty: Float32,
// dimensions: [-1,-1,-1,3,],
// dimension_symbols: [Some("batch",),Some("sequence",),Some("num_atoms",),None,],
//
// Name: ligand_types, Type: Input {
// name: "ligand_types",
// input_type: Tensor {
// ty: Int64,
// dimensions: [-1,-1,-1,],
// dimension_symbols: [Some("batch"), Some("sequence"),Some("num_atoms"),],
//
// Name: ligand_mask, Type: Input {
// name: "ligand_mask",
// input_type: Tensor {
// ty: Float32,
// dimensions: [-1,-1,-1,],
// dimension_symbols: [Some("batch",),Some("sequence",),Some("num_atoms",),],},
fn to_numeric_ligand_atoms(
&self,
) -> Result<(Array2<f32>, Array1<f32>, Array2<f32>), Self::Error> {
) -> Result<(Array4<f32>, Array1<i64>, Array3<f32>), Self::Error> {
let (coords, elements): (Vec<[f32; 3]>, Vec<Element>) = self
.iter_residues_all()
.filter(|residue| {
Expand All @@ -124,10 +177,11 @@ impl StructureFeatures for AtomCollection {

let n_atoms = coords.len();
let coords_flat: Vec<f32> = coords.into_iter().flat_map(|[x, y, z]| [x, y, z]).collect();
let coords_array = Array::from_shape_vec((n_atoms, 3), coords_flat)?;
// this gets a 2D array. We also need a Batch array and a sequence array
let coords_array = Array::from_shape_vec((1, 1, n_atoms, 3), coords_flat)?;

let elements_array =
Array1::from_vec(elements.iter().map(|e| e.atomic_number() as f32).collect());
Array1::from_vec(elements.iter().map(|e| e.atomic_number() as i64).collect());

let mask_array = Array::ones((n_atoms, 3));

Expand Down
214 changes: 214 additions & 0 deletions ferritin-core/src/featurize/structure_features.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
//! Protein->Tensor utiilities useful for Machine Learning
use super::utilities::{aa1to_int, aa3to1, get_nearest_neighbours, int_to_aa1, AAAtom};
use crate::AtomCollection;
use candle_core::{DType, Device, Error as CandleError, IndexOp, Result, Tensor};
use itertools::MultiUnzip;
use pdbtbx::Element;
use strum::IntoEnumIterator;

// Helper Fns --------------------------------------
fn is_heavy_atom(element: &Element) -> bool {
!matches!(element, Element::H | Element::He)
}

///. Trait defining Protein->Tensor utiilities useful for Machine Learning
pub trait StructureFeatures {
/// Convert amino acid sequence to numeric representation
fn decode_amino_acids(&self, device: &Device) -> Result<Tensor>;

/// Convert amino acid sequence to numeric representation
fn encode_amino_acids(&self, device: &Device) -> Result<Tensor>;

/// Convert amino acid sequence to numeric representation
fn create_CB(&self, device: &Device) -> Result<Tensor>;

/// Get residue indices
fn get_res_index(&self) -> Vec<u32>;

/// Extract backbone atom coordinates (N, CA, C, O)
fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor>;

/// Extract all atom coordinates in standard ordering
fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor>;

/// Extract ligand atom coordinates and properties
fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)>;
}

impl StructureFeatures for AtomCollection {
/// Convert amino acid sequence to numeric representation
fn decode_amino_acids(&self, device: &Device) -> Result<Tensor> {
todo!()
}

/// Convert amino acid sequence to numeric representation
fn encode_amino_acids(&self, device: &Device) -> Result<Tensor> {
let n = self.iter_residues_aminoacid().count();
let s = self
.iter_residues_aminoacid()
.map(|res| res.res_name)
.map(|res| aa3to1(&res))
.map(|res| aa1to_int(res));

Ok(Tensor::from_iter(s, device)?.reshape((1, n))?)
}

/// Convert amino acid sequence to numeric representation
fn create_CB(&self, device: &Device) -> Result<Tensor> {
// N = input_dict["X"][:, 0, :]
// CA = input_dict["X"][:, 1, :]
// C = input_dict["X"][:, 2, :]
// b = CA - N
// c = C - CA
// a = torch.cross(b, c, axis=-1)
// CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
//
let backbone = self.to_numeric_backbone_atoms(device)?;
// Extract N, CA, C coordinates
let n = backbone.i((.., 0, ..))?; // First atom (N)
let ca = backbone.i((.., 1, ..))?; // Second atom (CA)
let c = backbone.i((.., 2, ..))?; // Third atom (C)

// Constants for CB calculation
let a_coeff = -0.58273431_f64;
let b_coeff = 0.56802827_f64;
let c_coeff = -0.54067466_f64;

// Calculate vectors
let b = (&ca - &n)?; // CA - N
let c = (&c - &ca)?; // C - CA

// Manual cross product components
// a_x = b_y * c_z - b_z * c_y
// a_y = b_z * c_x - b_x * c_z
// a_z = b_x * c_y - b_y * c_x
let b_x = b.i((.., 0))?;
let b_y = b.i((.., 1))?;
let b_z = b.i((.., 2))?;
let c_x = c.i((.., 0))?;
let c_y = c.i((.., 1))?;
let c_z = c.i((.., 2))?;

let a_x = ((&b_y * &c_z)? - (&b_z * &c_y)?)?;
let a_y = ((&b_z * &c_x)? - (&b_x * &c_z)?)?;
let a_z = ((&b_x * &c_y)? - (&b_y * &c_x)?)?;

// Stack the cross product components back together
let a = Tensor::stack(&[&a_x, &a_y, &a_z], 1)?;

// Final CB calculation: -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
let cb = (&a * a_coeff)? + (&b * b_coeff)? + (&c * c_coeff)? + &ca;

Ok(cb?)
}

/// Get residue indices
fn get_res_index(&self) -> Vec<u32> {
self.iter_residues_aminoacid()
.map(|res| res.res_id as u32)
.collect()
}

/// create numeric Tensor of shape [1, <sequence-length>, 4, 3] where the 4 is N/CA/C/O
fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor> {
let res_count = self.iter_residues_aminoacid().count();
let mut backbone_data = vec![0f32; res_count * 4 * 3];
for residue in self.iter_residues_aminoacid() {
let resid = residue.res_id as usize;
let backbone_atoms = [
residue.find_atom_by_name("N"),
residue.find_atom_by_name("CA"),
residue.find_atom_by_name("C"),
residue.find_atom_by_name("O"),
];
for (atom_idx, maybe_atom) in backbone_atoms.iter().enumerate() {
if let Some(atom) = maybe_atom {
let [x, y, z] = atom.coords;
let base_idx = (resid * 4 + atom_idx) * 3;
backbone_data[base_idx] = *x;
backbone_data[base_idx + 1] = *y;
backbone_data[base_idx + 2] = *z;
}
}
}
// Create tensor with shape [1,residues, 4, 3]
Tensor::from_vec(backbone_data, (1, res_count, 4, 3), &device)
}

/// create numeric Tensor of shape [1, <sequence-length>, 37, 3]
fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor> {
let res_count = self.iter_residues_aminoacid().count();
let mut atom37_data = vec![0f32; res_count * 37 * 3];

for (idx, residue) in self.iter_residues_aminoacid().enumerate() {
for atom_type in AAAtom::iter().filter(|&a| a != AAAtom::Unknown) {
if let Some(atom) = residue.find_atom_by_name(&atom_type.to_string()) {
let [x, y, z] = atom.coords;
let base_idx = (idx * 37 + atom_type as usize) * 3;
atom37_data[base_idx] = *x;
atom37_data[base_idx + 1] = *y;
atom37_data[base_idx + 2] = *z;
}
}
}
// Create tensor with shape [residues, 37, 3]
Tensor::from_vec(atom37_data, (1, res_count, 37, 3), &device)
}

// The purpose of this function it to create 3 output tensors that relate
// key information about a protein sequence and ligands it interacts with.
//
// The outputs are:
// - y: 4D tensor of dimensions (<batch=1>, <num_residues>, <number_of_ligand_atoms>, <coords=3>)
// - y_t: 1D tensor of dimension = <num_residues>
// - y_m: 3D tensor of dimensions: (<batch=1>, <num_residues>, <number_of_ligand_atoms>))
//
fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)> {
let number_of_ligand_atoms = 16;
let cutoff_for_score = 5.;
// keep only the non-protein, non-water residues that are heavy
let (coords, elements): (Vec<[f32; 3]>, Vec<Element>) = self
.iter_residues_all()
.filter(|residue| {
let res_name = &residue.res_name;
!residue.is_amino_acid() && res_name != "HOH" && res_name != "WAT"
})
.flat_map(|residue| {
residue
.iter_atoms()
.filter(|atom| is_heavy_atom(&atom.element))
.map(|atom| (*atom.coords, atom.element.clone()))
.collect::<Vec<_>>()
})
.multiunzip();

// raw starting tensors
let y = Tensor::from_slice(&coords.concat(), (coords.len(), 3), device)?;
let y_m = Tensor::ones_like(&y)?;
let y_t = Tensor::from_slice(
&elements
.iter()
.map(|e| e.atomic_number() as f32)
.collect::<Vec<_>>(),
(elements.len(),),
device,
)?;

// get the C-beta coordinate tensro.
let CB = self.create_CB(device)?;
let num_residues = CB.dim(0)?;
let mask = Tensor::zeros(num_residues, DType::F32, device)?;
let (y, y_t, y_m, d_xy) =
get_nearest_neighbours(&CB, &mask, &y, &y_t, &y_m, number_of_ligand_atoms)?;

let distance_mask = d_xy.lt(cutoff_for_score)?;
let y_m_first = y_m.i((.., 0))?;
let mask_xy = distance_mask.mul(&mask)?.mul(&y_m_first)?;

let y = y.unsqueeze(0)?;
let y_t = y_t.unsqueeze(0)?;
let y_m = y_m.unsqueeze(0)?; // mask_xy??

Ok((y, y_t, y_m))
}
}
47 changes: 47 additions & 0 deletions ferritin-core/src/featurize/utilities.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use strum::{Display, EnumIter, EnumString};

#[rustfmt::skip]
Expand Down Expand Up @@ -141,6 +142,52 @@ define_residues! {
UNK: "UNK", 'X', 20, [0.0, 0.0], [AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
}

// Use CB to find nearest Ligand coords.
pub fn get_nearest_neighbours(
CB: &Tensor,
mask: &Tensor,
Y: &Tensor,
Y_t: &Tensor,
Y_m: &Tensor,
number_of_ligand_atoms: i64,
) -> Result<(Tensor, Tensor, Tensor, Tensor)> {
let device = CB.device();
let num_residues = CB.dim(0)?;
let num_ligand_atoms = Y.dim(0)?;
let xyz_dims = 3;
let mask_CBY = mask.unsqueeze(1)?.matmul(&Y_m.unsqueeze(0)?)?;

// Calculate L2 distances
let CB_expanded = CB.unsqueeze(1)?;
let Y_expanded = Y.unsqueeze(0)?;
let diff = &CB_expanded - &Y_expanded;
let L2_AB = diff?.powf(2.0)?.sum(D::Minus1)?;
let complement_mask = mask_CBY.neg()? + 1.0;
let padding = complement_mask? * 1000.0;
let L2_AB = L2_AB.mul(&mask_CBY)?.add(&padding?)?;

let nn_idx = L2_AB.arg_sort_last_dim(false)?;
let nn_idx = nn_idx.narrow(1, 0, number_of_ligand_atoms as usize)?;

let D_AB_closest = L2_AB.gather(&nn_idx, 1)?.i((.., 0))?.sqrt()?;
let Y_r = Y
.unsqueeze(0)?
.expand((num_residues, num_ligand_atoms, xyz_dims))?;
let Y_t_r = Y_t.unsqueeze(0)?.expand((num_residues, num_ligand_atoms))?;
let Y_m_r = Y_m.unsqueeze(0)?.expand((num_residues, num_ligand_atoms))?;
let nn_idx_expanded =
nn_idx
.unsqueeze(2)?
.expand((num_residues, number_of_ligand_atoms as usize, xyz_dims))?;

// Gather nearest neighbors
let Y = Y_r.gather(&nn_idx_expanded, 1)?;
let Y_t = Y_t_r.gather(&nn_idx, 1)?;
let Y_m = Y_m_r.gather(&nn_idx, 1)?;

Ok((Y, Y_t, Y_m, D_AB_closest))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 8bc4f58

Please sign in to comment.