Skip to content

Commit

Permalink
Merge pull request #41 from marcpinet/feat-decision-boundary-viz
Browse files Browse the repository at this point in the history
Feat decision boundary viz
  • Loading branch information
marcpinet authored May 18, 2024
2 parents 7b24e76 + fe3ed49 commit 959fd4d
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 8 deletions.
65 changes: 59 additions & 6 deletions neuralnetlib/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import json
import time
import matplotlib

import numpy as np

from neuralnetlib.activations import ActivationFunction
from neuralnetlib.layers import Layer, Input, Activation, Dropout, compatibility_dict
from neuralnetlib.losses import LossFunction, CategoricalCrossentropy
from neuralnetlib.metrics import accuracy_score
from neuralnetlib.preprocessing import PCA
from neuralnetlib.optimizers import Optimizer
from neuralnetlib.utils import shuffle, progress_bar
import matplotlib.pyplot as plt


matplotlib.use("TkAgg")


class Model:
Expand Down Expand Up @@ -106,7 +111,7 @@ def train_on_batch(self, x_batch: np.ndarray, y_batch: np.ndarray) -> float:

def fit(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size: int = None,
verbose: bool = True, metrics: list = None, random_state: int = None, validation_data: tuple = None,
callbacks: list = None):
callbacks: list = None, plot_decision_boundary: bool = False):
"""
Fit the model to the training data.
Expand All @@ -120,7 +125,9 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size:
random_state: Random seed for shuffling the data
validation_data: Tuple of validation data and labels
callbacks: List of callback objects (e.g., EarlyStopping)
plot_decision_boundary: Whether to plot the decision boundary
"""
global update_plot
x_train = np.array(x_train) if not isinstance(
x_train, np.ndarray) else x_train
y_train = np.array(y_train) if not isinstance(
Expand All @@ -131,6 +138,45 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size:
x_test = np.array(x_test)
y_test = np.array(y_test)

if plot_decision_boundary:
pca = PCA(n_components=2, random_state=random_state)
x_train_2d = pca.fit_transform(x_train)

fig, ax = plt.subplots(figsize=(8, 6))

x_min, x_max = x_train_2d[:, 0].min() - 1, x_train_2d[:, 0].max() + 1
y_min, y_max = x_train_2d[:, 1].min() - 1, x_train_2d[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))

y_train_encoded = np.argmax(y_train, axis=1) if y_train.ndim > 1 else y_train

def update_plot(epoch):
ax.clear()

scatter = ax.scatter(x_train_2d[:, 0], x_train_2d[:, 1], c=y_train_encoded, cmap='viridis', alpha=0.7)

labels = np.unique(y_train_encoded)
handles = [
plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=scatter.cmap(scatter.norm(label)),
label=f'Class {label}', markersize=8) for label in labels]
ax.legend(handles=handles, title='Classes')

grid_points = np.c_[xx.ravel(), yy.ravel()]
Z = self.predict(pca.inverse_transform(grid_points))
if Z.shape[1] > 1: # Multiclass classification
Z = np.argmax(Z, axis=1).reshape(xx.shape)
ax.contourf(xx, yy, Z, alpha=0.2, cmap=plt.cm.RdYlBu, levels=np.arange(Z.max() + 1))
else: # Binary classification
Z = (Z > 0.5).astype(int).reshape(xx.shape)
ax.contourf(xx, yy, Z, alpha=0.2, cmap=plt.cm.RdYlBu, levels=1)

ax.set_xlabel("PCA Component 1")
ax.set_ylabel("PCA Component 2")
ax.set_title(f"Decision Boundary (Epoch {epoch + 1})")

fig.canvas.draw()

for i in range(epochs):
start_time = time.time()

Expand Down Expand Up @@ -190,10 +236,10 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size:
for metric in metrics:
# Change extend to append
val_metrics.append(metric(val_predictions, y_test))
if verbose:
val_metrics_str = ' - '.join(
f'{metric.__name__}: {val_metric:.4f}' for metric, val_metric in zip(metrics, val_metrics))
print(f' - {val_metrics_str}', end='')
if verbose:
val_metrics_str = ' - '.join(
f'{metric.__name__}: {val_metric:.4f}' for metric, val_metric in zip(metrics, val_metrics))
print(f' - {val_metrics_str}', end='')

if callbacks:
metrics_values = {}
Expand All @@ -218,6 +264,13 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size:

if verbose:
print()

if plot_decision_boundary:
update_plot(i)
plt.pause(0.1) # Pause pour laisser le temps de mettre à jour le graphique

if plot_decision_boundary and 'IPython' in globals():
plt.show(block=False)

if verbose:
print()
Expand Down
34 changes: 34 additions & 0 deletions neuralnetlib/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,37 @@ def inverse_transform(self, X):
if self.min_ is None or self.scale_ is None:
raise ValueError("MinMaxScaler has not been fitted yet.")
return (X - self.feature_range[0]) / (self.feature_range[1] - self.feature_range[0]) * self.scale_ + self.min_


class PCA:
def __init__(self, n_components: int, random_state: int = None):
self.n_components = n_components
self.random_state = random_state
self.components = None
self.mean = None

def fit(self, X: np.ndarray):
self.mean = np.mean(X, axis=0)
X_centered = X - self.mean

covariance_matrix = np.cov(X_centered, rowvar=False)

eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)

sorted_indices = np.argsort(eigenvalues)[::-1]
eigenvalues = eigenvalues[sorted_indices]
eigenvectors = eigenvectors[:, sorted_indices]

self.components = eigenvectors[:, :self.n_components]

def transform(self, X: np.ndarray) -> np.ndarray:
X_centered = X - self.mean

return np.dot(X_centered, self.components)

def fit_transform(self, X: np.ndarray) -> np.ndarray:
self.fit(X)
return self.transform(X)

def inverse_transform(self, X: np.ndarray) -> np.ndarray:
return np.dot(X, self.components.T) + self.mean
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
numpy
numpy
matplotlib # for plotting the decision boundary (yeah, I won't rewrite matplotlib too)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='neuralnetlib',
version='2.5.1',
version='2.6.0',
author='Marc Pinet',
description='A simple convolutional neural network library with only numpy as dependency',
long_description=open('README.md', encoding="utf-8").read(),
Expand Down

0 comments on commit 959fd4d

Please sign in to comment.