{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1bba0128",
   "metadata": {},
   "source": [
    "# Supervised learning with PyTorch\n",
    "\n",
    "This is an example of how to build and optimize neural networks with PyTorch. PyTorch and Tensorflow offer a handy mechanism to provide automatic differentiation, using the chain rule in Calculus to calculate the derivative of a function very fast and with GPU support.\n",
    "\n",
    "Our dataset will consist of images of handwritten digits and the task shall be to classify those handwritten digits in the classes {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}.\n",
    "\n",
    "If this was a regression problem, we would often try to minimise the mean-squared-error between the output of the neural network and the correct prediction. As we saw in the presentation, this assumes that the underlying probability distribution of the prediction is a Gaussian, which is certainly not true for the distribution of digit classes: for one, the digit classes are discrete and Gaussians are only defined for continuous outputs. The most general probability distribution for a choice of 10 classes is a Categorical distribution (https://en.wikipedia.org/wiki/Categorical_distribution), which is simply a discrete distribution with a given probability value for each class. How can we then sculpt a function that maps the input image to a given class?\n",
    "\n",
    "Suppose the neural network provides an output $f_k(x)$ in the form of a list of probabilities, informing us of the probability that a given image belongs to a certain class $k$. If we know that a given input image x belongs to class C, then the true probability t for this image x to belong to each class is zero for classes that differ from C and 1 for the class C. The network's objective will be to output such probabilities, so that only the i-th component of the output is 1 if the input belongs to class i. The presentation shows how the Bayes' rule leads us naturally to minimize the cross entropy between the target probabilities and the predicted probabilities: $- \\sum_k t_k \\log f_k(x)$. One can gain intuition on this by reading more on the Information Theory concept of cross-entropy and how it relates to Mutual Information: minimizing the mutual information between the labels distribution and the predicted one moves them closer together:  https://en.wikipedia.org/wiki/Cross_entropy\n",
    "\n",
    "The neural network will therefore model a parametrized function that maps the input image pixels into a vector with 10 components, which refer to the probability that the image correspond to that digit.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d0681795",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torchvision in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (0.11.2)\n",
      "Requirement already satisfied: torch in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (1.10.1)\n",
      "Requirement already satisfied: pandas in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (1.1.5)\n",
      "Requirement already satisfied: numpy in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (1.19.2)\n",
      "Requirement already satisfied: matplotlib in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (3.3.4)\n",
      "Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (from torchvision) (8.3.1)\n",
      "Requirement already satisfied: typing_extensions in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (from torch) (3.10.0.2)\n",
      "Requirement already satisfied: dataclasses in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (from torch) (0.8)\n",
      "Requirement already satisfied: python-dateutil>=2.7.3 in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (from pandas) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2017.2 in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (from pandas) (2021.3)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (from matplotlib) (1.3.1)\n",
      "Requirement already satisfied: cycler>=0.10 in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (from matplotlib) (0.11.0)\n",
      "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (from matplotlib) (3.0.4)\n",
      "Requirement already satisfied: six>=1.5 in /home/daniloefl/miniconda3/envs/mltut/lib/python3.6/site-packages (from python-dateutil>=2.7.3->pandas) (1.16.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install torchvision torch pandas numpy matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "23feddde",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib notebook\n",
    "\n",
    "# import standard PyTorch modules\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# import torchvision module to handle image manipulation\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48433f6f",
   "metadata": {},
   "source": [
    "PyTorch allows you to create a class that outputs a single data entry and use that to feed input to your neural network. An example of how you would write such a class is given below, but for this exercise we shall use something ready-made which loads the standard MNIST handwritten digits dataset, just to simplify things.\n",
    "\n",
    "If you want to load a different dataset (for example your own data!), feel free to copy and modify the example Dataset class below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "30205402",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDataset(object):\n",
    "    def __init__(self):\n",
    "        pass\n",
    "    def __len__(self):\n",
    "        return 10 # these are how many samples I have\n",
    "    def __getitem__(self, idx):\n",
    "        # give me item with index idx\n",
    "        # read this from some file, but for the purposes of this example, generate a random image and label\n",
    "        my_image = np.random.randn(10,10, 1)\n",
    "        my_label = np.array(np.random.randint(10))\n",
    "        my_image = torch.from_numpy(my_image)\n",
    "        my_label = torch.from_numpy(my_label)\n",
    "        return {\"data\": my_image, \"label\": my_label}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cc0b0774",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n"
     ]
    }
   ],
   "source": [
    "my_dataset = MyDataset()\n",
    "print(len(my_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6dccfac6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'data': tensor([[[-2.2541e-01],\n",
      "         [-8.7492e-01],\n",
      "         [ 2.2757e-01],\n",
      "         [ 3.6083e-01],\n",
      "         [-1.0728e+00],\n",
      "         [-1.9445e+00],\n",
      "         [ 1.2723e+00],\n",
      "         [-4.6512e-01],\n",
      "         [ 7.8006e-01],\n",
      "         [-2.8025e+00]],\n",
      "\n",
      "        [[ 2.5337e-01],\n",
      "         [-3.9780e-01],\n",
      "         [-1.5573e+00],\n",
      "         [-6.1211e-01],\n",
      "         [-1.7922e+00],\n",
      "         [ 1.4484e+00],\n",
      "         [-3.1899e-01],\n",
      "         [ 1.4920e-03],\n",
      "         [-1.3394e+00],\n",
      "         [ 2.9670e-01]],\n",
      "\n",
      "        [[-1.1016e-01],\n",
      "         [ 1.6733e+00],\n",
      "         [ 2.7996e-01],\n",
      "         [-3.7264e-01],\n",
      "         [-1.4377e+00],\n",
      "         [ 1.7887e+00],\n",
      "         [ 1.6038e-01],\n",
      "         [ 7.4539e-01],\n",
      "         [ 1.9320e+00],\n",
      "         [ 5.8520e-01]],\n",
      "\n",
      "        [[ 1.8378e-03],\n",
      "         [ 7.1617e-01],\n",
      "         [ 1.9320e-01],\n",
      "         [-9.8246e-01],\n",
      "         [ 1.0930e+00],\n",
      "         [ 1.4183e+00],\n",
      "         [ 1.9863e-01],\n",
      "         [-1.1845e-01],\n",
      "         [ 1.3640e+00],\n",
      "         [ 8.7122e-01]],\n",
      "\n",
      "        [[ 1.6613e+00],\n",
      "         [ 1.2155e-01],\n",
      "         [ 3.4827e-01],\n",
      "         [ 9.9064e-01],\n",
      "         [-9.4938e-01],\n",
      "         [ 4.7694e-01],\n",
      "         [ 5.2231e-01],\n",
      "         [ 5.8428e-01],\n",
      "         [ 4.5853e-02],\n",
      "         [-2.5305e+00]],\n",
      "\n",
      "        [[-8.4262e-01],\n",
      "         [ 9.2893e-01],\n",
      "         [ 1.0124e+00],\n",
      "         [-1.6017e+00],\n",
      "         [-1.4961e-01],\n",
      "         [ 2.2560e+00],\n",
      "         [ 2.0985e-01],\n",
      "         [ 7.7839e-01],\n",
      "         [ 1.5581e+00],\n",
      "         [-2.3309e+00]],\n",
      "\n",
      "        [[ 1.5367e+00],\n",
      "         [-3.7147e-01],\n",
      "         [-2.3096e-01],\n",
      "         [ 1.1904e+00],\n",
      "         [ 1.9962e+00],\n",
      "         [-1.0334e-01],\n",
      "         [-6.0397e-01],\n",
      "         [ 1.0851e+00],\n",
      "         [-2.3377e-01],\n",
      "         [-7.0282e-01]],\n",
      "\n",
      "        [[ 1.7610e+00],\n",
      "         [-1.3054e+00],\n",
      "         [ 1.6490e+00],\n",
      "         [-3.9569e-01],\n",
      "         [-1.5368e+00],\n",
      "         [-1.0189e+00],\n",
      "         [-1.9441e-02],\n",
      "         [ 9.1458e-01],\n",
      "         [ 1.2236e+00],\n",
      "         [ 8.6759e-01]],\n",
      "\n",
      "        [[ 4.9696e-01],\n",
      "         [-1.0176e+00],\n",
      "         [ 2.7300e+00],\n",
      "         [-1.1193e+00],\n",
      "         [ 1.5149e+00],\n",
      "         [ 2.2174e+00],\n",
      "         [ 4.5604e-01],\n",
      "         [-1.0018e+00],\n",
      "         [ 6.1021e-01],\n",
      "         [-5.1552e-01]],\n",
      "\n",
      "        [[ 2.3127e+00],\n",
      "         [-5.9054e-01],\n",
      "         [-1.1682e+00],\n",
      "         [-2.8864e-02],\n",
      "         [-1.8656e+00],\n",
      "         [-1.4874e+00],\n",
      "         [-1.7110e+00],\n",
      "         [ 8.5748e-01],\n",
      "         [-1.2645e+00],\n",
      "         [-3.9108e-01]]], dtype=torch.float64), 'label': tensor(3)}\n"
     ]
    }
   ],
   "source": [
    "print(my_dataset[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f1c9da9",
   "metadata": {},
   "source": [
    "But let's keep things simple and just focus on the actual neural network, using a standard class to load a standard dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e97239d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use standard MNIST dataset\n",
    "my_dataset = torchvision.datasets.MNIST(\n",
    "    root = './data/MNIST',\n",
    "    train = True,\n",
    "    download = True,\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor()                                 \n",
    "    ])\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "527089bd",
   "metadata": {},
   "source": [
    "Plot some of the data with their labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "067b8105",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "/* Put everything inside the global mpl namespace */\n",
       "/* global mpl */\n",
       "window.mpl = {};\n",
       "\n",
       "mpl.get_websocket_type = function () {\n",
       "    if (typeof WebSocket !== 'undefined') {\n",
       "        return WebSocket;\n",
       "    } else if (typeof MozWebSocket !== 'undefined') {\n",
       "        return MozWebSocket;\n",
       "    } else {\n",
       "        alert(\n",
       "            'Your browser does not have WebSocket support. ' +\n",
       "                'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
       "                'Firefox 4 and 5 are also supported but you ' +\n",
       "                'have to enable WebSockets in about:config.'\n",
       "        );\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n",
       "    this.id = figure_id;\n",
       "\n",
       "    this.ws = websocket;\n",
       "\n",
       "    this.supports_binary = this.ws.binaryType !== undefined;\n",
       "\n",
       "    if (!this.supports_binary) {\n",
       "        var warnings = document.getElementById('mpl-warnings');\n",
       "        if (warnings) {\n",
       "            warnings.style.display = 'block';\n",
       "            warnings.textContent =\n",
       "                'This browser does not support binary websocket messages. ' +\n",
       "                'Performance may be slow.';\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.imageObj = new Image();\n",
       "\n",
       "    this.context = undefined;\n",
       "    this.message = undefined;\n",
       "    this.canvas = undefined;\n",
       "    this.rubberband_canvas = undefined;\n",
       "    this.rubberband_context = undefined;\n",
       "    this.format_dropdown = undefined;\n",
       "\n",
       "    this.image_mode = 'full';\n",
       "\n",
       "    this.root = document.createElement('div');\n",
       "    this.root.setAttribute('style', 'display: inline-block');\n",
       "    this._root_extra_style(this.root);\n",
       "\n",
       "    parent_element.appendChild(this.root);\n",
       "\n",
       "    this._init_header(this);\n",
       "    this._init_canvas(this);\n",
       "    this._init_toolbar(this);\n",
       "\n",
       "    var fig = this;\n",
       "\n",
       "    this.waiting = false;\n",
       "\n",
       "    this.ws.onopen = function () {\n",
       "        fig.send_message('supports_binary', { value: fig.supports_binary });\n",
       "        fig.send_message('send_image_mode', {});\n",
       "        if (fig.ratio !== 1) {\n",
       "            fig.send_message('set_dpi_ratio', { dpi_ratio: fig.ratio });\n",
       "        }\n",
       "        fig.send_message('refresh', {});\n",
       "    };\n",
       "\n",
       "    this.imageObj.onload = function () {\n",
       "        if (fig.image_mode === 'full') {\n",
       "            // Full images could contain transparency (where diff images\n",
       "            // almost always do), so we need to clear the canvas so that\n",
       "            // there is no ghosting.\n",
       "            fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
       "        }\n",
       "        fig.context.drawImage(fig.imageObj, 0, 0);\n",
       "    };\n",
       "\n",
       "    this.imageObj.onunload = function () {\n",
       "        fig.ws.close();\n",
       "    };\n",
       "\n",
       "    this.ws.onmessage = this._make_on_message_function(this);\n",
       "\n",
       "    this.ondownload = ondownload;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_header = function () {\n",
       "    var titlebar = document.createElement('div');\n",
       "    titlebar.classList =\n",
       "        'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n",
       "    var titletext = document.createElement('div');\n",
       "    titletext.classList = 'ui-dialog-title';\n",
       "    titletext.setAttribute(\n",
       "        'style',\n",
       "        'width: 100%; text-align: center; padding: 3px;'\n",
       "    );\n",
       "    titlebar.appendChild(titletext);\n",
       "    this.root.appendChild(titlebar);\n",
       "    this.header = titletext;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n",
       "\n",
       "mpl.figure.prototype._init_canvas = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var canvas_div = (this.canvas_div = document.createElement('div'));\n",
       "    canvas_div.setAttribute(\n",
       "        'style',\n",
       "        'border: 1px solid #ddd;' +\n",
       "            'box-sizing: content-box;' +\n",
       "            'clear: both;' +\n",
       "            'min-height: 1px;' +\n",
       "            'min-width: 1px;' +\n",
       "            'outline: 0;' +\n",
       "            'overflow: hidden;' +\n",
       "            'position: relative;' +\n",
       "            'resize: both;'\n",
       "    );\n",
       "\n",
       "    function on_keyboard_event_closure(name) {\n",
       "        return function (event) {\n",
       "            return fig.key_event(event, name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    canvas_div.addEventListener(\n",
       "        'keydown',\n",
       "        on_keyboard_event_closure('key_press')\n",
       "    );\n",
       "    canvas_div.addEventListener(\n",
       "        'keyup',\n",
       "        on_keyboard_event_closure('key_release')\n",
       "    );\n",
       "\n",
       "    this._canvas_extra_style(canvas_div);\n",
       "    this.root.appendChild(canvas_div);\n",
       "\n",
       "    var canvas = (this.canvas = document.createElement('canvas'));\n",
       "    canvas.classList.add('mpl-canvas');\n",
       "    canvas.setAttribute('style', 'box-sizing: content-box;');\n",
       "\n",
       "    this.context = canvas.getContext('2d');\n",
       "\n",
       "    var backingStore =\n",
       "        this.context.backingStorePixelRatio ||\n",
       "        this.context.webkitBackingStorePixelRatio ||\n",
       "        this.context.mozBackingStorePixelRatio ||\n",
       "        this.context.msBackingStorePixelRatio ||\n",
       "        this.context.oBackingStorePixelRatio ||\n",
       "        this.context.backingStorePixelRatio ||\n",
       "        1;\n",
       "\n",
       "    this.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
       "\n",
       "    var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n",
       "        'canvas'\n",
       "    ));\n",
       "    rubberband_canvas.setAttribute(\n",
       "        'style',\n",
       "        'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n",
       "    );\n",
       "\n",
       "    // Apply a ponyfill if ResizeObserver is not implemented by browser.\n",
       "    if (this.ResizeObserver === undefined) {\n",
       "        if (window.ResizeObserver !== undefined) {\n",
       "            this.ResizeObserver = window.ResizeObserver;\n",
       "        } else {\n",
       "            var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n",
       "            this.ResizeObserver = obs.ResizeObserver;\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n",
       "        var nentries = entries.length;\n",
       "        for (var i = 0; i < nentries; i++) {\n",
       "            var entry = entries[i];\n",
       "            var width, height;\n",
       "            if (entry.contentBoxSize) {\n",
       "                if (entry.contentBoxSize instanceof Array) {\n",
       "                    // Chrome 84 implements new version of spec.\n",
       "                    width = entry.contentBoxSize[0].inlineSize;\n",
       "                    height = entry.contentBoxSize[0].blockSize;\n",
       "                } else {\n",
       "                    // Firefox implements old version of spec.\n",
       "                    width = entry.contentBoxSize.inlineSize;\n",
       "                    height = entry.contentBoxSize.blockSize;\n",
       "                }\n",
       "            } else {\n",
       "                // Chrome <84 implements even older version of spec.\n",
       "                width = entry.contentRect.width;\n",
       "                height = entry.contentRect.height;\n",
       "            }\n",
       "\n",
       "            // Keep the size of the canvas and rubber band canvas in sync with\n",
       "            // the canvas container.\n",
       "            if (entry.devicePixelContentBoxSize) {\n",
       "                // Chrome 84 implements new version of spec.\n",
       "                canvas.setAttribute(\n",
       "                    'width',\n",
       "                    entry.devicePixelContentBoxSize[0].inlineSize\n",
       "                );\n",
       "                canvas.setAttribute(\n",
       "                    'height',\n",
       "                    entry.devicePixelContentBoxSize[0].blockSize\n",
       "                );\n",
       "            } else {\n",
       "                canvas.setAttribute('width', width * fig.ratio);\n",
       "                canvas.setAttribute('height', height * fig.ratio);\n",
       "            }\n",
       "            canvas.setAttribute(\n",
       "                'style',\n",
       "                'width: ' + width + 'px; height: ' + height + 'px;'\n",
       "            );\n",
       "\n",
       "            rubberband_canvas.setAttribute('width', width);\n",
       "            rubberband_canvas.setAttribute('height', height);\n",
       "\n",
       "            // And update the size in Python. We ignore the initial 0/0 size\n",
       "            // that occurs as the element is placed into the DOM, which should\n",
       "            // otherwise not happen due to the minimum size styling.\n",
       "            if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n",
       "                fig.request_resize(width, height);\n",
       "            }\n",
       "        }\n",
       "    });\n",
       "    this.resizeObserverInstance.observe(canvas_div);\n",
       "\n",
       "    function on_mouse_event_closure(name) {\n",
       "        return function (event) {\n",
       "            return fig.mouse_event(event, name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mousedown',\n",
       "        on_mouse_event_closure('button_press')\n",
       "    );\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseup',\n",
       "        on_mouse_event_closure('button_release')\n",
       "    );\n",
       "    // Throttle sequential mouse events to 1 every 20ms.\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mousemove',\n",
       "        on_mouse_event_closure('motion_notify')\n",
       "    );\n",
       "\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseenter',\n",
       "        on_mouse_event_closure('figure_enter')\n",
       "    );\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseleave',\n",
       "        on_mouse_event_closure('figure_leave')\n",
       "    );\n",
       "\n",
       "    canvas_div.addEventListener('wheel', function (event) {\n",
       "        if (event.deltaY < 0) {\n",
       "            event.step = 1;\n",
       "        } else {\n",
       "            event.step = -1;\n",
       "        }\n",
       "        on_mouse_event_closure('scroll')(event);\n",
       "    });\n",
       "\n",
       "    canvas_div.appendChild(canvas);\n",
       "    canvas_div.appendChild(rubberband_canvas);\n",
       "\n",
       "    this.rubberband_context = rubberband_canvas.getContext('2d');\n",
       "    this.rubberband_context.strokeStyle = '#000000';\n",
       "\n",
       "    this._resize_canvas = function (width, height, forward) {\n",
       "        if (forward) {\n",
       "            canvas_div.style.width = width + 'px';\n",
       "            canvas_div.style.height = height + 'px';\n",
       "        }\n",
       "    };\n",
       "\n",
       "    // Disable right mouse context menu.\n",
       "    this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n",
       "        event.preventDefault();\n",
       "        return false;\n",
       "    });\n",
       "\n",
       "    function set_focus() {\n",
       "        canvas.focus();\n",
       "        canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    window.setTimeout(set_focus, 100);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var toolbar = document.createElement('div');\n",
       "    toolbar.classList = 'mpl-toolbar';\n",
       "    this.root.appendChild(toolbar);\n",
       "\n",
       "    function on_click_closure(name) {\n",
       "        return function (_event) {\n",
       "            return fig.toolbar_button_onclick(name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    function on_mouseover_closure(tooltip) {\n",
       "        return function (event) {\n",
       "            if (!event.currentTarget.disabled) {\n",
       "                return fig.toolbar_button_onmouseover(tooltip);\n",
       "            }\n",
       "        };\n",
       "    }\n",
       "\n",
       "    fig.buttons = {};\n",
       "    var buttonGroup = document.createElement('div');\n",
       "    buttonGroup.classList = 'mpl-button-group';\n",
       "    for (var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            /* Instead of a spacer, we start a new button group. */\n",
       "            if (buttonGroup.hasChildNodes()) {\n",
       "                toolbar.appendChild(buttonGroup);\n",
       "            }\n",
       "            buttonGroup = document.createElement('div');\n",
       "            buttonGroup.classList = 'mpl-button-group';\n",
       "            continue;\n",
       "        }\n",
       "\n",
       "        var button = (fig.buttons[name] = document.createElement('button'));\n",
       "        button.classList = 'mpl-widget';\n",
       "        button.setAttribute('role', 'button');\n",
       "        button.setAttribute('aria-disabled', 'false');\n",
       "        button.addEventListener('click', on_click_closure(method_name));\n",
       "        button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n",
       "\n",
       "        var icon_img = document.createElement('img');\n",
       "        icon_img.src = '_images/' + image + '.png';\n",
       "        icon_img.srcset = '_images/' + image + '_large.png 2x';\n",
       "        icon_img.alt = tooltip;\n",
       "        button.appendChild(icon_img);\n",
       "\n",
       "        buttonGroup.appendChild(button);\n",
       "    }\n",
       "\n",
       "    if (buttonGroup.hasChildNodes()) {\n",
       "        toolbar.appendChild(buttonGroup);\n",
       "    }\n",
       "\n",
       "    var fmt_picker = document.createElement('select');\n",
       "    fmt_picker.classList = 'mpl-widget';\n",
       "    toolbar.appendChild(fmt_picker);\n",
       "    this.format_dropdown = fmt_picker;\n",
       "\n",
       "    for (var ind in mpl.extensions) {\n",
       "        var fmt = mpl.extensions[ind];\n",
       "        var option = document.createElement('option');\n",
       "        option.selected = fmt === mpl.default_extension;\n",
       "        option.innerHTML = fmt;\n",
       "        fmt_picker.appendChild(option);\n",
       "    }\n",
       "\n",
       "    var status_bar = document.createElement('span');\n",
       "    status_bar.classList = 'mpl-message';\n",
       "    toolbar.appendChild(status_bar);\n",
       "    this.message = status_bar;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n",
       "    // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
       "    // which will in turn request a refresh of the image.\n",
       "    this.send_message('resize', { width: x_pixels, height: y_pixels });\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.send_message = function (type, properties) {\n",
       "    properties['type'] = type;\n",
       "    properties['figure_id'] = this.id;\n",
       "    this.ws.send(JSON.stringify(properties));\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.send_draw_message = function () {\n",
       "    if (!this.waiting) {\n",
       "        this.waiting = true;\n",
       "        this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_save = function (fig, _msg) {\n",
       "    var format_dropdown = fig.format_dropdown;\n",
       "    var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
       "    fig.ondownload(fig, format);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_resize = function (fig, msg) {\n",
       "    var size = msg['size'];\n",
       "    if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n",
       "        fig._resize_canvas(size[0], size[1], msg['forward']);\n",
       "        fig.send_message('refresh', {});\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_rubberband = function (fig, msg) {\n",
       "    var x0 = msg['x0'] / fig.ratio;\n",
       "    var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n",
       "    var x1 = msg['x1'] / fig.ratio;\n",
       "    var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n",
       "    x0 = Math.floor(x0) + 0.5;\n",
       "    y0 = Math.floor(y0) + 0.5;\n",
       "    x1 = Math.floor(x1) + 0.5;\n",
       "    y1 = Math.floor(y1) + 0.5;\n",
       "    var min_x = Math.min(x0, x1);\n",
       "    var min_y = Math.min(y0, y1);\n",
       "    var width = Math.abs(x1 - x0);\n",
       "    var height = Math.abs(y1 - y0);\n",
       "\n",
       "    fig.rubberband_context.clearRect(\n",
       "        0,\n",
       "        0,\n",
       "        fig.canvas.width / fig.ratio,\n",
       "        fig.canvas.height / fig.ratio\n",
       "    );\n",
       "\n",
       "    fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_figure_label = function (fig, msg) {\n",
       "    // Updates the figure title.\n",
       "    fig.header.textContent = msg['label'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_cursor = function (fig, msg) {\n",
       "    var cursor = msg['cursor'];\n",
       "    switch (cursor) {\n",
       "        case 0:\n",
       "            cursor = 'pointer';\n",
       "            break;\n",
       "        case 1:\n",
       "            cursor = 'default';\n",
       "            break;\n",
       "        case 2:\n",
       "            cursor = 'crosshair';\n",
       "            break;\n",
       "        case 3:\n",
       "            cursor = 'move';\n",
       "            break;\n",
       "    }\n",
       "    fig.rubberband_canvas.style.cursor = cursor;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_message = function (fig, msg) {\n",
       "    fig.message.textContent = msg['message'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_draw = function (fig, _msg) {\n",
       "    // Request the server to send over a new figure.\n",
       "    fig.send_draw_message();\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_image_mode = function (fig, msg) {\n",
       "    fig.image_mode = msg['mode'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n",
       "    for (var key in msg) {\n",
       "        if (!(key in fig.buttons)) {\n",
       "            continue;\n",
       "        }\n",
       "        fig.buttons[key].disabled = !msg[key];\n",
       "        fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n",
       "    if (msg['mode'] === 'PAN') {\n",
       "        fig.buttons['Pan'].classList.add('active');\n",
       "        fig.buttons['Zoom'].classList.remove('active');\n",
       "    } else if (msg['mode'] === 'ZOOM') {\n",
       "        fig.buttons['Pan'].classList.remove('active');\n",
       "        fig.buttons['Zoom'].classList.add('active');\n",
       "    } else {\n",
       "        fig.buttons['Pan'].classList.remove('active');\n",
       "        fig.buttons['Zoom'].classList.remove('active');\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function () {\n",
       "    // Called whenever the canvas gets updated.\n",
       "    this.send_message('ack', {});\n",
       "};\n",
       "\n",
       "// A function to construct a web socket function for onmessage handling.\n",
       "// Called in the figure constructor.\n",
       "mpl.figure.prototype._make_on_message_function = function (fig) {\n",
       "    return function socket_on_message(evt) {\n",
       "        if (evt.data instanceof Blob) {\n",
       "            /* FIXME: We get \"Resource interpreted as Image but\n",
       "             * transferred with MIME type text/plain:\" errors on\n",
       "             * Chrome.  But how to set the MIME type?  It doesn't seem\n",
       "             * to be part of the websocket stream */\n",
       "            evt.data.type = 'image/png';\n",
       "\n",
       "            /* Free the memory for the previous frames */\n",
       "            if (fig.imageObj.src) {\n",
       "                (window.URL || window.webkitURL).revokeObjectURL(\n",
       "                    fig.imageObj.src\n",
       "                );\n",
       "            }\n",
       "\n",
       "            fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
       "                evt.data\n",
       "            );\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        } else if (\n",
       "            typeof evt.data === 'string' &&\n",
       "            evt.data.slice(0, 21) === 'data:image/png;base64'\n",
       "        ) {\n",
       "            fig.imageObj.src = evt.data;\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        var msg = JSON.parse(evt.data);\n",
       "        var msg_type = msg['type'];\n",
       "\n",
       "        // Call the  \"handle_{type}\" callback, which takes\n",
       "        // the figure and JSON message as its only arguments.\n",
       "        try {\n",
       "            var callback = fig['handle_' + msg_type];\n",
       "        } catch (e) {\n",
       "            console.log(\n",
       "                \"No handler for the '\" + msg_type + \"' message type: \",\n",
       "                msg\n",
       "            );\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        if (callback) {\n",
       "            try {\n",
       "                // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
       "                callback(fig, msg);\n",
       "            } catch (e) {\n",
       "                console.log(\n",
       "                    \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n",
       "                    e,\n",
       "                    e.stack,\n",
       "                    msg\n",
       "                );\n",
       "            }\n",
       "        }\n",
       "    };\n",
       "};\n",
       "\n",
       "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
       "mpl.findpos = function (e) {\n",
       "    //this section is from http://www.quirksmode.org/js/events_properties.html\n",
       "    var targ;\n",
       "    if (!e) {\n",
       "        e = window.event;\n",
       "    }\n",
       "    if (e.target) {\n",
       "        targ = e.target;\n",
       "    } else if (e.srcElement) {\n",
       "        targ = e.srcElement;\n",
       "    }\n",
       "    if (targ.nodeType === 3) {\n",
       "        // defeat Safari bug\n",
       "        targ = targ.parentNode;\n",
       "    }\n",
       "\n",
       "    // pageX,Y are the mouse positions relative to the document\n",
       "    var boundingRect = targ.getBoundingClientRect();\n",
       "    var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n",
       "    var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n",
       "\n",
       "    return { x: x, y: y };\n",
       "};\n",
       "\n",
       "/*\n",
       " * return a copy of an object with only non-object keys\n",
       " * we need this to avoid circular references\n",
       " * http://stackoverflow.com/a/24161582/3208463\n",
       " */\n",
       "function simpleKeys(original) {\n",
       "    return Object.keys(original).reduce(function (obj, key) {\n",
       "        if (typeof original[key] !== 'object') {\n",
       "            obj[key] = original[key];\n",
       "        }\n",
       "        return obj;\n",
       "    }, {});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.mouse_event = function (event, name) {\n",
       "    var canvas_pos = mpl.findpos(event);\n",
       "\n",
       "    if (name === 'button_press') {\n",
       "        this.canvas.focus();\n",
       "        this.canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    var x = canvas_pos.x * this.ratio;\n",
       "    var y = canvas_pos.y * this.ratio;\n",
       "\n",
       "    this.send_message(name, {\n",
       "        x: x,\n",
       "        y: y,\n",
       "        button: event.button,\n",
       "        step: event.step,\n",
       "        guiEvent: simpleKeys(event),\n",
       "    });\n",
       "\n",
       "    /* This prevents the web browser from automatically changing to\n",
       "     * the text insertion cursor when the button is pressed.  We want\n",
       "     * to control all of the cursor setting manually through the\n",
       "     * 'cursor' event from matplotlib */\n",
       "    event.preventDefault();\n",
       "    return false;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function (_event, _name) {\n",
       "    // Handle any extra behaviour associated with a key event\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.key_event = function (event, name) {\n",
       "    // Prevent repeat events\n",
       "    if (name === 'key_press') {\n",
       "        if (event.which === this._key) {\n",
       "            return;\n",
       "        } else {\n",
       "            this._key = event.which;\n",
       "        }\n",
       "    }\n",
       "    if (name === 'key_release') {\n",
       "        this._key = null;\n",
       "    }\n",
       "\n",
       "    var value = '';\n",
       "    if (event.ctrlKey && event.which !== 17) {\n",
       "        value += 'ctrl+';\n",
       "    }\n",
       "    if (event.altKey && event.which !== 18) {\n",
       "        value += 'alt+';\n",
       "    }\n",
       "    if (event.shiftKey && event.which !== 16) {\n",
       "        value += 'shift+';\n",
       "    }\n",
       "\n",
       "    value += 'k';\n",
       "    value += event.which.toString();\n",
       "\n",
       "    this._key_event_extra(event, name);\n",
       "\n",
       "    this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n",
       "    return false;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onclick = function (name) {\n",
       "    if (name === 'download') {\n",
       "        this.handle_save(this, null);\n",
       "    } else {\n",
       "        this.send_message('toolbar_button', { name: name });\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n",
       "    this.message.textContent = tooltip;\n",
       "};\n",
       "\n",
       "///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n",
       "// prettier-ignore\n",
       "var _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\n",
       "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
       "\n",
       "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
       "\n",
       "mpl.default_extension = \"png\";/* global mpl */\n",
       "\n",
       "var comm_websocket_adapter = function (comm) {\n",
       "    // Create a \"websocket\"-like object which calls the given IPython comm\n",
       "    // object with the appropriate methods. Currently this is a non binary\n",
       "    // socket, so there is still some room for performance tuning.\n",
       "    var ws = {};\n",
       "\n",
       "    ws.close = function () {\n",
       "        comm.close();\n",
       "    };\n",
       "    ws.send = function (m) {\n",
       "        //console.log('sending', m);\n",
       "        comm.send(m);\n",
       "    };\n",
       "    // Register the callback with on_msg.\n",
       "    comm.on_msg(function (msg) {\n",
       "        //console.log('receiving', msg['content']['data'], msg);\n",
       "        // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
       "        ws.onmessage(msg['content']['data']);\n",
       "    });\n",
       "    return ws;\n",
       "};\n",
       "\n",
       "mpl.mpl_figure_comm = function (comm, msg) {\n",
       "    // This is the function which gets called when the mpl process\n",
       "    // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
       "\n",
       "    var id = msg.content.data.id;\n",
       "    // Get hold of the div created by the display call when the Comm\n",
       "    // socket was opened in Python.\n",
       "    var element = document.getElementById(id);\n",
       "    var ws_proxy = comm_websocket_adapter(comm);\n",
       "\n",
       "    function ondownload(figure, _format) {\n",
       "        window.open(figure.canvas.toDataURL());\n",
       "    }\n",
       "\n",
       "    var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n",
       "\n",
       "    // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
       "    // web socket which is closed, not our websocket->open comm proxy.\n",
       "    ws_proxy.onopen();\n",
       "\n",
       "    fig.parent_element = element;\n",
       "    fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
       "    if (!fig.cell_info) {\n",
       "        console.error('Failed to find cell for figure', id, fig);\n",
       "        return;\n",
       "    }\n",
       "    fig.cell_info[0].output_area.element.on(\n",
       "        'cleared',\n",
       "        { fig: fig },\n",
       "        fig._remove_fig_handler\n",
       "    );\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_close = function (fig, msg) {\n",
       "    var width = fig.canvas.width / fig.ratio;\n",
       "    fig.cell_info[0].output_area.element.off(\n",
       "        'cleared',\n",
       "        fig._remove_fig_handler\n",
       "    );\n",
       "    fig.resizeObserverInstance.unobserve(fig.canvas_div);\n",
       "\n",
       "    // Update the output cell to use the data from the current canvas.\n",
       "    fig.push_to_output();\n",
       "    var dataURL = fig.canvas.toDataURL();\n",
       "    // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
       "    // the notebook keyboard shortcuts fail.\n",
       "    IPython.keyboard_manager.enable();\n",
       "    fig.parent_element.innerHTML =\n",
       "        '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "    fig.close_ws(fig, msg);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.close_ws = function (fig, msg) {\n",
       "    fig.send_message('closing', msg);\n",
       "    // fig.ws.close()\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.push_to_output = function (_remove_interactive) {\n",
       "    // Turn the data on the canvas into data in the output cell.\n",
       "    var width = this.canvas.width / this.ratio;\n",
       "    var dataURL = this.canvas.toDataURL();\n",
       "    this.cell_info[1]['text/html'] =\n",
       "        '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function () {\n",
       "    // Tell IPython that the notebook contents must change.\n",
       "    IPython.notebook.set_dirty(true);\n",
       "    this.send_message('ack', {});\n",
       "    var fig = this;\n",
       "    // Wait a second, then push the new image to the DOM so\n",
       "    // that it is saved nicely (might be nice to debounce this).\n",
       "    setTimeout(function () {\n",
       "        fig.push_to_output();\n",
       "    }, 1000);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var toolbar = document.createElement('div');\n",
       "    toolbar.classList = 'btn-toolbar';\n",
       "    this.root.appendChild(toolbar);\n",
       "\n",
       "    function on_click_closure(name) {\n",
       "        return function (_event) {\n",
       "            return fig.toolbar_button_onclick(name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    function on_mouseover_closure(tooltip) {\n",
       "        return function (event) {\n",
       "            if (!event.currentTarget.disabled) {\n",
       "                return fig.toolbar_button_onmouseover(tooltip);\n",
       "            }\n",
       "        };\n",
       "    }\n",
       "\n",
       "    fig.buttons = {};\n",
       "    var buttonGroup = document.createElement('div');\n",
       "    buttonGroup.classList = 'btn-group';\n",
       "    var button;\n",
       "    for (var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            /* Instead of a spacer, we start a new button group. */\n",
       "            if (buttonGroup.hasChildNodes()) {\n",
       "                toolbar.appendChild(buttonGroup);\n",
       "            }\n",
       "            buttonGroup = document.createElement('div');\n",
       "            buttonGroup.classList = 'btn-group';\n",
       "            continue;\n",
       "        }\n",
       "\n",
       "        button = fig.buttons[name] = document.createElement('button');\n",
       "        button.classList = 'btn btn-default';\n",
       "        button.href = '#';\n",
       "        button.title = name;\n",
       "        button.innerHTML = '<i class=\"fa ' + image + ' fa-lg\"></i>';\n",
       "        button.addEventListener('click', on_click_closure(method_name));\n",
       "        button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n",
       "        buttonGroup.appendChild(button);\n",
       "    }\n",
       "\n",
       "    if (buttonGroup.hasChildNodes()) {\n",
       "        toolbar.appendChild(buttonGroup);\n",
       "    }\n",
       "\n",
       "    // Add the status bar.\n",
       "    var status_bar = document.createElement('span');\n",
       "    status_bar.classList = 'mpl-message pull-right';\n",
       "    toolbar.appendChild(status_bar);\n",
       "    this.message = status_bar;\n",
       "\n",
       "    // Add the close button to the window.\n",
       "    var buttongrp = document.createElement('div');\n",
       "    buttongrp.classList = 'btn-group inline pull-right';\n",
       "    button = document.createElement('button');\n",
       "    button.classList = 'btn btn-mini btn-primary';\n",
       "    button.href = '#';\n",
       "    button.title = 'Stop Interaction';\n",
       "    button.innerHTML = '<i class=\"fa fa-power-off icon-remove icon-large\"></i>';\n",
       "    button.addEventListener('click', function (_evt) {\n",
       "        fig.handle_close(fig, {});\n",
       "    });\n",
       "    button.addEventListener(\n",
       "        'mouseover',\n",
       "        on_mouseover_closure('Stop Interaction')\n",
       "    );\n",
       "    buttongrp.appendChild(button);\n",
       "    var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n",
       "    titlebar.insertBefore(buttongrp, titlebar.firstChild);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._remove_fig_handler = function (event) {\n",
       "    var fig = event.data.fig;\n",
       "    if (event.target !== this) {\n",
       "        // Ignore bubbled events from children.\n",
       "        return;\n",
       "    }\n",
       "    fig.close_ws(fig, {});\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function (el) {\n",
       "    el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function (el) {\n",
       "    // this is important to make the div 'focusable\n",
       "    el.setAttribute('tabindex', 0);\n",
       "    // reach out to IPython and tell the keyboard manager to turn it's self\n",
       "    // off when our div gets focus\n",
       "\n",
       "    // location in version 3\n",
       "    if (IPython.notebook.keyboard_manager) {\n",
       "        IPython.notebook.keyboard_manager.register_events(el);\n",
       "    } else {\n",
       "        // location in version 2\n",
       "        IPython.keyboard_manager.register_events(el);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function (event, _name) {\n",
       "    var manager = IPython.notebook.keyboard_manager;\n",
       "    if (!manager) {\n",
       "        manager = IPython.keyboard_manager;\n",
       "    }\n",
       "\n",
       "    // Check for shift+enter\n",
       "    if (event.shiftKey && event.which === 13) {\n",
       "        this.canvas_div.blur();\n",
       "        // select the cell after this one\n",
       "        var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
       "        IPython.notebook.select(index + 1);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_save = function (fig, _msg) {\n",
       "    fig.ondownload(fig, null);\n",
       "};\n",
       "\n",
       "mpl.find_output_cell = function (html_output) {\n",
       "    // Return the cell and output element which can be found *uniquely* in the notebook.\n",
       "    // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
       "    // IPython event is triggered only after the cells have been serialised, which for\n",
       "    // our purposes (turning an active figure into a static one), is too late.\n",
       "    var cells = IPython.notebook.get_cells();\n",
       "    var ncells = cells.length;\n",
       "    for (var i = 0; i < ncells; i++) {\n",
       "        var cell = cells[i];\n",
       "        if (cell.cell_type === 'code') {\n",
       "            for (var j = 0; j < cell.output_area.outputs.length; j++) {\n",
       "                var data = cell.output_area.outputs[j];\n",
       "                if (data.data) {\n",
       "                    // IPython >= 3 moved mimebundle to data attribute of output\n",
       "                    data = data.data;\n",
       "                }\n",
       "                if (data['text/html'] === html_output) {\n",
       "                    return [cell, data, j];\n",
       "                }\n",
       "            }\n",
       "        }\n",
       "    }\n",
       "};\n",
       "\n",
       "// Register the function which deals with the matplotlib target/channel.\n",
       "// The kernel may be null if the page has been refreshed.\n",
       "if (IPython.notebook.kernel !== null) {\n",
       "    IPython.notebook.kernel.comm_manager.register_target(\n",
       "        'matplotlib',\n",
       "        mpl.mpl_figure_comm\n",
       "    );\n",
       "}\n"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<img src=\"\" width=\"1000\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(10,10))\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        idx = i*5+j\n",
    "        img = my_dataset[idx][0]\n",
    "        label = my_dataset[idx][1]\n",
    "        ax[i, j].imshow(img[0,...].detach().cpu().numpy())\n",
    "        ax[i, j].set(title=f\"Im. {idx}, true {label}\")\n",
    "        ax[i, j].set_xticklabels([])\n",
    "        ax[i, j].set_yticklabels([])\n",
    "plt.subplots_adjust(hspace=0.2,wspace=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e517975c",
   "metadata": {},
   "source": [
    "And now let us define the neural network. In PyTorch, neural networks always extend `nn.Module`. They define their sub-parts in their constructor, which are convolutional layers and fully connected linear layers in this case, and the method `forward` is expected to receive an input image and output the network target.\n",
    "\n",
    "The network parameters are the weights of the `Conv2d` and `Linear` layers, which are conveniently hidden here, but can be accessed if you try to access their `weights` elements.\n",
    "\n",
    "We will not directly output the label probabilities, since we do not actually need it to optimize the neural network: we need only the logits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d908ef86",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Network(nn.Module):\n",
    "    \"\"\"\n",
    "        This is our parametrized function.\n",
    "        It stores all the parametrized weights theta inside the conv1, conv2, fc1 and fc2 objects.\n",
    "        The forward function receives an image and outputs a vector.\n",
    "        \n",
    "        The intuition is that the i-th component of the vector represents the probability that\n",
    "        the probability that the image belongs to the i-th class however\n",
    "        we do not normalize the output to be in the range [0,1] and to sum to 1. The reason is\n",
    "        that this normalization is done later, in the training step, where the numerical error in it can be\n",
    "        minimized by calculating directly log(probability) instead of calculating first the probability\n",
    "        and then the log of it. Keep in mind therefore, that to get probabilities\n",
    "        from this object one should do F.softmax(my_network(x), dim=1).\n",
    "        \n",
    "        The code has been written like this, as this is a common optimization done in classification problems.\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        Constructor. Here we initialize the weights.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "\n",
    "        # define parameters\n",
    "        \n",
    "        # all these steps are purely linear (affine if one considers the bias)\n",
    "        # the forward function adds a non-linearity through the ReLU to allow this to do more than\n",
    "        # simple linear filters\n",
    "        \n",
    "        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)\n",
    "\n",
    "        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)\n",
    "        self.fc2 = nn.Linear(in_features=120, out_features=60)\n",
    "        self.out = nn.Linear(in_features=60, out_features=10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        This function is called when one does my_network(x) and it represents the action\n",
    "        of our parametrized function in the image, outputting the probabilities for that image as\n",
    "        a column vector. The input x has shape (B, C, H, W) (ie: batch dimension, channels, height and width).\n",
    "        The output has shape (B, K), where K is the number of classes.\n",
    "        Each row of the output has the probability for each class as a column vector.\n",
    "        Each column of the output has the probability for a single class for all images B given as an input.\n",
    "        \"\"\"\n",
    "\n",
    "        # first convolution\n",
    "        t = self.conv1(x)\n",
    "        # non-linearity\n",
    "        t = F.relu(t)\n",
    "        # reduce size of the image in width and height by taking the maximum\n",
    "        # pixel value in each 2x2 pixel matrix (kernel_size) and skipping one pixel (stride)\n",
    "        # the convolution receives one channel and outputs more\n",
    "        # the goal of the max_pool layer is to reduce the image size, so we\n",
    "        # can get more images in several channels which are smaller in size\n",
    "        # this is a trade off between memory and compute\n",
    "        t = F.max_pool2d(t, kernel_size=2, stride=2)\n",
    "\n",
    "        # second convolution\n",
    "        t = self.conv2(t)\n",
    "        # non-linearity\n",
    "        t = F.relu(t)\n",
    "        # reduce the size of the image in width and height again\n",
    "        t = F.max_pool2d(t, kernel_size=2, stride=2)\n",
    "\n",
    "        # transform images into a single vector using reshape\n",
    "        # this puts all pixel values in a single vector\n",
    "        t = t.reshape(-1, 12*4*4)\n",
    "        \n",
    "        # apply a linear transformation\n",
    "        t = self.fc1(t)\n",
    "        # add a non-linearity\n",
    "        t = F.relu(t)\n",
    "\n",
    "        # another linear transformation\n",
    "        t = self.fc2(t)\n",
    "        # another non-linearity\n",
    "        t = F.relu(t)\n",
    "\n",
    "        # final linear transformation\n",
    "        # the output of this has been set to 10 features, so the output will have the size\n",
    "        # (B, 10)\n",
    "        t = self.out(t)\n",
    "\n",
    "        # note: while we want the function to output a probability,\n",
    "        # we do not actually do any effort to normalize these numbers so that they are in [0, 1]\n",
    "        # and so that their sum is 1\n",
    "        # this would often be done by applying a transformation called Softmax(t) = exp(t)/sum(exp(t))\n",
    "        # however, this will be done internally by PyTorch in the function F.cross_entropy\n",
    "        # which we will call later on when training\n",
    "\n",
    "        return t"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c5620dc",
   "metadata": {},
   "source": [
    "Let us create one instance of this network. We also create an instance of PyTorch's `DataLoader`, which has the task of taking a given number of data elements and outputing it in a single object. This \"mini-batch\" of data is used during training, so that we do not need to load the entire data in memory during the optimization procedure.\n",
    "\n",
    "We also create an instance of the Adam optimizer, which is used to tune the parameters of the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "988e1979",
   "metadata": {},
   "outputs": [],
   "source": [
    "network = Network()\n",
    "B = 64\n",
    "loader = torch.utils.data.DataLoader(my_dataset, batch_size=B)\n",
    "optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ee54520",
   "metadata": {},
   "source": [
    "Now we actually repeatedly try to optimize the network parameters. Each time we go through all the data we have, we go through one \"epoch\". For each epoch, we take several \"mini-batches\" of data (given by the `DataLoader` in `loader`) and use it to make one training step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d15d655d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/10: average loss 0.35421\n",
      "Epoch 1/10: average loss 0.10931\n",
      "Epoch 2/10: average loss 0.07470\n",
      "Epoch 3/10: average loss 0.05767\n",
      "Epoch 4/10: average loss 0.04682\n",
      "Epoch 5/10: average loss 0.03882\n",
      "Epoch 6/10: average loss 0.03285\n",
      "Epoch 7/10: average loss 0.02714\n",
      "Epoch 8/10: average loss 0.02437\n",
      "Epoch 9/10: average loss 0.02176\n"
     ]
    }
   ],
   "source": [
    "epochs = 10\n",
    "# for each epoch\n",
    "for epoch in range(epochs):\n",
    "    losses = list()\n",
    "    # for each mini-batch given by the loader:\n",
    "    for batch in loader:\n",
    "        # get the images in the mini-batch\n",
    "        # this has size (B, C, H, W)\n",
    "        # where B is the mini-batch size\n",
    "        # C is the number of channels in the image (1 for grayscale)\n",
    "        # H is the height of the image\n",
    "        # W is the width of the image\n",
    "        images = batch[0]\n",
    "        # get the labels in the mini-batch (there shall be B of them)\n",
    "        labels = batch[1]\n",
    "        # get the output of the neural network:\n",
    "        logits = network(images)\n",
    "        \n",
    "        # note: the network does not output probabilities directly: it outputs logits\n",
    "        # to get probabilities from it we would need to do F.softmax(logits, dim=1)\n",
    "        # however, this is done inside F.cross_entropy below and we therefore should\n",
    "        # not do it twice here\n",
    "        # the reason it is done internally, in F.cross_entropy, is that what we really\n",
    "        # need is log(probability) and we can reduce the numerical error\n",
    "        # in its calculation by calculating log(softmax(.)) in one go\n",
    "        # (remember softmax(x) = exp(x)/sum(exp(x)), so log(softmax(x)) = x - log(sum(exp(x))))\n",
    "        \n",
    "        # calculate the loss function being minimized\n",
    "        # in this case, it is the cross-entropy between the logits and the true labels\n",
    "        loss = F.cross_entropy(logits, labels)\n",
    "\n",
    "        # clean the optimizer temporary gradient storage\n",
    "        optimizer.zero_grad()\n",
    "        # calculate the gradient of the loss function as a function of the gradients\n",
    "        loss.backward()\n",
    "        # ask the Adam optimizer to change the parameters in the direction of - gradient\n",
    "        # Adam scales the gradient by a constant which is adaptively tuned\n",
    "        # take a look at the Adam paper for more details: https://arxiv.org/abs/1412.6980\n",
    "        optimizer.step()\n",
    "        losses.append(loss.detach().cpu().item())\n",
    "    avg_loss = np.mean(np.array(losses))\n",
    "    print(f\"Epoch {epoch}/{epochs}: average loss {avg_loss:.5f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4980bf4",
   "metadata": {},
   "source": [
    "Let us check what the network says about some new data it has never seen before (note that we set `train` to `False`, to take a statistically independent part of the dataset)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "09646d29",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = torchvision.datasets.MNIST(\n",
    "    root = './data/MNIST',\n",
    "    train = False,\n",
    "    download = True,\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor()                                 \n",
    "    ])\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e315b5dc",
   "metadata": {},
   "source": [
    "And now we can plot again the new images, now showing what the network tells us about it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7a06a4c0",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "/* Put everything inside the global mpl namespace */\n",
       "/* global mpl */\n",
       "window.mpl = {};\n",
       "\n",
       "mpl.get_websocket_type = function () {\n",
       "    if (typeof WebSocket !== 'undefined') {\n",
       "        return WebSocket;\n",
       "    } else if (typeof MozWebSocket !== 'undefined') {\n",
       "        return MozWebSocket;\n",
       "    } else {\n",
       "        alert(\n",
       "            'Your browser does not have WebSocket support. ' +\n",
       "                'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
       "                'Firefox 4 and 5 are also supported but you ' +\n",
       "                'have to enable WebSockets in about:config.'\n",
       "        );\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n",
       "    this.id = figure_id;\n",
       "\n",
       "    this.ws = websocket;\n",
       "\n",
       "    this.supports_binary = this.ws.binaryType !== undefined;\n",
       "\n",
       "    if (!this.supports_binary) {\n",
       "        var warnings = document.getElementById('mpl-warnings');\n",
       "        if (warnings) {\n",
       "            warnings.style.display = 'block';\n",
       "            warnings.textContent =\n",
       "                'This browser does not support binary websocket messages. ' +\n",
       "                'Performance may be slow.';\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.imageObj = new Image();\n",
       "\n",
       "    this.context = undefined;\n",
       "    this.message = undefined;\n",
       "    this.canvas = undefined;\n",
       "    this.rubberband_canvas = undefined;\n",
       "    this.rubberband_context = undefined;\n",
       "    this.format_dropdown = undefined;\n",
       "\n",
       "    this.image_mode = 'full';\n",
       "\n",
       "    this.root = document.createElement('div');\n",
       "    this.root.setAttribute('style', 'display: inline-block');\n",
       "    this._root_extra_style(this.root);\n",
       "\n",
       "    parent_element.appendChild(this.root);\n",
       "\n",
       "    this._init_header(this);\n",
       "    this._init_canvas(this);\n",
       "    this._init_toolbar(this);\n",
       "\n",
       "    var fig = this;\n",
       "\n",
       "    this.waiting = false;\n",
       "\n",
       "    this.ws.onopen = function () {\n",
       "        fig.send_message('supports_binary', { value: fig.supports_binary });\n",
       "        fig.send_message('send_image_mode', {});\n",
       "        if (fig.ratio !== 1) {\n",
       "            fig.send_message('set_dpi_ratio', { dpi_ratio: fig.ratio });\n",
       "        }\n",
       "        fig.send_message('refresh', {});\n",
       "    };\n",
       "\n",
       "    this.imageObj.onload = function () {\n",
       "        if (fig.image_mode === 'full') {\n",
       "            // Full images could contain transparency (where diff images\n",
       "            // almost always do), so we need to clear the canvas so that\n",
       "            // there is no ghosting.\n",
       "            fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
       "        }\n",
       "        fig.context.drawImage(fig.imageObj, 0, 0);\n",
       "    };\n",
       "\n",
       "    this.imageObj.onunload = function () {\n",
       "        fig.ws.close();\n",
       "    };\n",
       "\n",
       "    this.ws.onmessage = this._make_on_message_function(this);\n",
       "\n",
       "    this.ondownload = ondownload;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_header = function () {\n",
       "    var titlebar = document.createElement('div');\n",
       "    titlebar.classList =\n",
       "        'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n",
       "    var titletext = document.createElement('div');\n",
       "    titletext.classList = 'ui-dialog-title';\n",
       "    titletext.setAttribute(\n",
       "        'style',\n",
       "        'width: 100%; text-align: center; padding: 3px;'\n",
       "    );\n",
       "    titlebar.appendChild(titletext);\n",
       "    this.root.appendChild(titlebar);\n",
       "    this.header = titletext;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n",
       "\n",
       "mpl.figure.prototype._init_canvas = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var canvas_div = (this.canvas_div = document.createElement('div'));\n",
       "    canvas_div.setAttribute(\n",
       "        'style',\n",
       "        'border: 1px solid #ddd;' +\n",
       "            'box-sizing: content-box;' +\n",
       "            'clear: both;' +\n",
       "            'min-height: 1px;' +\n",
       "            'min-width: 1px;' +\n",
       "            'outline: 0;' +\n",
       "            'overflow: hidden;' +\n",
       "            'position: relative;' +\n",
       "            'resize: both;'\n",
       "    );\n",
       "\n",
       "    function on_keyboard_event_closure(name) {\n",
       "        return function (event) {\n",
       "            return fig.key_event(event, name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    canvas_div.addEventListener(\n",
       "        'keydown',\n",
       "        on_keyboard_event_closure('key_press')\n",
       "    );\n",
       "    canvas_div.addEventListener(\n",
       "        'keyup',\n",
       "        on_keyboard_event_closure('key_release')\n",
       "    );\n",
       "\n",
       "    this._canvas_extra_style(canvas_div);\n",
       "    this.root.appendChild(canvas_div);\n",
       "\n",
       "    var canvas = (this.canvas = document.createElement('canvas'));\n",
       "    canvas.classList.add('mpl-canvas');\n",
       "    canvas.setAttribute('style', 'box-sizing: content-box;');\n",
       "\n",
       "    this.context = canvas.getContext('2d');\n",
       "\n",
       "    var backingStore =\n",
       "        this.context.backingStorePixelRatio ||\n",
       "        this.context.webkitBackingStorePixelRatio ||\n",
       "        this.context.mozBackingStorePixelRatio ||\n",
       "        this.context.msBackingStorePixelRatio ||\n",
       "        this.context.oBackingStorePixelRatio ||\n",
       "        this.context.backingStorePixelRatio ||\n",
       "        1;\n",
       "\n",
       "    this.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
       "\n",
       "    var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n",
       "        'canvas'\n",
       "    ));\n",
       "    rubberband_canvas.setAttribute(\n",
       "        'style',\n",
       "        'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n",
       "    );\n",
       "\n",
       "    // Apply a ponyfill if ResizeObserver is not implemented by browser.\n",
       "    if (this.ResizeObserver === undefined) {\n",
       "        if (window.ResizeObserver !== undefined) {\n",
       "            this.ResizeObserver = window.ResizeObserver;\n",
       "        } else {\n",
       "            var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n",
       "            this.ResizeObserver = obs.ResizeObserver;\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n",
       "        var nentries = entries.length;\n",
       "        for (var i = 0; i < nentries; i++) {\n",
       "            var entry = entries[i];\n",
       "            var width, height;\n",
       "            if (entry.contentBoxSize) {\n",
       "                if (entry.contentBoxSize instanceof Array) {\n",
       "                    // Chrome 84 implements new version of spec.\n",
       "                    width = entry.contentBoxSize[0].inlineSize;\n",
       "                    height = entry.contentBoxSize[0].blockSize;\n",
       "                } else {\n",
       "                    // Firefox implements old version of spec.\n",
       "                    width = entry.contentBoxSize.inlineSize;\n",
       "                    height = entry.contentBoxSize.blockSize;\n",
       "                }\n",
       "            } else {\n",
       "                // Chrome <84 implements even older version of spec.\n",
       "                width = entry.contentRect.width;\n",
       "                height = entry.contentRect.height;\n",
       "            }\n",
       "\n",
       "            // Keep the size of the canvas and rubber band canvas in sync with\n",
       "            // the canvas container.\n",
       "            if (entry.devicePixelContentBoxSize) {\n",
       "                // Chrome 84 implements new version of spec.\n",
       "                canvas.setAttribute(\n",
       "                    'width',\n",
       "                    entry.devicePixelContentBoxSize[0].inlineSize\n",
       "                );\n",
       "                canvas.setAttribute(\n",
       "                    'height',\n",
       "                    entry.devicePixelContentBoxSize[0].blockSize\n",
       "                );\n",
       "            } else {\n",
       "                canvas.setAttribute('width', width * fig.ratio);\n",
       "                canvas.setAttribute('height', height * fig.ratio);\n",
       "            }\n",
       "            canvas.setAttribute(\n",
       "                'style',\n",
       "                'width: ' + width + 'px; height: ' + height + 'px;'\n",
       "            );\n",
       "\n",
       "            rubberband_canvas.setAttribute('width', width);\n",
       "            rubberband_canvas.setAttribute('height', height);\n",
       "\n",
       "            // And update the size in Python. We ignore the initial 0/0 size\n",
       "            // that occurs as the element is placed into the DOM, which should\n",
       "            // otherwise not happen due to the minimum size styling.\n",
       "            if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n",
       "                fig.request_resize(width, height);\n",
       "            }\n",
       "        }\n",
       "    });\n",
       "    this.resizeObserverInstance.observe(canvas_div);\n",
       "\n",
       "    function on_mouse_event_closure(name) {\n",
       "        return function (event) {\n",
       "            return fig.mouse_event(event, name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mousedown',\n",
       "        on_mouse_event_closure('button_press')\n",
       "    );\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseup',\n",
       "        on_mouse_event_closure('button_release')\n",
       "    );\n",
       "    // Throttle sequential mouse events to 1 every 20ms.\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mousemove',\n",
       "        on_mouse_event_closure('motion_notify')\n",
       "    );\n",
       "\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseenter',\n",
       "        on_mouse_event_closure('figure_enter')\n",
       "    );\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseleave',\n",
       "        on_mouse_event_closure('figure_leave')\n",
       "    );\n",
       "\n",
       "    canvas_div.addEventListener('wheel', function (event) {\n",
       "        if (event.deltaY < 0) {\n",
       "            event.step = 1;\n",
       "        } else {\n",
       "            event.step = -1;\n",
       "        }\n",
       "        on_mouse_event_closure('scroll')(event);\n",
       "    });\n",
       "\n",
       "    canvas_div.appendChild(canvas);\n",
       "    canvas_div.appendChild(rubberband_canvas);\n",
       "\n",
       "    this.rubberband_context = rubberband_canvas.getContext('2d');\n",
       "    this.rubberband_context.strokeStyle = '#000000';\n",
       "\n",
       "    this._resize_canvas = function (width, height, forward) {\n",
       "        if (forward) {\n",
       "            canvas_div.style.width = width + 'px';\n",
       "            canvas_div.style.height = height + 'px';\n",
       "        }\n",
       "    };\n",
       "\n",
       "    // Disable right mouse context menu.\n",
       "    this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n",
       "        event.preventDefault();\n",
       "        return false;\n",
       "    });\n",
       "\n",
       "    function set_focus() {\n",
       "        canvas.focus();\n",
       "        canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    window.setTimeout(set_focus, 100);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var toolbar = document.createElement('div');\n",
       "    toolbar.classList = 'mpl-toolbar';\n",
       "    this.root.appendChild(toolbar);\n",
       "\n",
       "    function on_click_closure(name) {\n",
       "        return function (_event) {\n",
       "            return fig.toolbar_button_onclick(name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    function on_mouseover_closure(tooltip) {\n",
       "        return function (event) {\n",
       "            if (!event.currentTarget.disabled) {\n",
       "                return fig.toolbar_button_onmouseover(tooltip);\n",
       "            }\n",
       "        };\n",
       "    }\n",
       "\n",
       "    fig.buttons = {};\n",
       "    var buttonGroup = document.createElement('div');\n",
       "    buttonGroup.classList = 'mpl-button-group';\n",
       "    for (var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            /* Instead of a spacer, we start a new button group. */\n",
       "            if (buttonGroup.hasChildNodes()) {\n",
       "                toolbar.appendChild(buttonGroup);\n",
       "            }\n",
       "            buttonGroup = document.createElement('div');\n",
       "            buttonGroup.classList = 'mpl-button-group';\n",
       "            continue;\n",
       "        }\n",
       "\n",
       "        var button = (fig.buttons[name] = document.createElement('button'));\n",
       "        button.classList = 'mpl-widget';\n",
       "        button.setAttribute('role', 'button');\n",
       "        button.setAttribute('aria-disabled', 'false');\n",
       "        button.addEventListener('click', on_click_closure(method_name));\n",
       "        button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n",
       "\n",
       "        var icon_img = document.createElement('img');\n",
       "        icon_img.src = '_images/' + image + '.png';\n",
       "        icon_img.srcset = '_images/' + image + '_large.png 2x';\n",
       "        icon_img.alt = tooltip;\n",
       "        button.appendChild(icon_img);\n",
       "\n",
       "        buttonGroup.appendChild(button);\n",
       "    }\n",
       "\n",
       "    if (buttonGroup.hasChildNodes()) {\n",
       "        toolbar.appendChild(buttonGroup);\n",
       "    }\n",
       "\n",
       "    var fmt_picker = document.createElement('select');\n",
       "    fmt_picker.classList = 'mpl-widget';\n",
       "    toolbar.appendChild(fmt_picker);\n",
       "    this.format_dropdown = fmt_picker;\n",
       "\n",
       "    for (var ind in mpl.extensions) {\n",
       "        var fmt = mpl.extensions[ind];\n",
       "        var option = document.createElement('option');\n",
       "        option.selected = fmt === mpl.default_extension;\n",
       "        option.innerHTML = fmt;\n",
       "        fmt_picker.appendChild(option);\n",
       "    }\n",
       "\n",
       "    var status_bar = document.createElement('span');\n",
       "    status_bar.classList = 'mpl-message';\n",
       "    toolbar.appendChild(status_bar);\n",
       "    this.message = status_bar;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n",
       "    // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
       "    // which will in turn request a refresh of the image.\n",
       "    this.send_message('resize', { width: x_pixels, height: y_pixels });\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.send_message = function (type, properties) {\n",
       "    properties['type'] = type;\n",
       "    properties['figure_id'] = this.id;\n",
       "    this.ws.send(JSON.stringify(properties));\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.send_draw_message = function () {\n",
       "    if (!this.waiting) {\n",
       "        this.waiting = true;\n",
       "        this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_save = function (fig, _msg) {\n",
       "    var format_dropdown = fig.format_dropdown;\n",
       "    var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
       "    fig.ondownload(fig, format);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_resize = function (fig, msg) {\n",
       "    var size = msg['size'];\n",
       "    if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n",
       "        fig._resize_canvas(size[0], size[1], msg['forward']);\n",
       "        fig.send_message('refresh', {});\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_rubberband = function (fig, msg) {\n",
       "    var x0 = msg['x0'] / fig.ratio;\n",
       "    var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n",
       "    var x1 = msg['x1'] / fig.ratio;\n",
       "    var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n",
       "    x0 = Math.floor(x0) + 0.5;\n",
       "    y0 = Math.floor(y0) + 0.5;\n",
       "    x1 = Math.floor(x1) + 0.5;\n",
       "    y1 = Math.floor(y1) + 0.5;\n",
       "    var min_x = Math.min(x0, x1);\n",
       "    var min_y = Math.min(y0, y1);\n",
       "    var width = Math.abs(x1 - x0);\n",
       "    var height = Math.abs(y1 - y0);\n",
       "\n",
       "    fig.rubberband_context.clearRect(\n",
       "        0,\n",
       "        0,\n",
       "        fig.canvas.width / fig.ratio,\n",
       "        fig.canvas.height / fig.ratio\n",
       "    );\n",
       "\n",
       "    fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_figure_label = function (fig, msg) {\n",
       "    // Updates the figure title.\n",
       "    fig.header.textContent = msg['label'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_cursor = function (fig, msg) {\n",
       "    var cursor = msg['cursor'];\n",
       "    switch (cursor) {\n",
       "        case 0:\n",
       "            cursor = 'pointer';\n",
       "            break;\n",
       "        case 1:\n",
       "            cursor = 'default';\n",
       "            break;\n",
       "        case 2:\n",
       "            cursor = 'crosshair';\n",
       "            break;\n",
       "        case 3:\n",
       "            cursor = 'move';\n",
       "            break;\n",
       "    }\n",
       "    fig.rubberband_canvas.style.cursor = cursor;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_message = function (fig, msg) {\n",
       "    fig.message.textContent = msg['message'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_draw = function (fig, _msg) {\n",
       "    // Request the server to send over a new figure.\n",
       "    fig.send_draw_message();\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_image_mode = function (fig, msg) {\n",
       "    fig.image_mode = msg['mode'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n",
       "    for (var key in msg) {\n",
       "        if (!(key in fig.buttons)) {\n",
       "            continue;\n",
       "        }\n",
       "        fig.buttons[key].disabled = !msg[key];\n",
       "        fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n",
       "    if (msg['mode'] === 'PAN') {\n",
       "        fig.buttons['Pan'].classList.add('active');\n",
       "        fig.buttons['Zoom'].classList.remove('active');\n",
       "    } else if (msg['mode'] === 'ZOOM') {\n",
       "        fig.buttons['Pan'].classList.remove('active');\n",
       "        fig.buttons['Zoom'].classList.add('active');\n",
       "    } else {\n",
       "        fig.buttons['Pan'].classList.remove('active');\n",
       "        fig.buttons['Zoom'].classList.remove('active');\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function () {\n",
       "    // Called whenever the canvas gets updated.\n",
       "    this.send_message('ack', {});\n",
       "};\n",
       "\n",
       "// A function to construct a web socket function for onmessage handling.\n",
       "// Called in the figure constructor.\n",
       "mpl.figure.prototype._make_on_message_function = function (fig) {\n",
       "    return function socket_on_message(evt) {\n",
       "        if (evt.data instanceof Blob) {\n",
       "            /* FIXME: We get \"Resource interpreted as Image but\n",
       "             * transferred with MIME type text/plain:\" errors on\n",
       "             * Chrome.  But how to set the MIME type?  It doesn't seem\n",
       "             * to be part of the websocket stream */\n",
       "            evt.data.type = 'image/png';\n",
       "\n",
       "            /* Free the memory for the previous frames */\n",
       "            if (fig.imageObj.src) {\n",
       "                (window.URL || window.webkitURL).revokeObjectURL(\n",
       "                    fig.imageObj.src\n",
       "                );\n",
       "            }\n",
       "\n",
       "            fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
       "                evt.data\n",
       "            );\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        } else if (\n",
       "            typeof evt.data === 'string' &&\n",
       "            evt.data.slice(0, 21) === 'data:image/png;base64'\n",
       "        ) {\n",
       "            fig.imageObj.src = evt.data;\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        var msg = JSON.parse(evt.data);\n",
       "        var msg_type = msg['type'];\n",
       "\n",
       "        // Call the  \"handle_{type}\" callback, which takes\n",
       "        // the figure and JSON message as its only arguments.\n",
       "        try {\n",
       "            var callback = fig['handle_' + msg_type];\n",
       "        } catch (e) {\n",
       "            console.log(\n",
       "                \"No handler for the '\" + msg_type + \"' message type: \",\n",
       "                msg\n",
       "            );\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        if (callback) {\n",
       "            try {\n",
       "                // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
       "                callback(fig, msg);\n",
       "            } catch (e) {\n",
       "                console.log(\n",
       "                    \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n",
       "                    e,\n",
       "                    e.stack,\n",
       "                    msg\n",
       "                );\n",
       "            }\n",
       "        }\n",
       "    };\n",
       "};\n",
       "\n",
       "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
       "mpl.findpos = function (e) {\n",
       "    //this section is from http://www.quirksmode.org/js/events_properties.html\n",
       "    var targ;\n",
       "    if (!e) {\n",
       "        e = window.event;\n",
       "    }\n",
       "    if (e.target) {\n",
       "        targ = e.target;\n",
       "    } else if (e.srcElement) {\n",
       "        targ = e.srcElement;\n",
       "    }\n",
       "    if (targ.nodeType === 3) {\n",
       "        // defeat Safari bug\n",
       "        targ = targ.parentNode;\n",
       "    }\n",
       "\n",
       "    // pageX,Y are the mouse positions relative to the document\n",
       "    var boundingRect = targ.getBoundingClientRect();\n",
       "    var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n",
       "    var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n",
       "\n",
       "    return { x: x, y: y };\n",
       "};\n",
       "\n",
       "/*\n",
       " * return a copy of an object with only non-object keys\n",
       " * we need this to avoid circular references\n",
       " * http://stackoverflow.com/a/24161582/3208463\n",
       " */\n",
       "function simpleKeys(original) {\n",
       "    return Object.keys(original).reduce(function (obj, key) {\n",
       "        if (typeof original[key] !== 'object') {\n",
       "            obj[key] = original[key];\n",
       "        }\n",
       "        return obj;\n",
       "    }, {});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.mouse_event = function (event, name) {\n",
       "    var canvas_pos = mpl.findpos(event);\n",
       "\n",
       "    if (name === 'button_press') {\n",
       "        this.canvas.focus();\n",
       "        this.canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    var x = canvas_pos.x * this.ratio;\n",
       "    var y = canvas_pos.y * this.ratio;\n",
       "\n",
       "    this.send_message(name, {\n",
       "        x: x,\n",
       "        y: y,\n",
       "        button: event.button,\n",
       "        step: event.step,\n",
       "        guiEvent: simpleKeys(event),\n",
       "    });\n",
       "\n",
       "    /* This prevents the web browser from automatically changing to\n",
       "     * the text insertion cursor when the button is pressed.  We want\n",
       "     * to control all of the cursor setting manually through the\n",
       "     * 'cursor' event from matplotlib */\n",
       "    event.preventDefault();\n",
       "    return false;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function (_event, _name) {\n",
       "    // Handle any extra behaviour associated with a key event\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.key_event = function (event, name) {\n",
       "    // Prevent repeat events\n",
       "    if (name === 'key_press') {\n",
       "        if (event.which === this._key) {\n",
       "            return;\n",
       "        } else {\n",
       "            this._key = event.which;\n",
       "        }\n",
       "    }\n",
       "    if (name === 'key_release') {\n",
       "        this._key = null;\n",
       "    }\n",
       "\n",
       "    var value = '';\n",
       "    if (event.ctrlKey && event.which !== 17) {\n",
       "        value += 'ctrl+';\n",
       "    }\n",
       "    if (event.altKey && event.which !== 18) {\n",
       "        value += 'alt+';\n",
       "    }\n",
       "    if (event.shiftKey && event.which !== 16) {\n",
       "        value += 'shift+';\n",
       "    }\n",
       "\n",
       "    value += 'k';\n",
       "    value += event.which.toString();\n",
       "\n",
       "    this._key_event_extra(event, name);\n",
       "\n",
       "    this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n",
       "    return false;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onclick = function (name) {\n",
       "    if (name === 'download') {\n",
       "        this.handle_save(this, null);\n",
       "    } else {\n",
       "        this.send_message('toolbar_button', { name: name });\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n",
       "    this.message.textContent = tooltip;\n",
       "};\n",
       "\n",
       "///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n",
       "// prettier-ignore\n",
       "var _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\n",
       "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
       "\n",
       "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
       "\n",
       "mpl.default_extension = \"png\";/* global mpl */\n",
       "\n",
       "var comm_websocket_adapter = function (comm) {\n",
       "    // Create a \"websocket\"-like object which calls the given IPython comm\n",
       "    // object with the appropriate methods. Currently this is a non binary\n",
       "    // socket, so there is still some room for performance tuning.\n",
       "    var ws = {};\n",
       "\n",
       "    ws.close = function () {\n",
       "        comm.close();\n",
       "    };\n",
       "    ws.send = function (m) {\n",
       "        //console.log('sending', m);\n",
       "        comm.send(m);\n",
       "    };\n",
       "    // Register the callback with on_msg.\n",
       "    comm.on_msg(function (msg) {\n",
       "        //console.log('receiving', msg['content']['data'], msg);\n",
       "        // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
       "        ws.onmessage(msg['content']['data']);\n",
       "    });\n",
       "    return ws;\n",
       "};\n",
       "\n",
       "mpl.mpl_figure_comm = function (comm, msg) {\n",
       "    // This is the function which gets called when the mpl process\n",
       "    // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
       "\n",
       "    var id = msg.content.data.id;\n",
       "    // Get hold of the div created by the display call when the Comm\n",
       "    // socket was opened in Python.\n",
       "    var element = document.getElementById(id);\n",
       "    var ws_proxy = comm_websocket_adapter(comm);\n",
       "\n",
       "    function ondownload(figure, _format) {\n",
       "        window.open(figure.canvas.toDataURL());\n",
       "    }\n",
       "\n",
       "    var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n",
       "\n",
       "    // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
       "    // web socket which is closed, not our websocket->open comm proxy.\n",
       "    ws_proxy.onopen();\n",
       "\n",
       "    fig.parent_element = element;\n",
       "    fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
       "    if (!fig.cell_info) {\n",
       "        console.error('Failed to find cell for figure', id, fig);\n",
       "        return;\n",
       "    }\n",
       "    fig.cell_info[0].output_area.element.on(\n",
       "        'cleared',\n",
       "        { fig: fig },\n",
       "        fig._remove_fig_handler\n",
       "    );\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_close = function (fig, msg) {\n",
       "    var width = fig.canvas.width / fig.ratio;\n",
       "    fig.cell_info[0].output_area.element.off(\n",
       "        'cleared',\n",
       "        fig._remove_fig_handler\n",
       "    );\n",
       "    fig.resizeObserverInstance.unobserve(fig.canvas_div);\n",
       "\n",
       "    // Update the output cell to use the data from the current canvas.\n",
       "    fig.push_to_output();\n",
       "    var dataURL = fig.canvas.toDataURL();\n",
       "    // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
       "    // the notebook keyboard shortcuts fail.\n",
       "    IPython.keyboard_manager.enable();\n",
       "    fig.parent_element.innerHTML =\n",
       "        '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "    fig.close_ws(fig, msg);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.close_ws = function (fig, msg) {\n",
       "    fig.send_message('closing', msg);\n",
       "    // fig.ws.close()\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.push_to_output = function (_remove_interactive) {\n",
       "    // Turn the data on the canvas into data in the output cell.\n",
       "    var width = this.canvas.width / this.ratio;\n",
       "    var dataURL = this.canvas.toDataURL();\n",
       "    this.cell_info[1]['text/html'] =\n",
       "        '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function () {\n",
       "    // Tell IPython that the notebook contents must change.\n",
       "    IPython.notebook.set_dirty(true);\n",
       "    this.send_message('ack', {});\n",
       "    var fig = this;\n",
       "    // Wait a second, then push the new image to the DOM so\n",
       "    // that it is saved nicely (might be nice to debounce this).\n",
       "    setTimeout(function () {\n",
       "        fig.push_to_output();\n",
       "    }, 1000);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var toolbar = document.createElement('div');\n",
       "    toolbar.classList = 'btn-toolbar';\n",
       "    this.root.appendChild(toolbar);\n",
       "\n",
       "    function on_click_closure(name) {\n",
       "        return function (_event) {\n",
       "            return fig.toolbar_button_onclick(name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    function on_mouseover_closure(tooltip) {\n",
       "        return function (event) {\n",
       "            if (!event.currentTarget.disabled) {\n",
       "                return fig.toolbar_button_onmouseover(tooltip);\n",
       "            }\n",
       "        };\n",
       "    }\n",
       "\n",
       "    fig.buttons = {};\n",
       "    var buttonGroup = document.createElement('div');\n",
       "    buttonGroup.classList = 'btn-group';\n",
       "    var button;\n",
       "    for (var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            /* Instead of a spacer, we start a new button group. */\n",
       "            if (buttonGroup.hasChildNodes()) {\n",
       "                toolbar.appendChild(buttonGroup);\n",
       "            }\n",
       "            buttonGroup = document.createElement('div');\n",
       "            buttonGroup.classList = 'btn-group';\n",
       "            continue;\n",
       "        }\n",
       "\n",
       "        button = fig.buttons[name] = document.createElement('button');\n",
       "        button.classList = 'btn btn-default';\n",
       "        button.href = '#';\n",
       "        button.title = name;\n",
       "        button.innerHTML = '<i class=\"fa ' + image + ' fa-lg\"></i>';\n",
       "        button.addEventListener('click', on_click_closure(method_name));\n",
       "        button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n",
       "        buttonGroup.appendChild(button);\n",
       "    }\n",
       "\n",
       "    if (buttonGroup.hasChildNodes()) {\n",
       "        toolbar.appendChild(buttonGroup);\n",
       "    }\n",
       "\n",
       "    // Add the status bar.\n",
       "    var status_bar = document.createElement('span');\n",
       "    status_bar.classList = 'mpl-message pull-right';\n",
       "    toolbar.appendChild(status_bar);\n",
       "    this.message = status_bar;\n",
       "\n",
       "    // Add the close button to the window.\n",
       "    var buttongrp = document.createElement('div');\n",
       "    buttongrp.classList = 'btn-group inline pull-right';\n",
       "    button = document.createElement('button');\n",
       "    button.classList = 'btn btn-mini btn-primary';\n",
       "    button.href = '#';\n",
       "    button.title = 'Stop Interaction';\n",
       "    button.innerHTML = '<i class=\"fa fa-power-off icon-remove icon-large\"></i>';\n",
       "    button.addEventListener('click', function (_evt) {\n",
       "        fig.handle_close(fig, {});\n",
       "    });\n",
       "    button.addEventListener(\n",
       "        'mouseover',\n",
       "        on_mouseover_closure('Stop Interaction')\n",
       "    );\n",
       "    buttongrp.appendChild(button);\n",
       "    var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n",
       "    titlebar.insertBefore(buttongrp, titlebar.firstChild);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._remove_fig_handler = function (event) {\n",
       "    var fig = event.data.fig;\n",
       "    if (event.target !== this) {\n",
       "        // Ignore bubbled events from children.\n",
       "        return;\n",
       "    }\n",
       "    fig.close_ws(fig, {});\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function (el) {\n",
       "    el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function (el) {\n",
       "    // this is important to make the div 'focusable\n",
       "    el.setAttribute('tabindex', 0);\n",
       "    // reach out to IPython and tell the keyboard manager to turn it's self\n",
       "    // off when our div gets focus\n",
       "\n",
       "    // location in version 3\n",
       "    if (IPython.notebook.keyboard_manager) {\n",
       "        IPython.notebook.keyboard_manager.register_events(el);\n",
       "    } else {\n",
       "        // location in version 2\n",
       "        IPython.keyboard_manager.register_events(el);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function (event, _name) {\n",
       "    var manager = IPython.notebook.keyboard_manager;\n",
       "    if (!manager) {\n",
       "        manager = IPython.keyboard_manager;\n",
       "    }\n",
       "\n",
       "    // Check for shift+enter\n",
       "    if (event.shiftKey && event.which === 13) {\n",
       "        this.canvas_div.blur();\n",
       "        // select the cell after this one\n",
       "        var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
       "        IPython.notebook.select(index + 1);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_save = function (fig, _msg) {\n",
       "    fig.ondownload(fig, null);\n",
       "};\n",
       "\n",
       "mpl.find_output_cell = function (html_output) {\n",
       "    // Return the cell and output element which can be found *uniquely* in the notebook.\n",
       "    // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
       "    // IPython event is triggered only after the cells have been serialised, which for\n",
       "    // our purposes (turning an active figure into a static one), is too late.\n",
       "    var cells = IPython.notebook.get_cells();\n",
       "    var ncells = cells.length;\n",
       "    for (var i = 0; i < ncells; i++) {\n",
       "        var cell = cells[i];\n",
       "        if (cell.cell_type === 'code') {\n",
       "            for (var j = 0; j < cell.output_area.outputs.length; j++) {\n",
       "                var data = cell.output_area.outputs[j];\n",
       "                if (data.data) {\n",
       "                    // IPython >= 3 moved mimebundle to data attribute of output\n",
       "                    data = data.data;\n",
       "                }\n",
       "                if (data['text/html'] === html_output) {\n",
       "                    return [cell, data, j];\n",
       "                }\n",
       "            }\n",
       "        }\n",
       "    }\n",
       "};\n",
       "\n",
       "// Register the function which deals with the matplotlib target/channel.\n",
       "// The kernel may be null if the page has been refreshed.\n",
       "if (IPython.notebook.kernel !== null) {\n",
       "    IPython.notebook.kernel.comm_manager.register_target(\n",
       "        'matplotlib',\n",
       "        mpl.mpl_figure_comm\n",
       "    );\n",
       "}\n"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<img src=\"\" width=\"1000\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(10,15))\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        idx = i*5+j\n",
    "        img = test_dataset[idx][0]\n",
    "        label = test_dataset[idx][1]\n",
    "        logits = network(img[None, ...]) # output\n",
    "        probs = F.softmax(logits, dim=1) # apply softmax to normalize them\n",
    "        predicted = torch.argmax(probs[0, ...]) # index of the highest probability\n",
    "        ax[i, j].imshow(img[0,...].detach().cpu().numpy())\n",
    "        ax[i, j].set(title=f\"True {label}, pred. {predicted}\")\n",
    "        ax[i, j].set_xticklabels([])\n",
    "        ax[i, j].set_yticklabels([])\n",
    "plt.subplots_adjust(hspace=0.2,wspace=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73e79a6d",
   "metadata": {},
   "source": [
    "We can also examine the probability that the network gives for all images in the expected class. That is: what is the predicted probability of the network for all images in the true class 4?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b447df87",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=32)\n",
    "logits = list()\n",
    "label = list()\n",
    "for batch in test_loader:\n",
    "    logits += [network(batch[0])]\n",
    "    label += [batch[1]]\n",
    "logits = torch.cat(logits, dim=0)\n",
    "label = torch.cat(label, dim=0)\n",
    "probs = F.softmax(logits, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d6df7b33",
   "metadata": {},
   "outputs": [],
   "source": [
    "probs_4 = probs.detach().cpu().numpy()[label.detach().cpu().numpy() == 4]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c850c24a",
   "metadata": {},
   "source": [
    "We can histogram it and see that it did mostly a good job, but sometimes it failed. We can go forward and make a cut in the probability and look at those images to identify which images were incorrectly classified."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "93a5fe94",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "/* Put everything inside the global mpl namespace */\n",
       "/* global mpl */\n",
       "window.mpl = {};\n",
       "\n",
       "mpl.get_websocket_type = function () {\n",
       "    if (typeof WebSocket !== 'undefined') {\n",
       "        return WebSocket;\n",
       "    } else if (typeof MozWebSocket !== 'undefined') {\n",
       "        return MozWebSocket;\n",
       "    } else {\n",
       "        alert(\n",
       "            'Your browser does not have WebSocket support. ' +\n",
       "                'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
       "                'Firefox 4 and 5 are also supported but you ' +\n",
       "                'have to enable WebSockets in about:config.'\n",
       "        );\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n",
       "    this.id = figure_id;\n",
       "\n",
       "    this.ws = websocket;\n",
       "\n",
       "    this.supports_binary = this.ws.binaryType !== undefined;\n",
       "\n",
       "    if (!this.supports_binary) {\n",
       "        var warnings = document.getElementById('mpl-warnings');\n",
       "        if (warnings) {\n",
       "            warnings.style.display = 'block';\n",
       "            warnings.textContent =\n",
       "                'This browser does not support binary websocket messages. ' +\n",
       "                'Performance may be slow.';\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.imageObj = new Image();\n",
       "\n",
       "    this.context = undefined;\n",
       "    this.message = undefined;\n",
       "    this.canvas = undefined;\n",
       "    this.rubberband_canvas = undefined;\n",
       "    this.rubberband_context = undefined;\n",
       "    this.format_dropdown = undefined;\n",
       "\n",
       "    this.image_mode = 'full';\n",
       "\n",
       "    this.root = document.createElement('div');\n",
       "    this.root.setAttribute('style', 'display: inline-block');\n",
       "    this._root_extra_style(this.root);\n",
       "\n",
       "    parent_element.appendChild(this.root);\n",
       "\n",
       "    this._init_header(this);\n",
       "    this._init_canvas(this);\n",
       "    this._init_toolbar(this);\n",
       "\n",
       "    var fig = this;\n",
       "\n",
       "    this.waiting = false;\n",
       "\n",
       "    this.ws.onopen = function () {\n",
       "        fig.send_message('supports_binary', { value: fig.supports_binary });\n",
       "        fig.send_message('send_image_mode', {});\n",
       "        if (fig.ratio !== 1) {\n",
       "            fig.send_message('set_dpi_ratio', { dpi_ratio: fig.ratio });\n",
       "        }\n",
       "        fig.send_message('refresh', {});\n",
       "    };\n",
       "\n",
       "    this.imageObj.onload = function () {\n",
       "        if (fig.image_mode === 'full') {\n",
       "            // Full images could contain transparency (where diff images\n",
       "            // almost always do), so we need to clear the canvas so that\n",
       "            // there is no ghosting.\n",
       "            fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
       "        }\n",
       "        fig.context.drawImage(fig.imageObj, 0, 0);\n",
       "    };\n",
       "\n",
       "    this.imageObj.onunload = function () {\n",
       "        fig.ws.close();\n",
       "    };\n",
       "\n",
       "    this.ws.onmessage = this._make_on_message_function(this);\n",
       "\n",
       "    this.ondownload = ondownload;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_header = function () {\n",
       "    var titlebar = document.createElement('div');\n",
       "    titlebar.classList =\n",
       "        'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n",
       "    var titletext = document.createElement('div');\n",
       "    titletext.classList = 'ui-dialog-title';\n",
       "    titletext.setAttribute(\n",
       "        'style',\n",
       "        'width: 100%; text-align: center; padding: 3px;'\n",
       "    );\n",
       "    titlebar.appendChild(titletext);\n",
       "    this.root.appendChild(titlebar);\n",
       "    this.header = titletext;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n",
       "\n",
       "mpl.figure.prototype._init_canvas = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var canvas_div = (this.canvas_div = document.createElement('div'));\n",
       "    canvas_div.setAttribute(\n",
       "        'style',\n",
       "        'border: 1px solid #ddd;' +\n",
       "            'box-sizing: content-box;' +\n",
       "            'clear: both;' +\n",
       "            'min-height: 1px;' +\n",
       "            'min-width: 1px;' +\n",
       "            'outline: 0;' +\n",
       "            'overflow: hidden;' +\n",
       "            'position: relative;' +\n",
       "            'resize: both;'\n",
       "    );\n",
       "\n",
       "    function on_keyboard_event_closure(name) {\n",
       "        return function (event) {\n",
       "            return fig.key_event(event, name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    canvas_div.addEventListener(\n",
       "        'keydown',\n",
       "        on_keyboard_event_closure('key_press')\n",
       "    );\n",
       "    canvas_div.addEventListener(\n",
       "        'keyup',\n",
       "        on_keyboard_event_closure('key_release')\n",
       "    );\n",
       "\n",
       "    this._canvas_extra_style(canvas_div);\n",
       "    this.root.appendChild(canvas_div);\n",
       "\n",
       "    var canvas = (this.canvas = document.createElement('canvas'));\n",
       "    canvas.classList.add('mpl-canvas');\n",
       "    canvas.setAttribute('style', 'box-sizing: content-box;');\n",
       "\n",
       "    this.context = canvas.getContext('2d');\n",
       "\n",
       "    var backingStore =\n",
       "        this.context.backingStorePixelRatio ||\n",
       "        this.context.webkitBackingStorePixelRatio ||\n",
       "        this.context.mozBackingStorePixelRatio ||\n",
       "        this.context.msBackingStorePixelRatio ||\n",
       "        this.context.oBackingStorePixelRatio ||\n",
       "        this.context.backingStorePixelRatio ||\n",
       "        1;\n",
       "\n",
       "    this.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
       "\n",
       "    var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n",
       "        'canvas'\n",
       "    ));\n",
       "    rubberband_canvas.setAttribute(\n",
       "        'style',\n",
       "        'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n",
       "    );\n",
       "\n",
       "    // Apply a ponyfill if ResizeObserver is not implemented by browser.\n",
       "    if (this.ResizeObserver === undefined) {\n",
       "        if (window.ResizeObserver !== undefined) {\n",
       "            this.ResizeObserver = window.ResizeObserver;\n",
       "        } else {\n",
       "            var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n",
       "            this.ResizeObserver = obs.ResizeObserver;\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n",
       "        var nentries = entries.length;\n",
       "        for (var i = 0; i < nentries; i++) {\n",
       "            var entry = entries[i];\n",
       "            var width, height;\n",
       "            if (entry.contentBoxSize) {\n",
       "                if (entry.contentBoxSize instanceof Array) {\n",
       "                    // Chrome 84 implements new version of spec.\n",
       "                    width = entry.contentBoxSize[0].inlineSize;\n",
       "                    height = entry.contentBoxSize[0].blockSize;\n",
       "                } else {\n",
       "                    // Firefox implements old version of spec.\n",
       "                    width = entry.contentBoxSize.inlineSize;\n",
       "                    height = entry.contentBoxSize.blockSize;\n",
       "                }\n",
       "            } else {\n",
       "                // Chrome <84 implements even older version of spec.\n",
       "                width = entry.contentRect.width;\n",
       "                height = entry.contentRect.height;\n",
       "            }\n",
       "\n",
       "            // Keep the size of the canvas and rubber band canvas in sync with\n",
       "            // the canvas container.\n",
       "            if (entry.devicePixelContentBoxSize) {\n",
       "                // Chrome 84 implements new version of spec.\n",
       "                canvas.setAttribute(\n",
       "                    'width',\n",
       "                    entry.devicePixelContentBoxSize[0].inlineSize\n",
       "                );\n",
       "                canvas.setAttribute(\n",
       "                    'height',\n",
       "                    entry.devicePixelContentBoxSize[0].blockSize\n",
       "                );\n",
       "            } else {\n",
       "                canvas.setAttribute('width', width * fig.ratio);\n",
       "                canvas.setAttribute('height', height * fig.ratio);\n",
       "            }\n",
       "            canvas.setAttribute(\n",
       "                'style',\n",
       "                'width: ' + width + 'px; height: ' + height + 'px;'\n",
       "            );\n",
       "\n",
       "            rubberband_canvas.setAttribute('width', width);\n",
       "            rubberband_canvas.setAttribute('height', height);\n",
       "\n",
       "            // And update the size in Python. We ignore the initial 0/0 size\n",
       "            // that occurs as the element is placed into the DOM, which should\n",
       "            // otherwise not happen due to the minimum size styling.\n",
       "            if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n",
       "                fig.request_resize(width, height);\n",
       "            }\n",
       "        }\n",
       "    });\n",
       "    this.resizeObserverInstance.observe(canvas_div);\n",
       "\n",
       "    function on_mouse_event_closure(name) {\n",
       "        return function (event) {\n",
       "            return fig.mouse_event(event, name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mousedown',\n",
       "        on_mouse_event_closure('button_press')\n",
       "    );\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseup',\n",
       "        on_mouse_event_closure('button_release')\n",
       "    );\n",
       "    // Throttle sequential mouse events to 1 every 20ms.\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mousemove',\n",
       "        on_mouse_event_closure('motion_notify')\n",
       "    );\n",
       "\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseenter',\n",
       "        on_mouse_event_closure('figure_enter')\n",
       "    );\n",
       "    rubberband_canvas.addEventListener(\n",
       "        'mouseleave',\n",
       "        on_mouse_event_closure('figure_leave')\n",
       "    );\n",
       "\n",
       "    canvas_div.addEventListener('wheel', function (event) {\n",
       "        if (event.deltaY < 0) {\n",
       "            event.step = 1;\n",
       "        } else {\n",
       "            event.step = -1;\n",
       "        }\n",
       "        on_mouse_event_closure('scroll')(event);\n",
       "    });\n",
       "\n",
       "    canvas_div.appendChild(canvas);\n",
       "    canvas_div.appendChild(rubberband_canvas);\n",
       "\n",
       "    this.rubberband_context = rubberband_canvas.getContext('2d');\n",
       "    this.rubberband_context.strokeStyle = '#000000';\n",
       "\n",
       "    this._resize_canvas = function (width, height, forward) {\n",
       "        if (forward) {\n",
       "            canvas_div.style.width = width + 'px';\n",
       "            canvas_div.style.height = height + 'px';\n",
       "        }\n",
       "    };\n",
       "\n",
       "    // Disable right mouse context menu.\n",
       "    this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n",
       "        event.preventDefault();\n",
       "        return false;\n",
       "    });\n",
       "\n",
       "    function set_focus() {\n",
       "        canvas.focus();\n",
       "        canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    window.setTimeout(set_focus, 100);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var toolbar = document.createElement('div');\n",
       "    toolbar.classList = 'mpl-toolbar';\n",
       "    this.root.appendChild(toolbar);\n",
       "\n",
       "    function on_click_closure(name) {\n",
       "        return function (_event) {\n",
       "            return fig.toolbar_button_onclick(name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    function on_mouseover_closure(tooltip) {\n",
       "        return function (event) {\n",
       "            if (!event.currentTarget.disabled) {\n",
       "                return fig.toolbar_button_onmouseover(tooltip);\n",
       "            }\n",
       "        };\n",
       "    }\n",
       "\n",
       "    fig.buttons = {};\n",
       "    var buttonGroup = document.createElement('div');\n",
       "    buttonGroup.classList = 'mpl-button-group';\n",
       "    for (var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            /* Instead of a spacer, we start a new button group. */\n",
       "            if (buttonGroup.hasChildNodes()) {\n",
       "                toolbar.appendChild(buttonGroup);\n",
       "            }\n",
       "            buttonGroup = document.createElement('div');\n",
       "            buttonGroup.classList = 'mpl-button-group';\n",
       "            continue;\n",
       "        }\n",
       "\n",
       "        var button = (fig.buttons[name] = document.createElement('button'));\n",
       "        button.classList = 'mpl-widget';\n",
       "        button.setAttribute('role', 'button');\n",
       "        button.setAttribute('aria-disabled', 'false');\n",
       "        button.addEventListener('click', on_click_closure(method_name));\n",
       "        button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n",
       "\n",
       "        var icon_img = document.createElement('img');\n",
       "        icon_img.src = '_images/' + image + '.png';\n",
       "        icon_img.srcset = '_images/' + image + '_large.png 2x';\n",
       "        icon_img.alt = tooltip;\n",
       "        button.appendChild(icon_img);\n",
       "\n",
       "        buttonGroup.appendChild(button);\n",
       "    }\n",
       "\n",
       "    if (buttonGroup.hasChildNodes()) {\n",
       "        toolbar.appendChild(buttonGroup);\n",
       "    }\n",
       "\n",
       "    var fmt_picker = document.createElement('select');\n",
       "    fmt_picker.classList = 'mpl-widget';\n",
       "    toolbar.appendChild(fmt_picker);\n",
       "    this.format_dropdown = fmt_picker;\n",
       "\n",
       "    for (var ind in mpl.extensions) {\n",
       "        var fmt = mpl.extensions[ind];\n",
       "        var option = document.createElement('option');\n",
       "        option.selected = fmt === mpl.default_extension;\n",
       "        option.innerHTML = fmt;\n",
       "        fmt_picker.appendChild(option);\n",
       "    }\n",
       "\n",
       "    var status_bar = document.createElement('span');\n",
       "    status_bar.classList = 'mpl-message';\n",
       "    toolbar.appendChild(status_bar);\n",
       "    this.message = status_bar;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n",
       "    // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
       "    // which will in turn request a refresh of the image.\n",
       "    this.send_message('resize', { width: x_pixels, height: y_pixels });\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.send_message = function (type, properties) {\n",
       "    properties['type'] = type;\n",
       "    properties['figure_id'] = this.id;\n",
       "    this.ws.send(JSON.stringify(properties));\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.send_draw_message = function () {\n",
       "    if (!this.waiting) {\n",
       "        this.waiting = true;\n",
       "        this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_save = function (fig, _msg) {\n",
       "    var format_dropdown = fig.format_dropdown;\n",
       "    var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
       "    fig.ondownload(fig, format);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_resize = function (fig, msg) {\n",
       "    var size = msg['size'];\n",
       "    if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n",
       "        fig._resize_canvas(size[0], size[1], msg['forward']);\n",
       "        fig.send_message('refresh', {});\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_rubberband = function (fig, msg) {\n",
       "    var x0 = msg['x0'] / fig.ratio;\n",
       "    var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n",
       "    var x1 = msg['x1'] / fig.ratio;\n",
       "    var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n",
       "    x0 = Math.floor(x0) + 0.5;\n",
       "    y0 = Math.floor(y0) + 0.5;\n",
       "    x1 = Math.floor(x1) + 0.5;\n",
       "    y1 = Math.floor(y1) + 0.5;\n",
       "    var min_x = Math.min(x0, x1);\n",
       "    var min_y = Math.min(y0, y1);\n",
       "    var width = Math.abs(x1 - x0);\n",
       "    var height = Math.abs(y1 - y0);\n",
       "\n",
       "    fig.rubberband_context.clearRect(\n",
       "        0,\n",
       "        0,\n",
       "        fig.canvas.width / fig.ratio,\n",
       "        fig.canvas.height / fig.ratio\n",
       "    );\n",
       "\n",
       "    fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_figure_label = function (fig, msg) {\n",
       "    // Updates the figure title.\n",
       "    fig.header.textContent = msg['label'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_cursor = function (fig, msg) {\n",
       "    var cursor = msg['cursor'];\n",
       "    switch (cursor) {\n",
       "        case 0:\n",
       "            cursor = 'pointer';\n",
       "            break;\n",
       "        case 1:\n",
       "            cursor = 'default';\n",
       "            break;\n",
       "        case 2:\n",
       "            cursor = 'crosshair';\n",
       "            break;\n",
       "        case 3:\n",
       "            cursor = 'move';\n",
       "            break;\n",
       "    }\n",
       "    fig.rubberband_canvas.style.cursor = cursor;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_message = function (fig, msg) {\n",
       "    fig.message.textContent = msg['message'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_draw = function (fig, _msg) {\n",
       "    // Request the server to send over a new figure.\n",
       "    fig.send_draw_message();\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_image_mode = function (fig, msg) {\n",
       "    fig.image_mode = msg['mode'];\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n",
       "    for (var key in msg) {\n",
       "        if (!(key in fig.buttons)) {\n",
       "            continue;\n",
       "        }\n",
       "        fig.buttons[key].disabled = !msg[key];\n",
       "        fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n",
       "    if (msg['mode'] === 'PAN') {\n",
       "        fig.buttons['Pan'].classList.add('active');\n",
       "        fig.buttons['Zoom'].classList.remove('active');\n",
       "    } else if (msg['mode'] === 'ZOOM') {\n",
       "        fig.buttons['Pan'].classList.remove('active');\n",
       "        fig.buttons['Zoom'].classList.add('active');\n",
       "    } else {\n",
       "        fig.buttons['Pan'].classList.remove('active');\n",
       "        fig.buttons['Zoom'].classList.remove('active');\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function () {\n",
       "    // Called whenever the canvas gets updated.\n",
       "    this.send_message('ack', {});\n",
       "};\n",
       "\n",
       "// A function to construct a web socket function for onmessage handling.\n",
       "// Called in the figure constructor.\n",
       "mpl.figure.prototype._make_on_message_function = function (fig) {\n",
       "    return function socket_on_message(evt) {\n",
       "        if (evt.data instanceof Blob) {\n",
       "            /* FIXME: We get \"Resource interpreted as Image but\n",
       "             * transferred with MIME type text/plain:\" errors on\n",
       "             * Chrome.  But how to set the MIME type?  It doesn't seem\n",
       "             * to be part of the websocket stream */\n",
       "            evt.data.type = 'image/png';\n",
       "\n",
       "            /* Free the memory for the previous frames */\n",
       "            if (fig.imageObj.src) {\n",
       "                (window.URL || window.webkitURL).revokeObjectURL(\n",
       "                    fig.imageObj.src\n",
       "                );\n",
       "            }\n",
       "\n",
       "            fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
       "                evt.data\n",
       "            );\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        } else if (\n",
       "            typeof evt.data === 'string' &&\n",
       "            evt.data.slice(0, 21) === 'data:image/png;base64'\n",
       "        ) {\n",
       "            fig.imageObj.src = evt.data;\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        var msg = JSON.parse(evt.data);\n",
       "        var msg_type = msg['type'];\n",
       "\n",
       "        // Call the  \"handle_{type}\" callback, which takes\n",
       "        // the figure and JSON message as its only arguments.\n",
       "        try {\n",
       "            var callback = fig['handle_' + msg_type];\n",
       "        } catch (e) {\n",
       "            console.log(\n",
       "                \"No handler for the '\" + msg_type + \"' message type: \",\n",
       "                msg\n",
       "            );\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        if (callback) {\n",
       "            try {\n",
       "                // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
       "                callback(fig, msg);\n",
       "            } catch (e) {\n",
       "                console.log(\n",
       "                    \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n",
       "                    e,\n",
       "                    e.stack,\n",
       "                    msg\n",
       "                );\n",
       "            }\n",
       "        }\n",
       "    };\n",
       "};\n",
       "\n",
       "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
       "mpl.findpos = function (e) {\n",
       "    //this section is from http://www.quirksmode.org/js/events_properties.html\n",
       "    var targ;\n",
       "    if (!e) {\n",
       "        e = window.event;\n",
       "    }\n",
       "    if (e.target) {\n",
       "        targ = e.target;\n",
       "    } else if (e.srcElement) {\n",
       "        targ = e.srcElement;\n",
       "    }\n",
       "    if (targ.nodeType === 3) {\n",
       "        // defeat Safari bug\n",
       "        targ = targ.parentNode;\n",
       "    }\n",
       "\n",
       "    // pageX,Y are the mouse positions relative to the document\n",
       "    var boundingRect = targ.getBoundingClientRect();\n",
       "    var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n",
       "    var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n",
       "\n",
       "    return { x: x, y: y };\n",
       "};\n",
       "\n",
       "/*\n",
       " * return a copy of an object with only non-object keys\n",
       " * we need this to avoid circular references\n",
       " * http://stackoverflow.com/a/24161582/3208463\n",
       " */\n",
       "function simpleKeys(original) {\n",
       "    return Object.keys(original).reduce(function (obj, key) {\n",
       "        if (typeof original[key] !== 'object') {\n",
       "            obj[key] = original[key];\n",
       "        }\n",
       "        return obj;\n",
       "    }, {});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.mouse_event = function (event, name) {\n",
       "    var canvas_pos = mpl.findpos(event);\n",
       "\n",
       "    if (name === 'button_press') {\n",
       "        this.canvas.focus();\n",
       "        this.canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    var x = canvas_pos.x * this.ratio;\n",
       "    var y = canvas_pos.y * this.ratio;\n",
       "\n",
       "    this.send_message(name, {\n",
       "        x: x,\n",
       "        y: y,\n",
       "        button: event.button,\n",
       "        step: event.step,\n",
       "        guiEvent: simpleKeys(event),\n",
       "    });\n",
       "\n",
       "    /* This prevents the web browser from automatically changing to\n",
       "     * the text insertion cursor when the button is pressed.  We want\n",
       "     * to control all of the cursor setting manually through the\n",
       "     * 'cursor' event from matplotlib */\n",
       "    event.preventDefault();\n",
       "    return false;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function (_event, _name) {\n",
       "    // Handle any extra behaviour associated with a key event\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.key_event = function (event, name) {\n",
       "    // Prevent repeat events\n",
       "    if (name === 'key_press') {\n",
       "        if (event.which === this._key) {\n",
       "            return;\n",
       "        } else {\n",
       "            this._key = event.which;\n",
       "        }\n",
       "    }\n",
       "    if (name === 'key_release') {\n",
       "        this._key = null;\n",
       "    }\n",
       "\n",
       "    var value = '';\n",
       "    if (event.ctrlKey && event.which !== 17) {\n",
       "        value += 'ctrl+';\n",
       "    }\n",
       "    if (event.altKey && event.which !== 18) {\n",
       "        value += 'alt+';\n",
       "    }\n",
       "    if (event.shiftKey && event.which !== 16) {\n",
       "        value += 'shift+';\n",
       "    }\n",
       "\n",
       "    value += 'k';\n",
       "    value += event.which.toString();\n",
       "\n",
       "    this._key_event_extra(event, name);\n",
       "\n",
       "    this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n",
       "    return false;\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onclick = function (name) {\n",
       "    if (name === 'download') {\n",
       "        this.handle_save(this, null);\n",
       "    } else {\n",
       "        this.send_message('toolbar_button', { name: name });\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n",
       "    this.message.textContent = tooltip;\n",
       "};\n",
       "\n",
       "///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n",
       "// prettier-ignore\n",
       "var _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\n",
       "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
       "\n",
       "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
       "\n",
       "mpl.default_extension = \"png\";/* global mpl */\n",
       "\n",
       "var comm_websocket_adapter = function (comm) {\n",
       "    // Create a \"websocket\"-like object which calls the given IPython comm\n",
       "    // object with the appropriate methods. Currently this is a non binary\n",
       "    // socket, so there is still some room for performance tuning.\n",
       "    var ws = {};\n",
       "\n",
       "    ws.close = function () {\n",
       "        comm.close();\n",
       "    };\n",
       "    ws.send = function (m) {\n",
       "        //console.log('sending', m);\n",
       "        comm.send(m);\n",
       "    };\n",
       "    // Register the callback with on_msg.\n",
       "    comm.on_msg(function (msg) {\n",
       "        //console.log('receiving', msg['content']['data'], msg);\n",
       "        // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
       "        ws.onmessage(msg['content']['data']);\n",
       "    });\n",
       "    return ws;\n",
       "};\n",
       "\n",
       "mpl.mpl_figure_comm = function (comm, msg) {\n",
       "    // This is the function which gets called when the mpl process\n",
       "    // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
       "\n",
       "    var id = msg.content.data.id;\n",
       "    // Get hold of the div created by the display call when the Comm\n",
       "    // socket was opened in Python.\n",
       "    var element = document.getElementById(id);\n",
       "    var ws_proxy = comm_websocket_adapter(comm);\n",
       "\n",
       "    function ondownload(figure, _format) {\n",
       "        window.open(figure.canvas.toDataURL());\n",
       "    }\n",
       "\n",
       "    var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n",
       "\n",
       "    // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
       "    // web socket which is closed, not our websocket->open comm proxy.\n",
       "    ws_proxy.onopen();\n",
       "\n",
       "    fig.parent_element = element;\n",
       "    fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
       "    if (!fig.cell_info) {\n",
       "        console.error('Failed to find cell for figure', id, fig);\n",
       "        return;\n",
       "    }\n",
       "    fig.cell_info[0].output_area.element.on(\n",
       "        'cleared',\n",
       "        { fig: fig },\n",
       "        fig._remove_fig_handler\n",
       "    );\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_close = function (fig, msg) {\n",
       "    var width = fig.canvas.width / fig.ratio;\n",
       "    fig.cell_info[0].output_area.element.off(\n",
       "        'cleared',\n",
       "        fig._remove_fig_handler\n",
       "    );\n",
       "    fig.resizeObserverInstance.unobserve(fig.canvas_div);\n",
       "\n",
       "    // Update the output cell to use the data from the current canvas.\n",
       "    fig.push_to_output();\n",
       "    var dataURL = fig.canvas.toDataURL();\n",
       "    // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
       "    // the notebook keyboard shortcuts fail.\n",
       "    IPython.keyboard_manager.enable();\n",
       "    fig.parent_element.innerHTML =\n",
       "        '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "    fig.close_ws(fig, msg);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.close_ws = function (fig, msg) {\n",
       "    fig.send_message('closing', msg);\n",
       "    // fig.ws.close()\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.push_to_output = function (_remove_interactive) {\n",
       "    // Turn the data on the canvas into data in the output cell.\n",
       "    var width = this.canvas.width / this.ratio;\n",
       "    var dataURL = this.canvas.toDataURL();\n",
       "    this.cell_info[1]['text/html'] =\n",
       "        '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function () {\n",
       "    // Tell IPython that the notebook contents must change.\n",
       "    IPython.notebook.set_dirty(true);\n",
       "    this.send_message('ack', {});\n",
       "    var fig = this;\n",
       "    // Wait a second, then push the new image to the DOM so\n",
       "    // that it is saved nicely (might be nice to debounce this).\n",
       "    setTimeout(function () {\n",
       "        fig.push_to_output();\n",
       "    }, 1000);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function () {\n",
       "    var fig = this;\n",
       "\n",
       "    var toolbar = document.createElement('div');\n",
       "    toolbar.classList = 'btn-toolbar';\n",
       "    this.root.appendChild(toolbar);\n",
       "\n",
       "    function on_click_closure(name) {\n",
       "        return function (_event) {\n",
       "            return fig.toolbar_button_onclick(name);\n",
       "        };\n",
       "    }\n",
       "\n",
       "    function on_mouseover_closure(tooltip) {\n",
       "        return function (event) {\n",
       "            if (!event.currentTarget.disabled) {\n",
       "                return fig.toolbar_button_onmouseover(tooltip);\n",
       "            }\n",
       "        };\n",
       "    }\n",
       "\n",
       "    fig.buttons = {};\n",
       "    var buttonGroup = document.createElement('div');\n",
       "    buttonGroup.classList = 'btn-group';\n",
       "    var button;\n",
       "    for (var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            /* Instead of a spacer, we start a new button group. */\n",
       "            if (buttonGroup.hasChildNodes()) {\n",
       "                toolbar.appendChild(buttonGroup);\n",
       "            }\n",
       "            buttonGroup = document.createElement('div');\n",
       "            buttonGroup.classList = 'btn-group';\n",
       "            continue;\n",
       "        }\n",
       "\n",
       "        button = fig.buttons[name] = document.createElement('button');\n",
       "        button.classList = 'btn btn-default';\n",
       "        button.href = '#';\n",
       "        button.title = name;\n",
       "        button.innerHTML = '<i class=\"fa ' + image + ' fa-lg\"></i>';\n",
       "        button.addEventListener('click', on_click_closure(method_name));\n",
       "        button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n",
       "        buttonGroup.appendChild(button);\n",
       "    }\n",
       "\n",
       "    if (buttonGroup.hasChildNodes()) {\n",
       "        toolbar.appendChild(buttonGroup);\n",
       "    }\n",
       "\n",
       "    // Add the status bar.\n",
       "    var status_bar = document.createElement('span');\n",
       "    status_bar.classList = 'mpl-message pull-right';\n",
       "    toolbar.appendChild(status_bar);\n",
       "    this.message = status_bar;\n",
       "\n",
       "    // Add the close button to the window.\n",
       "    var buttongrp = document.createElement('div');\n",
       "    buttongrp.classList = 'btn-group inline pull-right';\n",
       "    button = document.createElement('button');\n",
       "    button.classList = 'btn btn-mini btn-primary';\n",
       "    button.href = '#';\n",
       "    button.title = 'Stop Interaction';\n",
       "    button.innerHTML = '<i class=\"fa fa-power-off icon-remove icon-large\"></i>';\n",
       "    button.addEventListener('click', function (_evt) {\n",
       "        fig.handle_close(fig, {});\n",
       "    });\n",
       "    button.addEventListener(\n",
       "        'mouseover',\n",
       "        on_mouseover_closure('Stop Interaction')\n",
       "    );\n",
       "    buttongrp.appendChild(button);\n",
       "    var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n",
       "    titlebar.insertBefore(buttongrp, titlebar.firstChild);\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._remove_fig_handler = function (event) {\n",
       "    var fig = event.data.fig;\n",
       "    if (event.target !== this) {\n",
       "        // Ignore bubbled events from children.\n",
       "        return;\n",
       "    }\n",
       "    fig.close_ws(fig, {});\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function (el) {\n",
       "    el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function (el) {\n",
       "    // this is important to make the div 'focusable\n",
       "    el.setAttribute('tabindex', 0);\n",
       "    // reach out to IPython and tell the keyboard manager to turn it's self\n",
       "    // off when our div gets focus\n",
       "\n",
       "    // location in version 3\n",
       "    if (IPython.notebook.keyboard_manager) {\n",
       "        IPython.notebook.keyboard_manager.register_events(el);\n",
       "    } else {\n",
       "        // location in version 2\n",
       "        IPython.keyboard_manager.register_events(el);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function (event, _name) {\n",
       "    var manager = IPython.notebook.keyboard_manager;\n",
       "    if (!manager) {\n",
       "        manager = IPython.keyboard_manager;\n",
       "    }\n",
       "\n",
       "    // Check for shift+enter\n",
       "    if (event.shiftKey && event.which === 13) {\n",
       "        this.canvas_div.blur();\n",
       "        // select the cell after this one\n",
       "        var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
       "        IPython.notebook.select(index + 1);\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_save = function (fig, _msg) {\n",
       "    fig.ondownload(fig, null);\n",
       "};\n",
       "\n",
       "mpl.find_output_cell = function (html_output) {\n",
       "    // Return the cell and output element which can be found *uniquely* in the notebook.\n",
       "    // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
       "    // IPython event is triggered only after the cells have been serialised, which for\n",
       "    // our purposes (turning an active figure into a static one), is too late.\n",
       "    var cells = IPython.notebook.get_cells();\n",
       "    var ncells = cells.length;\n",
       "    for (var i = 0; i < ncells; i++) {\n",
       "        var cell = cells[i];\n",
       "        if (cell.cell_type === 'code') {\n",
       "            for (var j = 0; j < cell.output_area.outputs.length; j++) {\n",
       "                var data = cell.output_area.outputs[j];\n",
       "                if (data.data) {\n",
       "                    // IPython >= 3 moved mimebundle to data attribute of output\n",
       "                    data = data.data;\n",
       "                }\n",
       "                if (data['text/html'] === html_output) {\n",
       "                    return [cell, data, j];\n",
       "                }\n",
       "            }\n",
       "        }\n",
       "    }\n",
       "};\n",
       "\n",
       "// Register the function which deals with the matplotlib target/channel.\n",
       "// The kernel may be null if the page has been refreshed.\n",
       "if (IPython.notebook.kernel !== null) {\n",
       "    IPython.notebook.kernel.comm_manager.register_target(\n",
       "        'matplotlib',\n",
       "        mpl.mpl_figure_comm\n",
       "    );\n",
       "}\n"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<img src=\"\" width=\"1000\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize=(10, 10))\n",
    "plt.hist(probs_4[:, 4], bins=20)\n",
    "ax.set(xlabel=r\"Probability that a true 4 belongs to class 4 according to the network (p(C$_4$|data))\",\n",
    "       ylabel=\"Number of occurences\",\n",
    "       title=\"\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bff1f8db",
   "metadata": {},
   "source": [
    "### Contact us at the EuXFEL Data Analysis group at any time if you need help analysing your data!\n",
    "\n",
    "#### Data Analysis group: da@xfel.eu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0879d703",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}