Skip to content

Commit

Permalink
Streamline A2H plot names
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead committed Dec 28, 2024
1 parent b1ff7bc commit d70b5ba
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions posebench/data/components/plot_dataset_rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def calculate_usalign_metrics(


def plot_dataset_rmsd(
dataset: str,
dataset_name: str,
pred_pdb_dir: str,
ref_pdb_dir: str,
Expand All @@ -111,7 +112,8 @@ def plot_dataset_rmsd(
):
"""Plot the RMSD between predicted and reference protein structures in a given dataset.
:param dataset_name: Name of the dataset.
:param dataset: Informal name of the dataset.
:param dataset_name: Formal name of the dataset.
:param pred_pdb_dir: Directory containing predicted protein structures in PDB format.
:param ref_pdb_dir: Directory containing reference protein structures in PDB format.
:param output_dir: Directory to save the plots.
Expand Down Expand Up @@ -209,12 +211,12 @@ def plot_dataset_rmsd(
plt.clf()
sns.histplot(dataset_df["TM-score"])
plt.title("Apo-To-Holo Protein TM-score")
plt.savefig(plot_dir / "a2h_TM-score_hist.png")
plt.savefig(plot_dir / f"{dataset}_a2h_TM-score_hist.png")

plt.clf()
sns.histplot(dataset_df["RMSD"])
plt.title("Apo-To-Holo Protein RMSD")
plt.savefig(plot_dir / "a2h_RMSD_hist.png")
plt.savefig(plot_dir / f"{dataset}_a2h_RMSD_hist.png")


@hydra.main(
Expand All @@ -233,6 +235,7 @@ def main(cfg: DictConfig):
# NOTE: Make sure to update the `usalign_exec_path` value in `configs/data/components/plot_dataset_rmsd.yaml` to reflect where you have placed the US-align executable on your machine.

plot_dataset_rmsd(
"astex_diverse",
"Astex Diverse Set",
os.path.join(
cfg.data_dir,
Expand All @@ -248,6 +251,7 @@ def main(cfg: DictConfig):
)

plot_dataset_rmsd(
"posebusters_benchmark",
"PoseBusters Benchmark Set",
os.path.join(
cfg.data_dir,
Expand All @@ -264,6 +268,7 @@ def main(cfg: DictConfig):
)

plot_dataset_rmsd(
"dockgen",
"DockGen Set",
os.path.join(cfg.data_dir, "dockgen_set", "dockgen_holo_aligned_predicted_structures"),
os.path.join(cfg.data_dir, "dockgen_set"),
Expand All @@ -276,6 +281,7 @@ def main(cfg: DictConfig):
)

plot_dataset_rmsd(
"casp15",
"CASP15 Set",
os.path.join(cfg.data_dir, "casp15_set", "casp15_holo_aligned_predicted_structures"),
os.path.join(cfg.data_dir, "casp15_set", "targets"),
Expand All @@ -296,6 +302,7 @@ def main(cfg: DictConfig):
)

# plot_dataset_rmsd(
# "casp15",
# "CASP15 Set",
# os.path.join(cfg.data_dir, "casp15_set", "casp15_holo_aligned_predicted_structures"),
# os.path.join(cfg.data_dir, "casp15_set", "targets"),
Expand Down

0 comments on commit d70b5ba

Please sign in to comment.