From 25a96e49c1dfa25e56c5f89f465a1c16bcff9037 Mon Sep 17 00:00:00 2001 From: Parth Raut Date: Sat, 11 Jan 2025 02:12:36 +0530 Subject: [PATCH] removed file --- examples/jax/train_single_NO_MONITOR.py | 114 ------------------------ 1 file changed, 114 deletions(-) delete mode 100644 examples/jax/train_single_NO_MONITOR.py diff --git a/examples/jax/train_single_NO_MONITOR.py b/examples/jax/train_single_NO_MONITOR.py deleted file mode 100644 index 65f1a6bd..00000000 --- a/examples/jax/train_single_NO_MONITOR.py +++ /dev/null @@ -1,114 +0,0 @@ -# Adapted from Training a simple neural network, with tensorflow/datasets data loading (https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html) - -import jax.numpy as jnp -from jax import grad, jit, vmap -from jax import random -from jax.scipy.special import logsumexp -import tensorflow as tf -import tensorflow_datasets as tfds -import time - -# A helper function to randomly initialize weights and biases -# for a dense neural network layer -def random_layer_params(m, n, key, scale=1e-2): - w_key, b_key = random.split(key) - return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) - -# Initialize all layers for a fully-connected neural network with sizes "sizes" -def init_network_params(sizes, key): - keys = random.split(key, len(sizes)) - return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] - -layer_sizes = [784, 512, 512, 10] -step_size = 0.01 -num_epochs = 10 -batch_size = 128 -n_targets = 10 -params = init_network_params(layer_sizes, random.key(0)) - -def relu(x): - return jnp.maximum(0, x) - -def predict(params, image): - # per-example predictions - activations = image - for w, b in params[:-1]: - outputs = jnp.dot(w, activations) + b - activations = relu(outputs) - - final_w, final_b = params[-1] - logits = jnp.dot(final_w, activations) + final_b - return logits - logsumexp(logits) - -def one_hot(x, k, dtype=jnp.float32): - """Create a one-hot encoding of x of size k.""" - return jnp.array(x[:, None] == jnp.arange(k), dtype) - -def accuracy(params, images, targets): - target_class = jnp.argmax(targets, axis=1) - predicted_class = jnp.argmax(batched_predict(params, images), axis=1) - return jnp.mean(predicted_class == target_class) - -# Make a batched version of the `predict` function -batched_predict = vmap(predict, in_axes=(None, 0)) - -def loss(params, images, targets): - preds = batched_predict(params, images) - return -jnp.mean(preds * targets) - -@jit -def update(params, x, y): - grads = grad(loss)(params, x, y) - return [(w - step_size * dw, b - step_size * db) - for (w, b), (dw, db) in zip(params, grads)] - - -# Ensure TF does not see GPU and grab all GPU memory. -tf.config.set_visible_devices([], device_type='GPU') - -data_dir = '/tmp/tfds' - -# Fetch full datasets for evaluation -# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) -# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy -mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) -mnist_data = tfds.as_numpy(mnist_data) -train_data, test_data = mnist_data['train'], mnist_data['test'] -num_labels = info.features['label'].num_classes -h, w, c = info.features['image'].shape -num_pixels = h * w * c - -# Full train set -train_images, train_labels = train_data['image'], train_data['label'] -train_images = jnp.reshape(train_images, (len(train_images), num_pixels)) -train_labels = one_hot(train_labels, num_labels) - -# Full test set -test_images, test_labels = test_data['image'], test_data['label'] -test_images = jnp.reshape(test_images, (len(test_images), num_pixels)) -test_labels = one_hot(test_labels, num_labels) - -print('Train:', train_images.shape, train_labels.shape) -print('Test:', test_images.shape, test_labels.shape) - -def get_train_batches(): - # as_supervised=True gives us the (image, label) as a tuple instead of a dict - ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir) - # You can build up an arbitrary tf.data input pipeline - ds = ds.batch(batch_size).prefetch(1) - # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays - return tfds.as_numpy(ds) - -for epoch in range(num_epochs): - start_time = time.time() - for x, y in get_train_batches(): - x = jnp.reshape(x, (len(x), num_pixels)) - y = one_hot(y, num_labels) - params = update(params, x, y) - epoch_time = time.time() - start_time - - train_acc = accuracy(params, train_images, train_labels) - test_acc = accuracy(params, test_images, test_labels) - print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) - print("Training set accuracy {}".format(train_acc)) - print("Test set accuracy {}".format(test_acc)) \ No newline at end of file