diff --git a/data.py b/data.py new file mode 100644 index 00000000..2545eeab --- /dev/null +++ b/data.py @@ -0,0 +1,3 @@ +from torchvision import datasets +datasets.CIFAR10(root='./data', train=True, download=True) +datasets.MNIST(root='./data', train=True, download=True) diff --git a/requirments.txt b/requirments.txt index 4069e78d..81c3982c 100644 Binary files a/requirments.txt and b/requirments.txt differ diff --git a/src/Federated_Learning_PyTorch_Colab.ipynb b/src/Federated_Learning_PyTorch_Colab.ipynb new file mode 100644 index 00000000..1259c5c4 --- /dev/null +++ b/src/Federated_Learning_PyTorch_Colab.ipynb @@ -0,0 +1,4527 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a0eef6c3", + "metadata": { + "id": "a0eef6c3" + }, + "source": [ + "# Federated Learning with PyTorch\n", + "This notebook runs the Federated Learning project in Colab using CPU." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "48ea5fdf", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "48ea5fdf", + "outputId": "821266f9-2310-42d9-b81e-d7c10d7a7d80" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.1+cu121)\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.19.1+cu121)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.7.1)\n", + "Requirement already satisfied: tensorboard in /usr/local/lib/python3.10/dist-packages (2.17.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.26.4)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.5)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.3)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (10.4.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.3.0)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.53.1)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.7)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (24.1)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.1.4)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (2.8.2)\n", + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (1.4.0)\n", + "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (1.64.1)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (3.7)\n", + "Requirement already satisfied: protobuf!=4.24.0,<5.0.0,>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (3.20.3)\n", + "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (71.0.4)\n", + "Requirement already satisfied: six>1.9 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (1.16.0)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (0.7.2)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (3.0.4)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard) (2.1.5)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n" + ] + } + ], + "source": [ + "# Install necessary libraries if needed\n", + "!pip install torch torchvision matplotlib tensorboard numpy tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f068b3ed", + "metadata": { + "id": "f068b3ed" + }, + "outputs": [], + "source": [ + "# Import required libraries and uploaded scripts\n", + "import os\n", + "import copy\n", + "import time\n", + "import pickle\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "import torch\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "\n", + "from options import args_parser\n", + "from update import LocalUpdate, test_inference\n", + "from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar\n", + "from utils import get_dataset, average_weights, exp_details" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "19fe865a", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "19fe865a", + "outputId": "489f5651-3f27-4550-917d-8ebcaad6948c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Experimental details:\n", + " Model : cnn\n", + " Optimizer : sgd\n", + " Learning : 0.01\n", + " Global Rounds : 10\n", + "\n", + " Federated parameters:\n", + " IID\n", + " Fraction of users : 0.1\n", + " Local Batch size : 10\n", + " Local Epochs : 10\n", + "\n", + "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar/cifar-10-python.tar.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 170498071/170498071 [00:06<00:00, 28119577.96it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ../data/cifar/cifar-10-python.tar.gz to ../data/cifar/\n", + "Files already downloaded and verified\n", + "CNNCifar(\n", + " (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n", + " (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n", + " (fc1): Linear(in_features=400, out_features=120, bias=True)\n", + " (fc2): Linear(in_features=120, out_features=84, bias=True)\n", + " (fc3): Linear(in_features=84, out_features=10, bias=True)\n", + ")\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r 0%| | 0/10 [00:00