Newer
Older
{
"cells": [
{
"cell_type": "markdown",
"id": "6a764d92",
"metadata": {},
"source": [
"# Mixture Models\n",
"\n",
"One common objective is to find similarities in data and cluster it. There are several heuristic methods to cluster them. We are going to focus on a few strongly theoretically motivated methods. Other methods use heuristics to identify similarities in data and are mentioned in the end.\n",
"\n",
"For the purposes of this example, we will generate a fake dataset."
]
},
{
"cell_type": "markdown",
"id": "fce4d8e8",
"metadata": {},
"source": [
"We start by loading the necessary Python modules. If you have not yet installed them, run the following cell to install them with pip:"
]
},
{
"cell_type": "code",
Danilo Ferreira de Lima
committed
"execution_count": 1,
"id": "44ca341e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: numpy in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (1.19.2)\r\n",
"Requirement already satisfied: scikit-learn in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (0.24.2)\r\n",
"Requirement already satisfied: pandas in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (1.1.5)\r\n",
"Requirement already satisfied: matplotlib in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (3.3.4)\r\n",
"Requirement already satisfied: scipy>=0.19.1 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from scikit-learn) (1.5.2)\r\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from scikit-learn) (2.2.0)\r\n",
"Requirement already satisfied: joblib>=0.11 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from scikit-learn) (1.0.1)\r\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from pandas) (2.8.2)\r\n",
"Requirement already satisfied: pytz>=2017.2 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from pandas) (2021.3)\r\n",
"Requirement already satisfied: pillow>=6.2.0 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from matplotlib) (8.3.1)\r\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from matplotlib) (3.0.4)\r\n",
"Requirement already satisfied: cycler>=0.10 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from matplotlib) (0.11.0)\r\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from matplotlib) (1.3.1)\r\n",
"Requirement already satisfied: six>=1.5 in /home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages (from python-dateutil>=2.7.3->pandas) (1.16.0)\r\n"
]
}
],
"source": [
"!pip install numpy scikit-learn pandas matplotlib"
]
},
{
"cell_type": "code",
Danilo Ferreira de Lima
committed
"execution_count": 2,
"id": "300cf8d3",
"metadata": {},
"outputs": [],
"source": [
Danilo Ferreira de Lima
committed
"%matplotlib notebook\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.mixture import GaussianMixture, BayesianGaussianMixture\n",
"from sklearn.cluster import KMeans"
]
},
{
"cell_type": "markdown",
"id": "0ecd6a69",
"metadata": {},
"source": [
"Let's generate the fake data now to have something to cluster."
]
},
{
"cell_type": "code",
Danilo Ferreira de Lima
committed
"execution_count": 3,
"id": "4959a292",
"metadata": {},
"outputs": [],
"source": [
"def generate_clusters(mu: np.ndarray, sigma: np.ndarray, N: int) ->np.ndarray:\n",
" assert len(mu) == len(sigma)\n",
" assert N > 1\n",
" D = len(mu[0].shape)\n",
" data = np.concatenate([np.random.default_rng().multivariate_normal(mean=mu_k, cov=sigma_k, size=N)\n",
" for mu_k, sigma_k in zip(mu, sigma)], axis=0)\n",
" source = np.concatenate([k*np.ones([N, 1]) for k in range(len(mu))], axis=0)\n",
" return np.concatenate([data, source], axis=1)\n"
]
},
{
"cell_type": "code",
Danilo Ferreira de Lima
committed
"execution_count": 4,
"id": "82929490",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/daniloefl/miniconda3/envs/ml2/lib/python3.6/site-packages/ipykernel_launcher.py:6: RuntimeWarning: covariance is not positive-semidefinite.\n",
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
" \n"
]
}
],
"source": [
"data = generate_clusters(mu=[np.array([5.0, -2.0]),\n",
" np.array([1.0, 5.0]),\n",
" np.array([-5.0, -1.0])],\n",
" sigma=[np.array([[0.1, 0.2],\n",
" [0.2, 0.1]]),\n",
" np.array([[1.0, 0.5],\n",
" [0.5, 1.0]]),\n",
" np.array([[2.0, 0.0],\n",
" [0.0, 5.0]])],\n",
" N=1000)\n",
"data = pd.DataFrame(data, columns=[\"x\", \"y\", \"source\"])"
]
},
{
"cell_type": "markdown",
"id": "d8295e8a",
"metadata": {},
"source": [
"Let's print out the dataset read first."
]
},
{
"cell_type": "code",
Danilo Ferreira de Lima
committed
"execution_count": 5,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"id": "024fb65a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" <th>source</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6.128013</td>\n",
" <td>-1.708080</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3.938043</td>\n",
" <td>-2.788596</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.221406</td>\n",
" <td>-1.813124</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.530432</td>\n",
" <td>-2.052744</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4.903167</td>\n",
" <td>-1.418898</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
Loading
Loading full blame...