diff --git a/posebench/data/components/plot_dataset_rmsd.py b/posebench/data/components/plot_dataset_rmsd.py index 1a60169..5f363c0 100644 --- a/posebench/data/components/plot_dataset_rmsd.py +++ b/posebench/data/components/plot_dataset_rmsd.py @@ -88,6 +88,7 @@ def calculate_usalign_metrics( def plot_dataset_rmsd( + dataset: str, dataset_name: str, pred_pdb_dir: str, ref_pdb_dir: str, @@ -111,7 +112,8 @@ def plot_dataset_rmsd( ): """Plot the RMSD between predicted and reference protein structures in a given dataset. - :param dataset_name: Name of the dataset. + :param dataset: Informal name of the dataset. + :param dataset_name: Formal name of the dataset. :param pred_pdb_dir: Directory containing predicted protein structures in PDB format. :param ref_pdb_dir: Directory containing reference protein structures in PDB format. :param output_dir: Directory to save the plots. @@ -209,12 +211,12 @@ def plot_dataset_rmsd( plt.clf() sns.histplot(dataset_df["TM-score"]) plt.title("Apo-To-Holo Protein TM-score") - plt.savefig(plot_dir / "a2h_TM-score_hist.png") + plt.savefig(plot_dir / f"{dataset}_a2h_TM-score_hist.png") plt.clf() sns.histplot(dataset_df["RMSD"]) plt.title("Apo-To-Holo Protein RMSD") - plt.savefig(plot_dir / "a2h_RMSD_hist.png") + plt.savefig(plot_dir / f"{dataset}_a2h_RMSD_hist.png") @hydra.main( @@ -233,6 +235,7 @@ def main(cfg: DictConfig): # NOTE: Make sure to update the `usalign_exec_path` value in `configs/data/components/plot_dataset_rmsd.yaml` to reflect where you have placed the US-align executable on your machine. plot_dataset_rmsd( + "astex_diverse", "Astex Diverse Set", os.path.join( cfg.data_dir, @@ -248,6 +251,7 @@ def main(cfg: DictConfig): ) plot_dataset_rmsd( + "posebusters_benchmark", "PoseBusters Benchmark Set", os.path.join( cfg.data_dir, @@ -264,6 +268,7 @@ def main(cfg: DictConfig): ) plot_dataset_rmsd( + "dockgen", "DockGen Set", os.path.join(cfg.data_dir, "dockgen_set", "dockgen_holo_aligned_predicted_structures"), os.path.join(cfg.data_dir, "dockgen_set"), @@ -276,6 +281,7 @@ def main(cfg: DictConfig): ) plot_dataset_rmsd( + "casp15", "CASP15 Set", os.path.join(cfg.data_dir, "casp15_set", "casp15_holo_aligned_predicted_structures"), os.path.join(cfg.data_dir, "casp15_set", "targets"), @@ -296,6 +302,7 @@ def main(cfg: DictConfig): ) # plot_dataset_rmsd( + # "casp15", # "CASP15 Set", # os.path.join(cfg.data_dir, "casp15_set", "casp15_holo_aligned_predicted_structures"), # os.path.join(cfg.data_dir, "casp15_set", "targets"),