Skip to content

Commit

Permalink
Added Normalisation to the confusion matrix (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
strsamue authored Jun 15, 2024
1 parent e16077a commit 26edd6d
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion nmrcraft/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,12 @@ def plot_confusion_matrix(
f"ConfusionMatrix_{model_name}_{dataset_size}_{target}.png",
)
cm = cm_list[target]
cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
classes = y_labels[target]
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation="nearest", cmap=cmap)
plt.imshow(
cm_normalized, interpolation="nearest", cmap=cmap, vmin=0, vmax=1
)
plt.title(f"{target} Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(classes))
Expand Down

0 comments on commit 26edd6d

Please sign in to comment.