Skip to content

Commit

Permalink
fix the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zachcp committed Dec 27, 2024
1 parent e96bda8 commit 3f69c2c
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 185 deletions.
40 changes: 35 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

299 changes: 119 additions & 180 deletions ferritin-core/src/featurize/ndarray_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,184 +378,123 @@ mod tests {
}
}

// #[test]
// fn test_all_atom37_tensor() {
// let device = Device::Cpu;
// let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap();
// let (pdb, _) = pdbtbx::open(pdb_file).unwrap();
// let ac = AtomCollection::from(&pdb);
// let ac_backbone_tensor: Tensor = ac.to_numeric_atom37(&device).expect("REASON");
// // batch size of 1154 residues; all atoms; positions
// assert_eq!(ac_backbone_tensor.dims(), &[1, 154, 37, 3]);

// // Check my residue coords in the Tensor
// // ATOM 1 N N . MET A 1 1 ? 24.277 8.374 -9.854 1.00 38.41 ? 0 MET A N 1
// // ATOM 2 C CA . MET A 1 1 ? 24.404 9.859 -9.939 1.00 37.90 ? 0 MET A CA 1
// // ATOM 3 C C . MET A 1 1 ? 25.814 10.249 -10.359 1.00 36.65 ? 0 MET A C 1
// // ATOM 4 O O . MET A 1 1 ? 26.748 9.469 -10.197 1.00 37.13 ? 0 MET A O 1
// // ATOM 5 C CB . MET A 1 1 ? 24.070 10.495 -8.596 1.00 39.58 ? 0 MET A CB 1
// // ATOM 6 C CG . MET A 1 1 ? 24.880 9.939 -7.442 1.00 41.49 ? 0 MET A CG 1
// // ATOM 7 S SD . MET A 1 1 ? 24.262 10.555 -5.873 1.00 44.70 ? 0 MET A SD 1
// // ATOM 8 C CE . MET A 1 1 ? 24.822 12.266 -5.967 1.00 41.59 ? 0 MET A CE 1
// //
// // pub enum AAAtom {
// // N = 0, CA = 1, C = 2, CB = 3, O = 4,
// // CG = 5, CG1 = 6, CG2 = 7, OG = 8, OG1 = 9,
// // SG = 10, CD = 11, CD1 = 12, CD2 = 13, ND1 = 14,
// // ND2 = 15, OD1 = 16, OD2 = 17, SD = 18, CE = 19,
// // CE1 = 20, CE2 = 21, CE3 = 22, NE = 23, NE1 = 24,
// // NE2 = 25, OE1 = 26, OE2 = 27, CH2 = 28, NH1 = 29,
// // NH2 = 30, OH = 31, CZ = 32, CZ2 = 33, CZ3 = 34,
// // NZ = 35, OXT = 36,
// // Unknown = -1,
// // }
// let allatom_coords = [
// // Methionine - AA00
// // We iterate through these positions. Not all AA's have each
// ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]),
// ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]),
// ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]),
// ("CB", (0, 0, 3, ..), vec![24.070, 10.495, -8.596]),
// ("O", (0, 0, 4, ..), vec![26.748, 9.469, -10.197]),
// ("CG", (0, 0, 5, ..), vec![24.880, 9.939, -7.442]),
// ("CG1", (0, 0, 6, ..), vec![0.0, 0.0, 0.0]),
// ("CG2", (0, 0, 7, ..), vec![0.0, 0.0, 0.0]),
// ("OG", (0, 0, 8, ..), vec![0.0, 0.0, 0.0]),
// ("OG1", (0, 0, 9, ..), vec![0.0, 0.0, 0.0]),
// ("SG", (0, 0, 10, ..), vec![0.0, 0.0, 0.0]),
// ("CD", (0, 0, 11, ..), vec![0.0, 0.0, 0.0]),
// ("CD1", (0, 0, 12, ..), vec![0.0, 0.0, 0.0]),
// ("CD2", (0, 0, 13, ..), vec![0.0, 0.0, 0.0]),
// ("ND1", (0, 0, 14, ..), vec![0.0, 0.0, 0.0]),
// ("ND2", (0, 0, 15, ..), vec![0.0, 0.0, 0.0]),
// ("OD1", (0, 0, 16, ..), vec![0.0, 0.0, 0.0]),
// ("OD2", (0, 0, 17, ..), vec![0.0, 0.0, 0.0]),
// ("SD", (0, 0, 18, ..), vec![24.262, 10.555, -5.873]),
// ("CE", (0, 0, 19, ..), vec![24.822, 12.266, -5.967]),
// ("CE1", (0, 0, 20, ..), vec![0.0, 0.0, 0.0]),
// ("CE2", (0, 0, 21, ..), vec![0.0, 0.0, 0.0]),
// ("CE3", (0, 0, 22, ..), vec![0.0, 0.0, 0.0]),
// ("NE", (0, 0, 23, ..), vec![0.0, 0.0, 0.0]),
// ("NE1", (0, 0, 24, ..), vec![0.0, 0.0, 0.0]),
// ("NE2", (0, 0, 25, ..), vec![0.0, 0.0, 0.0]),
// ("OE1", (0, 0, 26, ..), vec![0.0, 0.0, 0.0]),
// ("OE2", (0, 0, 27, ..), vec![0.0, 0.0, 0.0]),
// ("CH2", (0, 0, 28, ..), vec![0.0, 0.0, 0.0]),
// ("NH1", (0, 0, 29, ..), vec![0.0, 0.0, 0.0]),
// ("NH2", (0, 0, 30, ..), vec![0.0, 0.0, 0.0]),
// ("OH", (0, 0, 31, ..), vec![0.0, 0.0, 0.0]),
// ("CZ", (0, 0, 32, ..), vec![0.0, 0.0, 0.0]),
// ("CZ2", (0, 0, 33, ..), vec![0.0, 0.0, 0.0]),
// ("CZ3", (0, 0, 34, ..), vec![0.0, 0.0, 0.0]),
// ("NZ", (0, 0, 35, ..), vec![0.0, 0.0, 0.0]),
// ("OXT", (0, 0, 36, ..), vec![0.0, 0.0, 0.0]),
// ];
// for (atom_name, (b, i, j, k), expected) in allatom_coords {
// let actual: Vec<f32> = ac_backbone_tensor
// .i((b, i, j, k))
// .unwrap()
// .to_vec1()
// .unwrap();
// assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
// }
// }

// #[test]
// fn test_ligand_tensor() {
// let device = Device::Cpu;
// let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap();
// let (pdb, _) = pdbtbx::open(pdb_file).unwrap();
// let ac = AtomCollection::from(&pdb);
// let (ligand_coords, ligand_elements, _) =
// ac.to_numeric_ligand_atoms(&device).expect("REASON");

// // 54 residues; N/CA/C/O; positions
// assert_eq!(ligand_coords.dims(), &[54, 3]);

// // Check my residue coords in the Tensor
// //
// // HETATM 1222 S S . SO4 B 2 . ? 30.746 18.706 28.896 1.00 47.98 ? 157 SO4 A S 1
// // HETATM 1223 O O1 . SO4 B 2 . ? 30.697 20.077 28.620 1.00 48.06 ? 157 SO4 A O1 1
// // HETATM 1224 O O2 . SO4 B 2 . ? 31.104 18.021 27.725 1.00 47.52 ? 157 SO4 A O2 1
// // HETATM 1225 O O3 . SO4 B 2 . ? 29.468 18.179 29.331 1.00 47.79 ? 157 SO4 A O3 1
// // HETATM 1226 O O4 . SO4 B 2 . ? 31.722 18.578 29.881 1.00 47.85 ? 157 SO4 A O4 1
// let allatom_coords = [
// ("S", (0, ..), vec![30.746, 18.706, 28.896]),
// ("O1", (1, ..), vec![30.697, 20.077, 28.620]),
// ("O2", (2, ..), vec![31.104, 18.021, 27.725]),
// ("O3", (3, ..), vec![29.468, 18.179, 29.331]),
// ("O4", (4, ..), vec![31.722, 18.578, 29.881]),
// ];

// for (atom_name, (i, j), expected) in allatom_coords {
// let actual: Vec<f32> = ligand_coords.i((i, j)).unwrap().to_vec1().unwrap();
// assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
// }

// // Now check the elements
// let elements: Vec<&str> = ligand_elements
// .to_vec1::<f32>()
// .unwrap()
// .into_iter()
// .map(|elem| Element::new(elem as usize).unwrap().symbol())
// .collect();

// assert_eq!(elements[0], "S");
// assert_eq!(elements[1], "O");
// assert_eq!(elements[2], "O");
// assert_eq!(elements[3], "O");
// }

// #[test]
// fn test_backbone_tensor() {
// let device = Device::Cpu;
// let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap();
// let (pdb, _) = pdbtbx::open(pdb_file).unwrap();
// let ac = AtomCollection::from(&pdb);
// let xyz_37 = ac
// .to_numeric_atom37(&device)
// .expect("XYZ creation for all-atoms");
// assert_eq!(xyz_37.dims(), [1, 154, 37, 3]);

// // # xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords
// let xyz_m = create_backbone_mask_37(&xyz_37).expect("masking procedure should work");
// assert_eq!(xyz_m.dims(), &[1, 154, 37]);
// }

// #[test]
// fn test_compute_nearest_neighbors() {
// // let device = Device::Cpu;
// let test_dtype = DType::F32;

// // Create a simple 2x3x3 tensor representing 2 sequences of 3 points in 3D space
// let coords = Tensor::new(
// &[
// [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]], // First sequence
// [[0.0, 1.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]], // Second sequence
// ],
// &device,
// )
// .unwrap()
// .to_dtype(test_dtype)
// .unwrap();

// // Create mask indicating all points are valid
// let mask = Tensor::ones((2, 3), test_dtype, &device).unwrap();

// // Get 2 nearest neighbors for each point
// let (distances, indices) = compute_nearest_neighbors(&coords, &mask, 2, 1e-6).unwrap();

// // Check shapes
// assert_eq!(distances.dims(), &[2, 3, 2]); // [batch, seq_len, k]
// assert_eq!(indices.dims(), &[2, 3, 2]); // [batch, seq_len, k]

// // For first sequence, point [1,0,0] should have [0,0,0] and [2,0,0] as nearest neighbors
// let point_neighbors: Vec<u32> = indices.i((0, 1, ..)).unwrap().to_vec1().unwrap();
// assert_eq!(point_neighbors, vec![0, 2]);

// // Check distances are correct
// let point_distances: Vec<f32> = distances.i((0, 1, ..)).unwrap().to_vec1().unwrap();
// assert!((point_distances[0] - 1.0).abs() < 1e-5);
// assert!((point_distances[1] - 1.0).abs() < 1e-5);
// }
#[test]
fn test_all_atom37_tensor() {
let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap();
let (pdb, _) = pdbtbx::open(pdb_file).unwrap();
let ac = AtomCollection::from(&pdb);
let ac_backbone_tensor: Array4<f32> = ac.to_numeric_atom37().expect("REASON");
// batch size of 1154 residues; all atoms; positions
assert_eq!(ac_backbone_tensor.dim(), (1, 154, 37, 3));

// Check my residue coords in the Tensor
// ATOM 1 N N . MET A 1 1 ? 24.277 8.374 -9.854 1.00 38.41 ? 0 MET A N 1
// ATOM 2 C CA . MET A 1 1 ? 24.404 9.859 -9.939 1.00 37.90 ? 0 MET A CA 1
// ATOM 3 C C . MET A 1 1 ? 25.814 10.249 -10.359 1.00 36.65 ? 0 MET A C 1
// ATOM 4 O O . MET A 1 1 ? 26.748 9.469 -10.197 1.00 37.13 ? 0 MET A O 1
// ATOM 5 C CB . MET A 1 1 ? 24.070 10.495 -8.596 1.00 39.58 ? 0 MET A CB 1
// ATOM 6 C CG . MET A 1 1 ? 24.880 9.939 -7.442 1.00 41.49 ? 0 MET A CG 1
// ATOM 7 S SD . MET A 1 1 ? 24.262 10.555 -5.873 1.00 44.70 ? 0 MET A SD 1
// ATOM 8 C CE . MET A 1 1 ? 24.822 12.266 -5.967 1.00 41.59 ? 0 MET A CE 1
//
// pub enum AAAtom {
// N = 0, CA = 1, C = 2, CB = 3, O = 4,
// CG = 5, CG1 = 6, CG2 = 7, OG = 8, OG1 = 9,
// SG = 10, CD = 11, CD1 = 12, CD2 = 13, ND1 = 14,
// ND2 = 15, OD1 = 16, OD2 = 17, SD = 18, CE = 19,
// CE1 = 20, CE2 = 21, CE3 = 22, NE = 23, NE1 = 24,
// NE2 = 25, OE1 = 26, OE2 = 27, CH2 = 28, NH1 = 29,
// NH2 = 30, OH = 31, CZ = 32, CZ2 = 33, CZ3 = 34,
// NZ = 35, OXT = 36,
// Unknown = -1,
// }
let allatom_coords = [
// Methionine - AA00
// We iterate through these positions. Not all AA's have each
("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]),
("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]),
("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]),
("CB", (0, 0, 3, ..), vec![24.070, 10.495, -8.596]),
("O", (0, 0, 4, ..), vec![26.748, 9.469, -10.197]),
("CG", (0, 0, 5, ..), vec![24.880, 9.939, -7.442]),
("CG1", (0, 0, 6, ..), vec![0.0, 0.0, 0.0]),
("CG2", (0, 0, 7, ..), vec![0.0, 0.0, 0.0]),
("OG", (0, 0, 8, ..), vec![0.0, 0.0, 0.0]),
("OG1", (0, 0, 9, ..), vec![0.0, 0.0, 0.0]),
("SG", (0, 0, 10, ..), vec![0.0, 0.0, 0.0]),
("CD", (0, 0, 11, ..), vec![0.0, 0.0, 0.0]),
("CD1", (0, 0, 12, ..), vec![0.0, 0.0, 0.0]),
("CD2", (0, 0, 13, ..), vec![0.0, 0.0, 0.0]),
("ND1", (0, 0, 14, ..), vec![0.0, 0.0, 0.0]),
("ND2", (0, 0, 15, ..), vec![0.0, 0.0, 0.0]),
("OD1", (0, 0, 16, ..), vec![0.0, 0.0, 0.0]),
("OD2", (0, 0, 17, ..), vec![0.0, 0.0, 0.0]),
("SD", (0, 0, 18, ..), vec![24.262, 10.555, -5.873]),
("CE", (0, 0, 19, ..), vec![24.822, 12.266, -5.967]),
("CE1", (0, 0, 20, ..), vec![0.0, 0.0, 0.0]),
("CE2", (0, 0, 21, ..), vec![0.0, 0.0, 0.0]),
("CE3", (0, 0, 22, ..), vec![0.0, 0.0, 0.0]),
("NE", (0, 0, 23, ..), vec![0.0, 0.0, 0.0]),
("NE1", (0, 0, 24, ..), vec![0.0, 0.0, 0.0]),
("NE2", (0, 0, 25, ..), vec![0.0, 0.0, 0.0]),
("OE1", (0, 0, 26, ..), vec![0.0, 0.0, 0.0]),
("OE2", (0, 0, 27, ..), vec![0.0, 0.0, 0.0]),
("CH2", (0, 0, 28, ..), vec![0.0, 0.0, 0.0]),
("NH1", (0, 0, 29, ..), vec![0.0, 0.0, 0.0]),
("NH2", (0, 0, 30, ..), vec![0.0, 0.0, 0.0]),
("OH", (0, 0, 31, ..), vec![0.0, 0.0, 0.0]),
("CZ", (0, 0, 32, ..), vec![0.0, 0.0, 0.0]),
("CZ2", (0, 0, 33, ..), vec![0.0, 0.0, 0.0]),
("CZ3", (0, 0, 34, ..), vec![0.0, 0.0, 0.0]),
("NZ", (0, 0, 35, ..), vec![0.0, 0.0, 0.0]),
("OXT", (0, 0, 36, ..), vec![0.0, 0.0, 0.0]),
];
for (atom_name, (b, i, j, k), expected) in allatom_coords {
let actual: Vec<f32> = ac_backbone_tensor.slice(s![b, i, j, k]).to_vec();
assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
}
}

#[test]
fn test_ligand_tensor() {
let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap();
let (pdb, _) = pdbtbx::open(pdb_file).unwrap();
let ac = AtomCollection::from(&pdb);
let (ligand_coords, ligand_elements, _) = ac.to_numeric_ligand_atoms().expect("REASON");
// 54 residues; N/CA/C/O; positions
assert_eq!(ligand_coords.dim(), (54, 3));

// Check my residue coords in the Tensor
//
// HETATM 1222 S S . SO4 B 2 . ? 30.746 18.706 28.896 1.00 47.98 ? 157 SO4 A S 1
// HETATM 1223 O O1 . SO4 B 2 . ? 30.697 20.077 28.620 1.00 48.06 ? 157 SO4 A O1 1
// HETATM 1224 O O2 . SO4 B 2 . ? 31.104 18.021 27.725 1.00 47.52 ? 157 SO4 A O2 1
// HETATM 1225 O O3 . SO4 B 2 . ? 29.468 18.179 29.331 1.00 47.79 ? 157 SO4 A O3 1
// HETATM 1226 O O4 . SO4 B 2 . ? 31.722 18.578 29.881 1.00 47.85 ? 157 SO4 A O4 1
let allatom_coords = [
("S", (0, ..), vec![30.746, 18.706, 28.896]),
("O1", (1, ..), vec![30.697, 20.077, 28.620]),
("O2", (2, ..), vec![31.104, 18.021, 27.725]),
("O3", (3, ..), vec![29.468, 18.179, 29.331]),
("O4", (4, ..), vec![31.722, 18.578, 29.881]),
];

for (atom_name, (i, j), expected) in allatom_coords {
let actual: Vec<f32> = ligand_coords.slice(s![i, j]).to_vec();
assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
}

// Now check the elements
//
let elements: Vec<&str> = ligand_elements
.to_vec()
.into_iter()
.map(|elem| Element::new(elem as usize).unwrap().symbol())
.collect();

assert_eq!(elements[0], "S");
assert_eq!(elements[1], "O");
assert_eq!(elements[2], "O");
assert_eq!(elements[3], "O");
}
}

0 comments on commit 3f69c2c

Please sign in to comment.