Skip to content

Commit

Permalink
Update CASP15 plots
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead committed Dec 17, 2024
1 parent 5938be9 commit aaef02b
Showing 1 changed file with 101 additions and 47 deletions.
148 changes: 101 additions & 47 deletions notebooks/casp15_inference_results_plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,47 +69,32 @@
"source": [
"# General variables\n",
"baseline_methods = [\n",
" \"vina_p2rank\",\n",
" \"diffdock\",\n",
" \"diffdockv1\",\n",
" \"dynamicbind\",\n",
" \"neuralplexer\",\n",
" \"neuralplexer_no_ilcl\",\n",
" \"rfaa\",\n",
" \"chai-lab\",\n",
" \"tulip\",\n",
" \"vina_diffdock\",\n",
" \"vina_p2rank\",\n",
" \"consensus_ensemble\",\n",
"]\n",
"max_num_repeats_per_method = 3\n",
"\n",
"# Mappings\n",
"method_mapping = {\n",
" \"vina_p2rank\": \"P2Rank-Vina\",\n",
" \"diffdock\": \"DiffDock-L\",\n",
" \"diffdockv1\": \"-> w/o SCT\",\n",
" \"dynamicbind\": \"DynamicBind\",\n",
" \"neuralplexer\": \"NeuralPLexer\",\n",
" \"neuralplexer_no_ilcl\": \"-> w/o ILCL\",\n",
" \"rfaa\": \"RoseTTAFold-AA\",\n",
" \"chai-lab\": \"Chai-1\",\n",
" \"tulip\": \"TULIP\",\n",
" \"vina_diffdock\": \"DiffDock-L-Vina\",\n",
" \"vina_p2rank\": \"P2Rank-Vina\",\n",
" \"consensus_ensemble\": \"Ensemble (Con)\",\n",
"}\n",
"\n",
"method_category_mapping = {\n",
" \"vina_p2rank\": \"Conventional blind\",\n",
" \"diffdock\": \"DL-based blind\",\n",
" \"diffdockv1\": \"DL-based blind\",\n",
" \"dynamicbind\": \"DL-based blind\",\n",
" \"neuralplexer\": \"DL-based blind\",\n",
" \"neuralplexer_no_ilcl\": \"DL-based blind\",\n",
" \"rfaa\": \"DL-based blind\",\n",
" \"chai-lab\": \"DL-based blind\",\n",
" \"tulip\": \"Conventional blind\",\n",
" \"vina_diffdock\": \"Conventional blind\",\n",
" \"vina_p2rank\": \"Conventional blind\",\n",
" \"consensus_ensemble\": \"Hybrid blind\",\n",
"}"
]
},
Expand Down Expand Up @@ -287,6 +272,16 @@
" return list(method_mapping.keys()).index(method)\n",
"\n",
"\n",
"def assign_category_index(category: str) -> str:\n",
" \"\"\"\n",
" Assign category index for plotting.\n",
"\n",
" :param category: Category name.\n",
" :return: Category index.\n",
" \"\"\"\n",
" return list(method_mapping.values()).index(category)\n",
"\n",
"\n",
"def categorize_method(method: str) -> str:\n",
" \"\"\"\n",
" Categorize method for plotting.\n",
Expand All @@ -312,6 +307,12 @@
"source": [
"# load and organize the CASP15 results CSV\n",
"for repeat_index in range(1, max_num_repeats_per_method + 1):\n",
" # PLIF metrics\n",
" globals()[f\"casp15_plif_metrics_csv_filepath_{repeat_index}\"] = \"casp15_plif_metrics.csv\"\n",
" globals()[f\"casp15_plif_metrics_table_{repeat_index}\"] = pd.read_csv(\n",
" globals()[f\"casp15_plif_metrics_csv_filepath_{repeat_index}\"]\n",
" )\n",
"\n",
" globals()[f\"scoring_results_table_{repeat_index}\"] = pd.concat(\n",
" [\n",
" globals()[f\"{method}{config}_scoring_results_table_{repeat_index}\"]\n",
Expand All @@ -326,6 +327,11 @@
" globals()[f\"scoring_results_table_{repeat_index}\"].loc[\n",
" :, \"method_assignment_index\"\n",
" ] = globals()[f\"scoring_results_table_{repeat_index}\"][\"method\"].apply(assign_method_index)\n",
" globals()[f\"casp15_plif_metrics_table_{repeat_index}\"].loc[\n",
" :, \"category_assignment_index\"\n",
" ] = globals()[f\"casp15_plif_metrics_table_{repeat_index}\"][\"Category\"].apply(\n",
" assign_category_index\n",
" )\n",
" globals()[f\"scoring_results_table_{repeat_index}\"].loc[:, \"RMSD ≤ 2 Å\"] = (\n",
" globals()[f\"scoring_results_table_{repeat_index}\"]\n",
" .loc[:, \"rmsd_≤_2å\"]\n",
Expand Down Expand Up @@ -531,15 +537,20 @@
"# RMSD ≤ 2 Å Bar Chart of CASP15 Set (Relaxed vs. Unrelaxed) Results #\n",
"\n",
"# prepare data for the bar charts to plot\n",
"colors = [\"#FB8072\", \"#BEBADA\"]\n",
"colors = [\"#FB8072\", \"#BEBADA\", \"#FCCDE5\"]\n",
"\n",
"bar_width = 0.75\n",
"r1 = [item - 0.25 for item in range(2, 24, 2)]\n",
"bar_width = 0.5\n",
"r1 = [item - 0.5 for item in range(2, 14, 2)]\n",
"r2 = [x + bar_width for x in r1]\n",
"r3 = [x + bar_width for x in r2]\n",
"\n",
"for complex_type in [\"single\", \"multi\"]:\n",
" for complex_license in [\"all\", \"public\"]:\n",
" casp15_rmsd_lt_2_data_list, casp15_relaxed_rmsd_lt_2_data_list = [], []\n",
" (\n",
" casp15_rmsd_lt_2_data_list,\n",
" casp15_relaxed_rmsd_lt_2_data_list,\n",
" casp15_plif_wm_data_list,\n",
" ) = ([], [], [])\n",
" for repeat_index in range(1, max_num_repeats_per_method + 1):\n",
" # filter the data based on the complex type and license\n",
" casp15_results_table = globals()[f\"scoring_results_table_{repeat_index}\"][\n",
Expand Down Expand Up @@ -640,6 +651,15 @@
" )\n",
" casp15_relaxed_rmsd_lt_2_data_list.append(casp15_relaxed_rmsd_lt_2_data)\n",
"\n",
" # CASP15 PLIF-WM results\n",
" casp15_plif_wm_data = (\n",
" globals()[f\"casp15_plif_metrics_table_{repeat_index}\"]\n",
" .groupby(\"Category\")\n",
" .agg({\"WM\": \"mean\", \"category_assignment_index\": \"first\"})\n",
" )\n",
" casp15_plif_wm_data = casp15_plif_wm_data.sort_values(\"category_assignment_index\")\n",
" casp15_plif_wm_data_list.append(casp15_plif_wm_data)\n",
"\n",
" # calculate means and standard deviations\n",
" casp15_rmsd_lt_2_data_mean = (\n",
" pd.concat([df for df in casp15_rmsd_lt_2_data_list])\n",
Expand Down Expand Up @@ -686,8 +706,33 @@
" .std()\n",
" .sort_values([\"method_assignment_index\"])[\"RMSD ≤ 2 Å\"]\n",
" )\n",
"\n",
" casp15_plif_wm_data_mean = (\n",
" pd.concat([df for df in casp15_plif_wm_data_list])\n",
" .groupby(\n",
" [\n",
" \"Category\",\n",
" \"category_assignment_index\",\n",
" ]\n",
" )\n",
" .mean()\n",
" .sort_values([\"category_assignment_index\"])[\"WM\"]\n",
" )\n",
" casp15_plif_wm_data_std = (\n",
" pd.concat([df for df in casp15_plif_wm_data_list])\n",
" .groupby(\n",
" [\n",
" \"Category\",\n",
" \"category_assignment_index\",\n",
" ]\n",
" )\n",
" .std()\n",
" .sort_values([\"category_assignment_index\"])[\"WM\"]\n",
" )\n",
"\n",
" casp15_rmsd_lt_2_data_std.fillna(0, inplace=True)\n",
" casp15_relaxed_rmsd_lt_2_data_std.fillna(0, inplace=True)\n",
" casp15_plif_wm_data_std.fillna(0, inplace=True)\n",
"\n",
" # define font properties\n",
" plt.rcParams[\"font.size\"] = 22\n",
Expand Down Expand Up @@ -724,14 +769,26 @@
" width=bar_width,\n",
" )\n",
"\n",
" # plot PLIF-WM data for the CASP15 set\n",
" casp15_plif_wm_bar = axis.bar(\n",
" r3,\n",
" casp15_plif_wm_data_mean,\n",
" yerr=casp15_plif_wm_data_std,\n",
" label=\"PLIF-WM\",\n",
" color=colors[2],\n",
" hatch=\"\\\\\\\\\\\\\",\n",
" width=bar_width,\n",
" )\n",
"\n",
" # add labels, titles, ticks, etc.\n",
" axis.set_xlabel(f\"{complex_type.title()}-ligand blind docking ({complex_license})\")\n",
" axis.set_ylabel(\"Percentage of predictions\")\n",
" axis.set_xlim(1, 23 + 0.1)\n",
" axis.set_ylim(0, 100)\n",
" axis.set_xlim(1, 13 + 0.1)\n",
" axis.set_ylim(0, 125)\n",
"\n",
" axis.bar_label(casp15_rmsd_lt2_bar, fmt=\"{:,.1f}%\", label_type=\"center\")\n",
" axis.bar_label(casp15_relaxed_rmsd_lt_2_bar, fmt=\"{:,.1f}%\", label_type=\"center\")\n",
" axis.bar_label(casp15_plif_wm_bar, fmt=\"{:,.1f}%\", label_type=\"center\")\n",
"\n",
" axis.yaxis.set_major_formatter(mtick.PercentFormatter())\n",
"\n",
Expand All @@ -740,24 +797,18 @@
" axis.grid(axis=\"y\", color=\"#EAEFF8\")\n",
" axis.set_axisbelow(True)\n",
"\n",
" axis.set_xticks([2, 4, 6, 8, 8 + 1e-3, 10, 12, 14, 16, 18, 18 + 1e-3, 20, 22, 22 + 1e-3])\n",
" axis.set_xticks([2, 2 + 1e-3, 4, 6, 8, 8 + 1e-3, 10, 12])\n",
" axis.set_xticks([1 + 0.1], minor=True)\n",
" axis.set_xticklabels(\n",
" [\n",
" \"P2Rank-Vina\",\n",
" \"Conventional blind\",\n",
" \"DiffDock-L\",\n",
" \"-> w/o SCT\",\n",
" \"DynamicBind\",\n",
" \"NeuralPLexer\",\n",
" \"DL-based blind\",\n",
" \"-> w/o ILCL\",\n",
" \"RoseTTAFold-AA\",\n",
" \"Chai-1\",\n",
" \"TULIP\",\n",
" \"DiffDock-L-Vina\",\n",
" \"Conventional blind\",\n",
" \"P2Rank-Vina\",\n",
" \"Ensemble (Con)\",\n",
" \"Hybrid blind\",\n",
" ]\n",
" )\n",
"\n",
Expand All @@ -769,7 +820,7 @@
" axis.tick_params(axis=\"y\", which=\"major\", left=\"off\", right=\"on\", color=\"#EAEFF8\")\n",
"\n",
" # vertical alignment of xtick labels\n",
" vert_alignments = [0.0, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0, -0.1]\n",
" vert_alignments = [0.0, -0.1, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0]\n",
" for tick, y in zip(axis.get_xticklabels(), vert_alignments):\n",
" tick.set_y(y)\n",
"\n",
Expand All @@ -779,17 +830,25 @@
" [\"RMSD ≤ 2Å\"],\n",
" loc=\"upper right\",\n",
" title=\"No post-processing\",\n",
" bbox_to_anchor=(1, 1, -0.20, -0.05),\n",
" bbox_to_anchor=(1, 1, -0.40, -0.05),\n",
" )\n",
" legend_1 = fig.legend(\n",
" [casp15_relaxed_rmsd_lt_2_bar],\n",
" [\"RMSD ≤ 2Å\"],\n",
" loc=\"upper right\",\n",
" title=\"With relaxation\",\n",
" bbox_to_anchor=(1, 1, -0.2, -0.05),\n",
" )\n",
" legend_2 = fig.legend(\n",
" [casp15_plif_wm_bar],\n",
" [\"PLIF-WM\"],\n",
" loc=\"upper right\",\n",
" title=\"Protein-ligand interactions\\n (no post-processing)\",\n",
" bbox_to_anchor=(1, 1, -0.01, -0.05),\n",
" )\n",
" legend_0.get_frame().set_alpha(0)\n",
" legend_1.get_frame().set_alpha(0)\n",
" legend_2.get_frame().set_alpha(0)\n",
"\n",
" # display the plots\n",
" plt.tight_layout()\n",
Expand Down Expand Up @@ -864,8 +923,9 @@
"colors = [\"#FB8072\", \"#BEBADA\"]\n",
"\n",
"bar_width = 0.75\n",
"r1 = [item - 0.25 for item in range(2, 24, 2)]\n",
"r1 = [item - 0.25 for item in range(2, 14, 2)]\n",
"r2 = [x + bar_width for x in r1]\n",
"r3 = [x + bar_width for x in r2]\n",
"\n",
"for complex_type in [\"single\", \"multi\"]:\n",
" for complex_license in [\"all\", \"public\"]:\n",
Expand Down Expand Up @@ -1047,7 +1107,7 @@
" # add labels, titles, ticks, etc.\n",
" axis.set_xlabel(f\"{complex_type.title()}-ligand blind docking ({complex_license})\")\n",
" axis.set_ylabel(\"Percentage of complex predictions\")\n",
" axis.set_xlim(1, 23 + 0.1)\n",
" axis.set_xlim(1, 13 + 0.1)\n",
" axis.set_ylim(0, 100)\n",
"\n",
" axis.bar_label(casp15_pb_valid_bar, fmt=\"{:,.1f}%\", label_type=\"center\")\n",
Expand All @@ -1060,24 +1120,18 @@
" axis.grid(axis=\"y\", color=\"#EAEFF8\")\n",
" axis.set_axisbelow(True)\n",
"\n",
" axis.set_xticks([2, 4, 6, 8, 8 + 1e-3, 10, 12, 14, 16, 18, 18 + 1e-3, 20, 22, 22 + 1e-3])\n",
" axis.set_xticks([2, 2 + 1e-3, 4, 6, 8, 8 + 1e-3, 10, 12])\n",
" axis.set_xticks([1 + 0.1], minor=True)\n",
" axis.set_xticklabels(\n",
" [\n",
" \"P2Rank-Vina\",\n",
" \"Conventional blind\",\n",
" \"DiffDock-L\",\n",
" \"-> w/o SCT\",\n",
" \"DynamicBind\",\n",
" \"NeuralPLexer\",\n",
" \"DL-based blind\",\n",
" \"-> w/o ILCL\",\n",
" \"RoseTTAFold-AA\",\n",
" \"Chai-1\",\n",
" \"TULIP\",\n",
" \"DiffDock-L-Vina\",\n",
" \"Conventional blind\",\n",
" \"P2Rank-Vina\",\n",
" \"Ensemble (Con)\",\n",
" \"Hybrid blind\",\n",
" ]\n",
" )\n",
"\n",
Expand All @@ -1089,7 +1143,7 @@
" axis.tick_params(axis=\"y\", which=\"major\", left=\"off\", right=\"on\", color=\"#EAEFF8\")\n",
"\n",
" # vertical alignment of xtick labels\n",
" vert_alignments = [0.0, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0, -0.1]\n",
" vert_alignments = [0.0, -0.1, 0.0, 0.0, 0.0, -0.1, 0.0, 0.0]\n",
" for tick, y in zip(axis.get_xticklabels(), vert_alignments):\n",
" tick.set_y(y)\n",
"\n",
Expand Down

0 comments on commit aaef02b

Please sign in to comment.