Skip to content

Commit

Permalink
Feat: Added statistics for the bootstrapped Metrics
Browse files Browse the repository at this point in the history
Calculate 95% confidence interval and mean for the bootstrapped metrics
values
  • Loading branch information
Tiago Würthner committed May 27, 2024
1 parent 9fddfe5 commit 2b77d23
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
44 changes: 44 additions & 0 deletions nmrcraft/evaluation/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, List, Tuple

import numpy as np
import scipy.stats as st
from sklearn.metrics import (
accuracy_score,
confusion_matrix,
Expand Down Expand Up @@ -86,3 +87,46 @@ def evaluate_bootstrap(X_test, y_test, model, targets, n_times=10):
)
bootstrap_metrics[target]["F1"].append(metrics[target]["F1"])
return bootstrap_metrics


def metrics_statistics(bootstrapped_metrics):
"""
Do statistics with the bootsrapped metrics
Args:
dict: bootstrapped_metrics
Returns:
dict: Mean and 95% ci for the bootstrapped values for each target
"""
metrics_stats = {}
for key, value in bootstrapped_metrics.items():
metrics_stats[key] = {
"Accuracy_mean": None,
"Accuracy_ci": None,
"F1_mean": None,
"F1_ci": None,
}

print(key)
print(value["Accuracy"])

# calc mean and 95% confidence interval for Accuracy
metrics_stats[key]["Accuracy_mean"] = np.mean(value["Accuracy"])
metrics_stats[key]["Accuracy_ci"] = st.t.interval(
confidence=0.95,
df=len(value["Accuracy"]) - 1,
loc=np.mean(value["Accuracy"]),
scale=st.sem(value["Accuracy"]),
)

# calc mean and 95% confidence interval for F1 score
metrics_stats[key]["F1_mean"] = np.mean(value["F1"])
metrics_stats[key]["F1_ci"] = st.t.interval(
confidence=0.95,
df=len(value["F1"]) - 1,
loc=np.mean(value["F1"]),
scale=st.sem(value["F1"]),
)

return metrics_stats
5 changes: 5 additions & 0 deletions scripts/training/one_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@
X_test, y_test, best_model, args.target
)

bootsrap_stat_metrics = evaluation.metrics_statistics(
bootstrap_metrics
)
print(bootsrap_stat_metrics)

# TODO: Adapt this code to the new structure
# visualizer = Visualizer(
# model_name=model_name,
Expand Down

0 comments on commit 2b77d23

Please sign in to comment.