Skip to content
Snippets Groups Projects
Mixture Models.ipynb 1.13 MiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6a764d92",
   "metadata": {},
   "source": [
    "# Mixture Models\n",
    "\n",
    "One common objective is to find similarities in data and cluster it. There are several heuristic methods to cluster them. We are going to focus on a few strongly theoretically motivated methods. Other methods use heuristics to identify similarities in data and are mentioned in the end.\n",
    "\n",
    "For the purposes of this example, we will generate a fake dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fce4d8e8",
   "metadata": {},
   "source": [
    "We start by loading the necessary Python modules. If you have not yet installed them, run the following cell to install them with pip:"
   ]
  },
  {
   "cell_type": "code",
   "id": "44ca341e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: numpy in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (1.19.2)\r\n",
      "Requirement already satisfied: scikit-learn in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (0.24.2)\r\n",
      "Requirement already satisfied: pandas in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (1.1.5)\r\n",
      "Requirement already satisfied: matplotlib in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (3.3.4)\r\n",
      "Requirement already satisfied: scipy>=0.19.1 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from scikit-learn) (1.5.2)\r\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from scikit-learn) (2.2.0)\r\n",
      "Requirement already satisfied: joblib>=0.11 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from scikit-learn) (1.0.1)\r\n",
      "Requirement already satisfied: python-dateutil>=2.7.3 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from pandas) (2.8.2)\r\n",
      "Requirement already satisfied: pytz>=2017.2 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from pandas) (2021.3)\r\n",
      "Requirement already satisfied: pillow>=6.2.0 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from matplotlib) (8.3.1)\r\n",
      "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from matplotlib) (3.0.4)\r\n",
      "Requirement already satisfied: cycler>=0.10 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from matplotlib) (0.11.0)\r\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from matplotlib) (1.3.1)\r\n",
      "Requirement already satisfied: six>=1.5 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from python-dateutil>=2.7.3->pandas) (1.16.0)\r\n"
     ]
    }
   ],
   "source": [
    "!pip install numpy scikit-learn pandas matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "id": "300cf8d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.mixture import GaussianMixture, BayesianGaussianMixture\n",
    "from sklearn.cluster import KMeans"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ecd6a69",
   "metadata": {},
   "source": [
    "Let's generate the fake data now to have something to cluster."
   ]
  },
  {
   "cell_type": "code",
   "id": "4959a292",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_clusters(mu: np.ndarray, sigma: np.ndarray, N: int) ->np.ndarray:\n",
    "    assert len(mu) == len(sigma)\n",
    "    assert N > 1\n",
    "    D = len(mu[0].shape)\n",
    "    data = np.concatenate([np.random.default_rng().multivariate_normal(mean=mu_k, cov=sigma_k, size=N)\n",
    "                           for mu_k, sigma_k in zip(mu, sigma)], axis=0)\n",
    "    source = np.concatenate([k*np.ones([N, 1]) for k in range(len(mu))], axis=0)\n",
    "    return np.concatenate([data, source], axis=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "82929490",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages/ipykernel_launcher.py:6: RuntimeWarning: covariance is not positive-semidefinite.\n",
      "  \n"
     ]
    }
   ],
   "source": [
    "data = generate_clusters(mu=[np.array([5.0, -2.0]),\n",
    "                             np.array([1.0, 5.0]),\n",
    "                             np.array([-5.0, -1.0])],\n",
    "                         sigma=[np.array([[0.1, 0.2],\n",
    "                                          [0.2, 0.1]]),\n",
    "                                np.array([[1.0, 0.5],\n",
    "                                          [0.5, 1.0]]),\n",
    "                                np.array([[2.0, 0.0],\n",
    "                                          [0.0, 5.0]])],\n",
    "                         N=1000)\n",
    "data = pd.DataFrame(data, columns=[\"x\", \"y\", \"source\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8295e8a",
   "metadata": {},
   "source": [
    "Let's print out the dataset read first."
   ]
  },
  {
   "cell_type": "code",
   "id": "024fb65a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x</th>\n",
       "      <th>y</th>\n",
       "      <th>source</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6.128013</td>\n",
       "      <td>-1.708080</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3.938043</td>\n",
       "      <td>-2.788596</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.221406</td>\n",
       "      <td>-1.813124</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.530432</td>\n",
       "      <td>-2.052744</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4.903167</td>\n",
       "      <td>-1.418898</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
Loading
Loading full blame...