Skip to content

Commit

Permalink
Added Normalisation to the confusion matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
strsamue committed Jun 15, 2024
1 parent e16077a commit 15d666b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion nmrcraft/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,12 @@ def plot_confusion_matrix(
f"ConfusionMatrix_{model_name}_{dataset_size}_{target}.png",
)
cm = cm_list[target]
cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
classes = y_labels[target]
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation="nearest", cmap=cmap)
plt.imshow(
cm_normalized, interpolation="nearest", cmap=cmap, vmin=0, vmax=1
)
plt.title(f"{target} Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(classes))
Expand Down

0 comments on commit 15d666b

Please sign in to comment.