Skip to content

Commit

Permalink
Added Style and plot #3
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Stricker committed May 29, 2024
1 parent aae6013 commit 06b48c6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 4 deletions.
23 changes: 19 additions & 4 deletions nmrcraft/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@

def style_setup():
"""Function to set up matplotlib parameters."""
colors = ["#C28340", "#854F2B", "#61371F", "#8FCA5C", "#70B237", "#477A1E"]
colors = [
"#C28340",
"#854F2B",
"#61371F",
"#8FCA5C",
"#70B237",
"#477A1E",
"#3B661A",
]
cmap = LinearSegmentedColormap.from_list("custom", colors)

plt.style.use("./style.mplstyle")
Expand Down Expand Up @@ -233,6 +241,7 @@ def plot_metric(
iterative_column="model",
xdata="dataset_fraction",
):
_, colors, _ = style_setup()
if iterative_column == "target":

def convert_to_labels(target_list):
Expand All @@ -247,7 +256,7 @@ def convert_to_labels(target_list):
data["xlabel"] = data["model_targets"].apply(convert_to_labels)
print(data)

for iterator in data[iterative_column].unique():
for i, iterator in enumerate(data[iterative_column].unique()):
model_data = data[data[iterative_column] == iterator]
errors = [
model_data[metric + "_mean"].values
Expand All @@ -259,8 +268,9 @@ def convert_to_labels(target_list):
model_data[xdata],
model_data[metric + "_mean"],
yerr=errors,
fmt="o",
fmt="o-",
label=iterator,
color=colors[i],
capsize=5,
)
plt.legend()
Expand All @@ -286,7 +296,12 @@ def plot_bar(
_, colors, _ = style_setup()

def convert_to_labels(target_list):
label_dict = {"metal": "M", "E_ligand": "E", "X3_ligand": "X3"}
label_dict = {
"metal": "M",
"E_ligand": "E",
"X3_ligand": "X3",
"lig": "\n (ligands input)",
}
return ", ".join([label_dict[i] for i in target_list])

# Convert string representations of lists to actual lists
Expand Down
42 changes: 42 additions & 0 deletions scripts/analysis/visulize_results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import ast
import os

import pandas as pd

from nmrcraft.analysis.plotting import plot_bar, plot_metric
Expand All @@ -13,6 +16,8 @@

df = pd.concat([df_base, df_one])

if not os.path.exists("./plots"):
os.makedirs("./plots")

targets = df["target"].unique()
for target in targets:
Expand Down Expand Up @@ -60,3 +65,40 @@
iterative_column="target",
xdata="xlabel",
)


full_df["model_targets"] = full_df["model_targets"].apply(
lambda x: ast.literal_eval(x) if isinstance(x, str) else x
)

# Add 'lig' to model_targets if nmr_only is True
full_df["model_targets"] = full_df.apply(
lambda row: row["model_targets"] + ["lig"]
if not row["nmr_only"]
else row["model_targets"],
axis=1,
)

print(full_df)


models = full_df["model"].unique()
for model in models:
sub_df = full_df[full_df["model"] == model]
print(sub_df)
plot_bar(
sub_df,
title=f"Accuracy for {model} Predictions",
filename=f"plots/03_accuracy_{model}.png",
metric="accuracy",
iterative_column="target",
xdata="xlabel",
)
plot_bar(
sub_df,
title=f"F1-Score for {model} Predictions",
filename=f"plots/03_f1-score_{model}.png",
metric="f1",
iterative_column="target",
xdata="xlabel",
)

0 comments on commit 06b48c6

Please sign in to comment.