diff --git a/nmrcraft/analysis/plotting.py b/nmrcraft/analysis/plotting.py index 6045f06..cfcdc6f 100644 --- a/nmrcraft/analysis/plotting.py +++ b/nmrcraft/analysis/plotting.py @@ -275,7 +275,7 @@ def convert_to_labels(target_list): plt.close() -def plot_metric_1( +def plot_bar( data, title="Accuracy", filename="plots/accuracy.png", @@ -283,44 +283,74 @@ def plot_metric_1( iterative_column="model", xdata="dataset_fraction", ): - if iterative_column == "target": + _, colors, _ = style_setup() - def convert_to_labels(target_list): - label_dict = {"metal": "M", "E_ligand": "E", "X3_ligand": "X3"} - return ", ".join([label_dict[i] for i in target_list]) + def convert_to_labels(target_list): + label_dict = {"metal": "M", "E_ligand": "E", "X3_ligand": "X3"} + return ", ".join([label_dict[i] for i in target_list]) - # Convert string representations of lists to actual lists - data["model_targets"] = data["model_targets"].apply( - lambda x: ast.literal_eval(x) if isinstance(x, str) else x - ) + # Convert string representations of lists to actual lists + data["model_targets"] = data["model_targets"].apply( + lambda x: ast.literal_eval(x) if isinstance(x, str) else x + ) - data["xlabel"] = data["model_targets"].apply(convert_to_labels) - print(data) + data["xlabel"] = data["model_targets"].apply(convert_to_labels) - for iterator in data[iterative_column].unique(): - model_data = data[data[iterative_column] == iterator] - errors = [ - model_data[metric + "_mean"].values - - model_data[metric + "_lb"].values, - model_data[metric + "_hb"].values - - model_data[metric + "_mean"].values, - ] - plt.errorbar( - model_data[xdata], - model_data[metric + "_mean"], - yerr=errors, - fmt="o", - label=model_data["target"], + # Aggregate the data to handle duplicates + aggregated_data = ( + data.groupby(["xlabel", "target"]) + .agg({metric + "_mean": "mean"}) + .reset_index() + ) + aggregated_data_lb = ( + data.groupby(["xlabel", "target"]) + .agg({metric + "_lb": "mean"}) + .reset_index() + ) + aggregated_data_hb = ( + data.groupby(["xlabel", "target"]) + .agg({metric + "_hb": "mean"}) + .reset_index() + ) + + # Pivot the aggregated data + new_df = aggregated_data.pivot( + index="xlabel", columns="target", values=metric + "_mean" + ) + new_lb = aggregated_data_lb.pivot( + index="xlabel", columns="target", values=metric + "_lb" + ) + new_hb = aggregated_data_hb.pivot( + index="xlabel", columns="target", values=metric + "_hb" + ) + + fig, ax = plt.subplots() + width = 0.2 # width of the bar + x = np.arange(len(new_df.index)) + + # Plotting each column (target) as a separate group + for i, column in enumerate(new_df.columns): + ax.bar( + x + i * width, + new_df[column], + width, + color=colors[i], + label=column, + yerr=[ + new_df[column] - new_lb[column], + new_hb[column] - new_df[column], + ], capsize=5, ) - plt.legend() - plt.title(title) - plt.grid(True) - if iterative_column == "model": - plt.xlim(0, 1) - plt.ylim(0, 1.2) - plt.xlabel("Dataset Size") - plt.ylabel(metric) + + ax.set_ylabel(f"{metric}") + ax.set_ylim(0, 1.2) + ax.set_xticks(x + width * (len(new_df.columns) - 1) / 2) + ax.set_xticklabels(new_df.index, rotation=45, ha="right") + ax.set_title(title) + ax.legend() + plt.grid(True, "major", "y", ls="--", lw=0.5, c="k", alpha=0.3) + fig.tight_layout() plt.savefig(filename) plt.close() diff --git a/scripts/analysis/visulize_results.py b/scripts/analysis/visulize_results.py index b88689c..e87a2c0 100755 --- a/scripts/analysis/visulize_results.py +++ b/scripts/analysis/visulize_results.py @@ -1,6 +1,6 @@ import pandas as pd -from nmrcraft.analysis.plotting import plot_metric, plot_metric_1 +from nmrcraft.analysis.plotting import plot_bar, plot_metric import_filename_base = "metrics/results_baselines.csv" import_filename_one = "metrics/results_one_target.csv" @@ -39,34 +39,24 @@ df = pd.concat([df_one, df_multi]) full_df = df[df["dataset_fraction"] == 1] - -# models = full_df['model'].unique() -# for model in models: -# sub_df = full_df[full_df["model"] == model] -# print(sub_df) -# plot_metric_1(sub_df, title=f"Accuracy for {model} Predictions", filename=f'plots/02_accuracy_{model}.png', metric="accuracy", iterative_column='target', xdata='xlabel') -# plot_metric_1(sub_df, title=f"F1-Score for {model} Predictions", filename=f'plots/02_f1-score_{model}.png', metric="f1", iterative_column='target', xdata='xlabel') - -nmr_only = full_df["nmr_only"].unique() -for mode in nmr_only: - mode_df = full_df[full_df["nmr_only"] == mode] - models = mode_df["model"].unique() - for model in models: - sub_df = full_df[mode_df["model"] == model] - print(sub_df) - plot_metric_1( - sub_df, - title=f"Accuracy for {model} Predictions with onlyNMR = {mode}", - filename=f"plots/03_accuracy_{model}_NMR_{mode}.png", - metric="accuracy", - iterative_column="target", - xdata="xlabel", - ) - plot_metric_1( - sub_df, - title=f"F1-Score for {model} Predictions with onlyNMR = {mode}", - filename=f"plots/03_f1-score_{model}_NMR_{mode}.png", - metric="f1", - iterative_column="target", - xdata="xlabel", - ) +true_df = full_df[full_df["nmr_only"]] +models = true_df["model"].unique() +for model in models: + sub_df = true_df[true_df["model"] == model] + print(sub_df) + plot_bar( + sub_df, + title=f"Accuracy for {model} Predictions", + filename=f"plots/02_accuracy_{model}.png", + metric="accuracy", + iterative_column="target", + xdata="xlabel", + ) + plot_bar( + sub_df, + title=f"F1-Score for {model} Predictions", + filename=f"plots/02_f1-score_{model}.png", + metric="f1", + iterative_column="target", + xdata="xlabel", + )