From 2b77d2320b4eb44aff4ff083174192d2f3f27356 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tiago=20W=C3=BCrthner?= Date: Mon, 27 May 2024 14:21:59 +0000 Subject: [PATCH] Feat: Added statistics for the bootstrapped Metrics Calculate 95% confidence interval and mean for the bootstrapped metrics values --- nmrcraft/evaluation/evaluation.py | 44 +++++++++++++++++++++++++++++++ scripts/training/one_target.py | 5 ++++ 2 files changed, 49 insertions(+) diff --git a/nmrcraft/evaluation/evaluation.py b/nmrcraft/evaluation/evaluation.py index 97dd781..1cb3de3 100644 --- a/nmrcraft/evaluation/evaluation.py +++ b/nmrcraft/evaluation/evaluation.py @@ -1,6 +1,7 @@ from typing import Dict, List, Tuple import numpy as np +import scipy.stats as st from sklearn.metrics import ( accuracy_score, confusion_matrix, @@ -86,3 +87,46 @@ def evaluate_bootstrap(X_test, y_test, model, targets, n_times=10): ) bootstrap_metrics[target]["F1"].append(metrics[target]["F1"]) return bootstrap_metrics + + +def metrics_statistics(bootstrapped_metrics): + """ + Do statistics with the bootsrapped metrics + + Args: + dict: bootstrapped_metrics + + Returns: + dict: Mean and 95% ci for the bootstrapped values for each target + """ + metrics_stats = {} + for key, value in bootstrapped_metrics.items(): + metrics_stats[key] = { + "Accuracy_mean": None, + "Accuracy_ci": None, + "F1_mean": None, + "F1_ci": None, + } + + print(key) + print(value["Accuracy"]) + + # calc mean and 95% confidence interval for Accuracy + metrics_stats[key]["Accuracy_mean"] = np.mean(value["Accuracy"]) + metrics_stats[key]["Accuracy_ci"] = st.t.interval( + confidence=0.95, + df=len(value["Accuracy"]) - 1, + loc=np.mean(value["Accuracy"]), + scale=st.sem(value["Accuracy"]), + ) + + # calc mean and 95% confidence interval for F1 score + metrics_stats[key]["F1_mean"] = np.mean(value["F1"]) + metrics_stats[key]["F1_ci"] = st.t.interval( + confidence=0.95, + df=len(value["F1"]) - 1, + loc=np.mean(value["F1"]), + scale=st.sem(value["F1"]), + ) + + return metrics_stats diff --git a/scripts/training/one_target.py b/scripts/training/one_target.py index a67934d..ff952af 100644 --- a/scripts/training/one_target.py +++ b/scripts/training/one_target.py @@ -120,6 +120,11 @@ X_test, y_test, best_model, args.target ) + bootsrap_stat_metrics = evaluation.metrics_statistics( + bootstrap_metrics + ) + print(bootsrap_stat_metrics) + # TODO: Adapt this code to the new structure # visualizer = Visualizer( # model_name=model_name,