Skip to content
Snippets Groups Projects
Supervised classification.ipynb 169 KiB
Newer Older

{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1bba0128",
   "metadata": {},
   "source": [
    "# Supervised learning with PyTorch\n",
    "\n",
    "This is an example of how to build and optimize neural networks with PyTorch. PyTorch and Tensorflow offer a handy mechanism to provide automatic differentiation, using the chain rule in Calculus to calculate the derivative of a function very fast and with GPU support.\n",
    "\n",
    "Our dataset will consist of images of handwritten digits and the task shall be to classify those handwritten digits in the classes {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}.\n",
    "\n",
    "If this was a regression problem, we would often try to minimise the mean-squared-error between the output of the neural network and the correct prediction. As we saw in the presentation, this assumes that the underlying probability distribution of the prediction is a Gaussian, which is certainly not true for the distribution of digit classes: for one, the digit classes are discrete and Gaussians are only defined for continuous outputs. The most general probability distribution for a choice of 10 classes is a Categorical distribution (https://en.wikipedia.org/wiki/Categorical_distribution), which is simply a discrete distribution with a given probability value for each class. How can we then sculpt a function that maps the input image to a given class?\n",
    "\n",
    "Suppose the neural network provides an output $f_k(x)$ in the form of a list of probabilities, informing us of the probability that a given image belongs to a certain class $k$. If we know that a given input image x belongs to class C, then the true probability t for this image x to belong to each class is zero for classes that differ from C and 1 for the class C. The network's objective will be to output such probabilities, so that only the i-th component of the output is 1 if the input belongs to class i. The presentation shows how the Bayes' rule leads us naturally to minimize the cross entropy between the target probabilities and the predicted probabilities: $- \\sum_k t_k \\log f_k(x)$. One can gain intuition on this by reading more on the Information Theory concept of cross-entropy and how it relates to Mutual Information: minimizing the mutual information between the labels distribution and the predicted one moves them closer together:  https://en.wikipedia.org/wiki/Cross_entropy\n",
    "\n",
    "The neural network will therefore model a parametrized function that maps the input image pixels into a vector with 10 components, which refer to the probability that the image correspond to that digit.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0681795",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torchvision torch pandas numpy matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "23feddde",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import standard PyTorch modules\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# import torchvision module to handle image manipulation\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48433f6f",
   "metadata": {},
   "source": [
    "PyTorch allows you to create a class that outputs a single data entry and use that to feed input to your neural network. An example of how you would write such a class is given below, but for this exercise we shall use something ready-made which loads the standard MNIST handwritten digits dataset, just to simplify things.\n",
    "\n",
    "If you want to load a different dataset (for example your own data!), feel free to copy and modify the example Dataset class below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "30205402",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDataset(object):\n",
    "    def __init__(self):\n",
    "        pass\n",
    "    def __len__(self):\n",
    "        return 10 # these are how many samples I have\n",
    "    def __getitem__(self, idx):\n",
    "        # give me item with index idx\n",
    "        # read this from some file, but for the purposes of this example, generate a random image and label\n",
    "        my_image = np.random.randn(10,10, 1)\n",
    "        my_label = np.array(np.random.randint(10))\n",
    "        my_image = torch.from_numpy(my_image)\n",
    "        my_label = torch.from_numpy(my_label)\n",
    "        return {\"data\": my_image, \"label\": my_label}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cc0b0774",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n"
     ]
    }
   ],
   "source": [
    "my_dataset = MyDataset()\n",
    "print(len(my_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6dccfac6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'data': tensor([[[ 1.5547e+00],\n",
      "         [ 5.5951e-01],\n",
      "         [-4.4580e-01],\n",
      "         [-4.0911e-01],\n",
      "         [-1.9626e+00],\n",
      "         [-1.2957e+00],\n",
      "         [ 1.0107e+00],\n",
      "         [ 8.5706e-01],\n",
      "         [ 8.2698e-02],\n",
      "         [ 1.7445e+00]],\n",
      "\n",
      "        [[-1.1114e+00],\n",
      "         [-1.5847e+00],\n",
      "         [-2.0654e-01],\n",
      "         [-1.0200e+00],\n",
      "         [-1.6865e-01],\n",
      "         [-1.2053e-01],\n",
      "         [ 7.0255e-01],\n",
      "         [-5.0251e-01],\n",
      "         [ 1.0529e+00],\n",
      "         [-2.9051e-01]],\n",
      "\n",
      "        [[-9.3932e-02],\n",
      "         [ 2.6510e+00],\n",
      "         [ 1.4673e+00],\n",
      "         [-1.8302e+00],\n",
      "         [-1.2404e-01],\n",
      "         [ 3.8249e-01],\n",
      "         [-5.5515e-02],\n",
      "         [-1.3505e+00],\n",
      "         [ 1.3203e-01],\n",
      "         [ 9.6623e-02]],\n",
      "\n",
      "        [[-7.8525e-01],\n",
      "         [ 6.6473e-01],\n",
      "         [ 4.6917e-01],\n",
      "         [-1.2006e+00],\n",
      "         [-7.7406e-01],\n",
      "         [-1.3107e+00],\n",
      "         [ 4.2693e-01],\n",
      "         [-8.7382e-01],\n",
      "         [-2.5915e-01],\n",
      "         [ 1.5292e+00]],\n",
      "\n",
      "        [[-6.6223e-01],\n",
      "         [-2.2870e-01],\n",
      "         [-1.1778e-01],\n",
      "         [ 1.1825e+00],\n",
      "         [-1.1801e+00],\n",
      "         [-2.1859e-01],\n",
      "         [-1.6676e+00],\n",
      "         [-1.0415e-01],\n",
      "         [ 8.8033e-01],\n",
      "         [-7.0019e-01]],\n",
      "\n",
      "        [[-1.9371e-01],\n",
      "         [ 5.4381e-01],\n",
      "         [-2.9687e-01],\n",
      "         [ 6.8429e-01],\n",
      "         [ 5.0528e-01],\n",
      "         [-6.3122e-02],\n",
      "         [ 2.4948e-02],\n",
      "         [ 3.4935e-02],\n",
      "         [-5.9903e-01],\n",
      "         [ 2.9530e-01]],\n",
      "\n",
      "        [[-9.3742e-02],\n",
      "         [ 5.5731e-01],\n",
      "         [ 4.4727e-01],\n",
      "         [-1.9633e+00],\n",
      "         [ 7.6218e-01],\n",
      "         [-9.8049e-01],\n",
      "         [ 1.6627e-02],\n",
      "         [ 2.7729e-01],\n",
      "         [ 1.7569e-01],\n",
      "         [ 1.2022e+00]],\n",
      "\n",
      "        [[ 2.4165e-02],\n",
      "         [ 3.4443e-01],\n",
      "         [-1.3817e+00],\n",
      "         [-1.6941e+00],\n",
      "         [ 5.7643e-01],\n",
      "         [-3.3574e-01],\n",
      "         [-8.5208e-04],\n",
      "         [ 6.7266e-01],\n",
      "         [ 2.4279e-01],\n",
      "         [ 1.8059e+00]],\n",
      "\n",
      "        [[ 1.5710e+00],\n",
      "         [ 2.8216e+00],\n",
      "         [-2.3268e-02],\n",
      "         [-1.1153e+00],\n",
      "         [-8.6641e-01],\n",
      "         [ 5.0544e-01],\n",
      "         [-3.7233e-02],\n",
      "         [-2.8511e-01],\n",
      "         [-2.3818e+00],\n",
      "         [ 8.0363e-01]],\n",
      "\n",
      "        [[-2.4681e-01],\n",
      "         [ 1.0006e+00],\n",
      "         [ 1.9276e-01],\n",
      "         [-7.3025e-01],\n",
      "         [-1.0975e+00],\n",
      "         [ 9.3319e-01],\n",
      "         [-5.5379e-01],\n",
      "         [-5.1401e-01],\n",
      "         [-8.8545e-01],\n",
      "         [ 4.6912e-01]]], dtype=torch.float64), 'label': tensor(7)}\n"
     ]
    }
   ],
   "source": [
    "print(my_dataset[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f1c9da9",
   "metadata": {},
   "source": [
    "But let's keep things simple and just focus on the actual neural network, using a standard class to load a standard dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e97239d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81f14eed13584b959df99d123c11d53f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/9912422 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "35beff87bf47420fafba7e5a6301aaf9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/28881 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7b5d6f41eef14f8a9ad29beae3acf62b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1648877 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9e980b594b1942d68ab90d2255e2f42a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4542 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Use standard MNIST dataset\n",
    "my_dataset = torchvision.datasets.MNIST(\n",
    "    root = './data/MNIST',\n",
    "    train = True,\n",
    "    download = True,\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor()                                 \n",
    "    ])\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "527089bd",
   "metadata": {},
   "source": [
    "Plot some of the data with their labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "067b8105",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1440x1728 with 25 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(20,24))\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        idx = i*5+j\n",
    "        img = my_dataset[idx][0]\n",
    "        label = my_dataset[idx][1]\n",
    "        ax[i, j].imshow(img[0,...].detach().cpu().numpy())\n",
    "        ax[i, j].set(title=f\"Image {idx}, true label {label}\")\n",
    "        ax[i, j].set_xticklabels([])\n",
    "        ax[i, j].set_yticklabels([])\n",
    "plt.subplots_adjust(hspace=0.2,wspace=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e517975c",
   "metadata": {},
   "source": [
    "And now let us define the neural network. In PyTorch, neural networks always extend `nn.Module`. They define their sub-parts in their constructor, which are convolutional layers and fully connected linear layers in this case, and the method `forward` is expected to receive an input image and output the network target.\n",
    "\n",
    "The network parameters are the weights of the `Conv2d` and `Linear` layers, which are conveniently hidden here, but can be accessed if you try to access their `weights` elements.\n",
    "\n",
    "We will not directly output the label probabilities, since we do not actually need it to optimize the neural network: we need only the logits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d908ef86",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Network(nn.Module):\n",
    "    \"\"\"\n",
    "        This is our parametrized function.\n",
    "        It stores all the parametrized weights theta inside the conv1, conv2, fc1 and fc2 objects.\n",
    "        The forward function receives an image and outputs a vector.\n",
    "        \n",
    "        The intuition is that the i-th component of the vector represents the probability that\n",
    "        the probability that the image belongs to the i-th class however\n",
    "        we do not normalize the output to be in the range [0,1] and to sum to 1. The reason is\n",
    "        that this normalization is done later, in the training step, where the numerical error in it can be\n",
    "        minimized by calculating directly log(probability) instead of calculating first the probability\n",
    "        and then the log of it. Keep in mind therefore, that to get probabilities\n",
    "        from this object one should do F.softmax(my_network(x), dim=1).\n",
    "        \n",
    "        The code has been written like this, as this is a common optimization done in classification problems.\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        Constructor. Here we initialize the weights.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "\n",
    "        # define parameters\n",
    "        \n",
    "        # all these steps are purely linear (affine if one considers the bias)\n",
    "        # the forward function adds a non-linearity through the ReLU to allow this to do more than\n",
    "        # simple linear filters\n",
    "        \n",
    "        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)\n",
    "\n",
    "        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)\n",
    "        self.fc2 = nn.Linear(in_features=120, out_features=60)\n",
    "        self.out = nn.Linear(in_features=60, out_features=10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        This function is called when one does my_network(x) and it represents the action\n",
    "        of our parametrized function in the image, outputting the probabilities for that image as\n",
    "        a column vector. The input x has shape (B, C, H, W) (ie: batch dimension, channels, height and width).\n",
    "        The output has shape (B, K), where K is the number of classes.\n",
    "        Each row of the output has the probability for each class as a column vector.\n",
    "        Each column of the output has the probability for a single class for all images B given as an input.\n",
    "        \"\"\"\n",
    "\n",
    "        # first convolution\n",
    "        t = self.conv1(x)\n",
    "        # non-linearity\n",
    "        t = F.relu(t)\n",
    "        # reduce size of the image in width and height by taking the maximum\n",
    "        # pixel value in each 2x2 pixel matrix (kernel_size) and skipping one pixel (stride)\n",
    "        # the convolution receives one channel and outputs more\n",
    "        # the goal of the max_pool layer is to reduce the image size, so we\n",
    "        # can get more images in several channels which are smaller in size\n",
    "        # this is a trade off between memory and compute\n",
    "        t = F.max_pool2d(t, kernel_size=2, stride=2)\n",
    "\n",
    "        # second convolution\n",
    "        t = self.conv2(t)\n",
    "        # non-linearity\n",
    "        t = F.relu(t)\n",
    "        # reduce the size of the image in width and height again\n",
    "        t = F.max_pool2d(t, kernel_size=2, stride=2)\n",
    "\n",
    "        # transform images into a single vector using reshape\n",
    "        # this puts all pixel values in a single vector\n",
    "        t = t.reshape(-1, 12*4*4)\n",
    "        \n",
    "        # apply a linear transformation\n",
    "        t = self.fc1(t)\n",
    "        # add a non-linearity\n",
    "        t = F.relu(t)\n",
    "\n",
    "        # another linear transformation\n",
    "        t = self.fc2(t)\n",
    "        # another non-linearity\n",
    "        t = F.relu(t)\n",
    "\n",
    "        # final linear transformation\n",
    "        # the output of this has been set to 10 features, so the output will have the size\n",
    "        # (B, 10)\n",
    "        t = self.out(t)\n",
    "\n",
    "        # note: while we want the function to output a probability,\n",
    "        # we do not actually do any effort to normalize these numbers so that they are in [0, 1]\n",
    "        # and so that their sum is 1\n",
    "        # this would often be done by applying a transformation called Softmax(t) = exp(t)/sum(exp(t))\n",
    "        # however, this will be done internally by PyTorch in the function F.cross_entropy\n",
    "        # which we will call later on when training\n",
    "\n",
    "        return t"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c5620dc",
   "metadata": {},
   "source": [
    "Let us create one instance of this network. We also create an instance of PyTorch's `DataLoader`, which has the task of taking a given number of data elements and outputing it in a single object. This \"mini-batch\" of data is used during training, so that we do not need to load the entire data in memory during the optimization procedure.\n",
    "\n",
    "We also create an instance of the Adam optimizer, which is used to tune the parameters of the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "988e1979",
   "metadata": {},
   "outputs": [],
   "source": [
    "network = Network()\n",
    "B = 64\n",
    "loader = torch.utils.data.DataLoader(my_dataset, batch_size=B)\n",
    "optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ee54520",
   "metadata": {},
   "source": [
    "Now we actually repeatedly try to optimize the network parameters. Each time we go through all the data we have, we go through one \"epoch\". For each epoch, we take several \"mini-batches\" of data (given by the `DataLoader` in `loader`) and use it to make one training step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d15d655d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/danilo/miniconda3/envs/mlmkl/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448224956/work/c10/core/TensorImpl.h:1156.)\n",
      "  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/10: average loss 0.36116\n",
      "Epoch 1/10: average loss 0.10886\n",
      "Epoch 2/10: average loss 0.07327\n",
      "Epoch 3/10: average loss 0.05648\n",
      "Epoch 4/10: average loss 0.04617\n",
      "Epoch 5/10: average loss 0.03912\n",
      "Epoch 6/10: average loss 0.03184\n",
      "Epoch 7/10: average loss 0.02736\n",
      "Epoch 8/10: average loss 0.02388\n",
      "Epoch 9/10: average loss 0.02053\n"
     ]
    }
   ],
   "source": [
    "epochs = 10\n",
    "# for each epoch\n",
    "for epoch in range(epochs):\n",
    "    losses = list()\n",
    "    # for each mini-batch given by the loader:\n",
    "    for batch in loader:\n",
    "        # get the images in the mini-batch\n",
    "        # this has size (B, C, H, W)\n",
    "        # where B is the mini-batch size\n",
    "        # C is the number of channels in the image (1 for grayscale)\n",
    "        # H is the height of the image\n",
    "        # W is the width of the image\n",
    "        images = batch[0]\n",
    "        # get the labels in the mini-batch (there shall be B of them)\n",
    "        labels = batch[1]\n",
    "        # get the output of the neural network:\n",
    "        logits = network(images)\n",
    "        \n",
    "        # note: the network does not output probabilities directly: it outputs logits\n",
    "        # to get probabilities from it we would need to do F.softmax(logits, dim=1)\n",
    "        # however, this is done inside F.cross_entropy below and we therefore should\n",
    "        # not do it twice here\n",
    "        # the reason it is done internally, in F.cross_entropy, is that what we really\n",
    "        # need is log(probability) and we can reduce the numerical error\n",
    "        # in its calculation by calculating log(softmax(.)) in one go\n",
    "        # (remember softmax(x) = exp(x)/sum(exp(x)), so log(softmax(x)) = x - log(sum(exp(x))))\n",
    "        \n",
    "        # calculate the loss function being minimized\n",
    "        # in this case, it is the cross-entropy between the logits and the true labels\n",
    "        loss = F.cross_entropy(logits, labels)\n",
    "\n",
    "        # clean the optimizer temporary gradient storage\n",
    "        optimizer.zero_grad()\n",
    "        # calculate the gradient of the loss function as a function of the gradients\n",
    "        loss.backward()\n",
    "        # ask the Adam optimizer to change the parameters in the direction of - gradient\n",
    "        # Adam scales the gradient by a constant which is adaptively tuned\n",
    "        # take a look at the Adam paper for more details: https://arxiv.org/abs/1412.6980\n",
    "        optimizer.step()\n",
    "        losses.append(loss.detach().cpu().item())\n",
    "    avg_loss = np.mean(np.array(losses))\n",
    "    print(f\"Epoch {epoch}/{epochs}: average loss {avg_loss:.5f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4980bf4",
   "metadata": {},
   "source": [
    "Let us check what the network says about some new data it has never seen before (note that we set `train` to `False`, to take a statistically independent part of the dataset)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "09646d29",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = torchvision.datasets.MNIST(\n",
    "    root = './data/MNIST',\n",
    "    train = False,\n",
    "    download = True,\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor()                                 \n",
    "    ])\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e315b5dc",
   "metadata": {},
   "source": [
    "And now we can plot again the new images, now showing what the network tells us about it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "7a06a4c0",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1440x1728 with 25 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(20,24))\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        idx = i*5+j\n",
    "        img = test_dataset[idx][0]\n",
    "        label = test_dataset[idx][1]\n",
    "        logits = network(img[None, ...]) # output\n",
    "        probs = F.softmax(logits, dim=1) # apply softmax to normalize them\n",
    "        predicted = torch.argmax(probs[0, ...]) # index of the highest probability\n",
    "        ax[i, j].imshow(img[0,...].detach().cpu().numpy())\n",
    "        ax[i, j].set(title=f\"Image {idx}, true label {label}, predicted: {predicted}\")\n",
    "        ax[i, j].set_xticklabels([])\n",
    "        ax[i, j].set_yticklabels([])\n",
    "plt.subplots_adjust(hspace=0.2,wspace=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73e79a6d",
   "metadata": {},
   "source": [
    "We can also examine the probability that the network gives for all images in the expected class. That is: what is the predicted probability of the network for all images in the true class 4?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "b447df87",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=32)\n",
    "logits = list()\n",
    "label = list()\n",
    "for batch in test_loader:\n",
    "    logits += [network(batch[0])]\n",
    "    label += [batch[1]]\n",
    "logits = torch.cat(logits, dim=0)\n",
    "label = torch.cat(label, dim=0)\n",
    "probs = F.softmax(logits, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "d6df7b33",
   "metadata": {},
   "outputs": [],
   "source": [
    "probs_4 = probs.detach().cpu().numpy()[label.detach().cpu().numpy() == 4]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c850c24a",
   "metadata": {},
   "source": [
    "We can histogram it and see that it did mostly a good job, but sometimes it failed. We can go forward and make a cut in the probability and look at those images to identify which images were incorrectly classified."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "93a5fe94",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x720 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize=(10, 10))\n",
    "plt.hist(probs_4[:, 4], bins=20)\n",
    "ax.set(xlabel=r\"Probability that a true 4 belongs to class 4 according to the network (p(C$_4$|data))\",\n",
    "       ylabel=\"Number of occurences\",\n",
    "       title=\"\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bff1f8db",
   "metadata": {},
   "source": [
    "### Contact us at the EuXFEL Data Analysis group at any time if you need help analysing your data!\n",
    "\n",
    "#### Danilo Ferreira de Lima: danilo.enoque.ferreira.de.lima@xfel.eu\n",
    "#### Arman Davtyan: arman.davtyan@xfel.eu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}