Skip to content

Commit

Permalink
Barplot #2 works now
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Stricker committed May 29, 2024
1 parent fcc944b commit aae6013
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 65 deletions.
96 changes: 63 additions & 33 deletions nmrcraft/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,52 +275,82 @@ def convert_to_labels(target_list):
plt.close()


def plot_metric_1(
def plot_bar(
data,
title="Accuracy",
filename="plots/accuracy.png",
metric="accuracy",
iterative_column="model",
xdata="dataset_fraction",
):
if iterative_column == "target":
_, colors, _ = style_setup()

def convert_to_labels(target_list):
label_dict = {"metal": "M", "E_ligand": "E", "X3_ligand": "X3"}
return ", ".join([label_dict[i] for i in target_list])
def convert_to_labels(target_list):
label_dict = {"metal": "M", "E_ligand": "E", "X3_ligand": "X3"}
return ", ".join([label_dict[i] for i in target_list])

# Convert string representations of lists to actual lists
data["model_targets"] = data["model_targets"].apply(
lambda x: ast.literal_eval(x) if isinstance(x, str) else x
)
# Convert string representations of lists to actual lists
data["model_targets"] = data["model_targets"].apply(
lambda x: ast.literal_eval(x) if isinstance(x, str) else x
)

data["xlabel"] = data["model_targets"].apply(convert_to_labels)
print(data)
data["xlabel"] = data["model_targets"].apply(convert_to_labels)

for iterator in data[iterative_column].unique():
model_data = data[data[iterative_column] == iterator]
errors = [
model_data[metric + "_mean"].values
- model_data[metric + "_lb"].values,
model_data[metric + "_hb"].values
- model_data[metric + "_mean"].values,
]
plt.errorbar(
model_data[xdata],
model_data[metric + "_mean"],
yerr=errors,
fmt="o",
label=model_data["target"],
# Aggregate the data to handle duplicates
aggregated_data = (
data.groupby(["xlabel", "target"])
.agg({metric + "_mean": "mean"})
.reset_index()
)
aggregated_data_lb = (
data.groupby(["xlabel", "target"])
.agg({metric + "_lb": "mean"})
.reset_index()
)
aggregated_data_hb = (
data.groupby(["xlabel", "target"])
.agg({metric + "_hb": "mean"})
.reset_index()
)

# Pivot the aggregated data
new_df = aggregated_data.pivot(
index="xlabel", columns="target", values=metric + "_mean"
)
new_lb = aggregated_data_lb.pivot(
index="xlabel", columns="target", values=metric + "_lb"
)
new_hb = aggregated_data_hb.pivot(
index="xlabel", columns="target", values=metric + "_hb"
)

fig, ax = plt.subplots()
width = 0.2 # width of the bar
x = np.arange(len(new_df.index))

# Plotting each column (target) as a separate group
for i, column in enumerate(new_df.columns):
ax.bar(
x + i * width,
new_df[column],
width,
color=colors[i],
label=column,
yerr=[
new_df[column] - new_lb[column],
new_hb[column] - new_df[column],
],
capsize=5,
)
plt.legend()
plt.title(title)
plt.grid(True)
if iterative_column == "model":
plt.xlim(0, 1)
plt.ylim(0, 1.2)
plt.xlabel("Dataset Size")
plt.ylabel(metric)

ax.set_ylabel(f"{metric}")
ax.set_ylim(0, 1.2)
ax.set_xticks(x + width * (len(new_df.columns) - 1) / 2)
ax.set_xticklabels(new_df.index, rotation=45, ha="right")
ax.set_title(title)
ax.legend()
plt.grid(True, "major", "y", ls="--", lw=0.5, c="k", alpha=0.3)
fig.tight_layout()
plt.savefig(filename)
plt.close()

Expand Down
54 changes: 22 additions & 32 deletions scripts/analysis/visulize_results.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd

from nmrcraft.analysis.plotting import plot_metric, plot_metric_1
from nmrcraft.analysis.plotting import plot_bar, plot_metric

import_filename_base = "metrics/results_baselines.csv"
import_filename_one = "metrics/results_one_target.csv"
Expand Down Expand Up @@ -39,34 +39,24 @@
df = pd.concat([df_one, df_multi])
full_df = df[df["dataset_fraction"] == 1]


# models = full_df['model'].unique()
# for model in models:
# sub_df = full_df[full_df["model"] == model]
# print(sub_df)
# plot_metric_1(sub_df, title=f"Accuracy for {model} Predictions", filename=f'plots/02_accuracy_{model}.png', metric="accuracy", iterative_column='target', xdata='xlabel')
# plot_metric_1(sub_df, title=f"F1-Score for {model} Predictions", filename=f'plots/02_f1-score_{model}.png', metric="f1", iterative_column='target', xdata='xlabel')

nmr_only = full_df["nmr_only"].unique()
for mode in nmr_only:
mode_df = full_df[full_df["nmr_only"] == mode]
models = mode_df["model"].unique()
for model in models:
sub_df = full_df[mode_df["model"] == model]
print(sub_df)
plot_metric_1(
sub_df,
title=f"Accuracy for {model} Predictions with onlyNMR = {mode}",
filename=f"plots/03_accuracy_{model}_NMR_{mode}.png",
metric="accuracy",
iterative_column="target",
xdata="xlabel",
)
plot_metric_1(
sub_df,
title=f"F1-Score for {model} Predictions with onlyNMR = {mode}",
filename=f"plots/03_f1-score_{model}_NMR_{mode}.png",
metric="f1",
iterative_column="target",
xdata="xlabel",
)
true_df = full_df[full_df["nmr_only"]]
models = true_df["model"].unique()
for model in models:
sub_df = true_df[true_df["model"] == model]
print(sub_df)
plot_bar(
sub_df,
title=f"Accuracy for {model} Predictions",
filename=f"plots/02_accuracy_{model}.png",
metric="accuracy",
iterative_column="target",
xdata="xlabel",
)
plot_bar(
sub_df,
title=f"F1-Score for {model} Predictions",
filename=f"plots/02_f1-score_{model}.png",
metric="f1",
iterative_column="target",
xdata="xlabel",
)

0 comments on commit aae6013

Please sign in to comment.