K-Means and Gaussian Mixture Models
K-means and Gaussian Mixtures are popular and effective clustering techniques. In this post I introduce the problem of clustering and motivate their use.
In machine learning, clustering encompasses a family of problems and techniques for grouping unlabeled data into categories. In contrast to classification, where our job is to learn relationships between data and their associated labels, in clustering our job is to learn the labels themselves by leveraging the patterns and structure present in the data. To make this more concrete, imagine you are given two datasets containing images of cats and dogs. The datasets are identical except that in the first dataset each image comes with a label cat
or dog
, and in the second you only are given the image. In classification our job is to learn what features of the image are most predictive of the label (cat ears, dog nose, etc.), whereas in clustering our job is to learn how to group similar images together without any pre-existing labels. That is, to identify clusters that correspond to cat
and dog
based solely on the inherent features and patterns observed in the images.
The vast majority of data is unlabeled, yet the need for effective algorithmic techniques to extract meaningful information from such data is critical for its use. From the grouping of genetic sequences to identify disease patterns, to image segmentation, uses for clustering abound. In this post I introduce the problem of clustering and derive two popular techniques: -means and Gaussian mixture models (GMM). To begin I will apply -means to the Old Faithful dataset to give a concrete example of clustering before using the techniques shortcomings to motivate GMMs.
-Means
The -means algorithm is a simple and intuitive approach to clustering that doesn’t require any probability theory to understand. Suppose we are given some data set consisting of Euclidean vectors. In our example dataset each vector is -dimensional, but -means still works in higher dimensions. Importantly, we have no other additional information. Unlike classification, we have no class labels to work with, only the raw coordinates of each data point.
We would like to partition the dataset above into sub-groups such that each data point is associated with a unique cluster . Intuitively, we would like similar data points to be clustered together and dissimilar data points to be clustered apart. It just so happens that in our dataset it is visually obvious how we should assign points to clusters, but in practice our data is rarely this nice. Our goal is essentially to categorize or classify this unlabeled data. In doing so we will depart from classification and optimize not just the boundary of the clusters, but the labels associated with each data point as well.
To begin, we’ll introduce two variables. The first is a set of vectors where is the centroid or prototype of the cluster. You can think of as the center of mass for cluster . We will also introduce an indicator variable where if the data point belongs to cluster and otherwise. Our goal then, is to optimize our set of cluster centroids and our cluster assignments such that the sum of the squared distances from each data point to its assigned centroid is minimized. In plain language, if we are going to assign a point to a particular cluster, we would like the center of that cluster to be as close as possible to the point. We can write out this objective with the following equation
Our job then is to choose the values for and that minimize
Finding
Intuitively, the value for that minimizes is simply the one which assigns its associated data point to the closest cluster center. More formally
Finding
To find the center for the cluster we can find the value for which minimizes . We’ll start by taking its derivative with respect to
Now we can set the derivative equal to and solve for
The denominator in the above equation is the number of data points assigned to the cluster, and so is simply the average of its assigned points.
Now notice that the equations for our two parameters each contains the other parameter as a variable. This motivates an iterative, algorithmic solution in which we optimize one parameter while fixing the other, and then optimize the second parameter holding the first one fixed. Our procedure will go like this:
-means Algorithm
(1) Initialize our cluster centers
(2) Compute our assignments with our current value for
(3) Using our new assignments, compute new cluster centers
(4) Repeats steps (2) and (3) until changes in between steps are negligible.
Implementation
As you can see, for this dataset the algorithm converges extremely quickly and we get a decent clustering in only a few optimization steps. In practice -means is often a good first choice for clustering, but the technique does have some serious drawbacks.
Perhaps the most obvious drawback is that we have no principled way of choosing the number of clusters . In fact a larger number of clusters will always result in a lower value of . In our toy example this isn’t a problem, largely because our data is 2-dimensional and so we can visualize it, but if our data were to be higher dimensional than this would be more difficult. There does exist some metrics for determining the optimal cluster number, but these have their flaws as well. In addition, -means can be quite sensitive to the initialization of , particularly when our data is not as easily separable as our toy example. There does exist initialization strategies which help to alleviate this, but there are more problems. One is that our objective function models our clusters as equally sized spheres, an assumption which will not hold in many (maybe most) circumstances. Finally, -means gives us back hard cluster assignments, when sometimes we’d like to have a confidence or probability that a data point belongs to a particular cluster.
Luckily, our next technique will allow us to address almost all of these shortcomings.
Gaussian Mixture Model
Gaussian mixture models (GMM) are a powerful probabilistic tool for solving the clustering problem that will help to ameliorate many of the problems we faced with -means. A GMM is a just a linear combination of Gaussian distributions where each component is weighted by some scalar such that . This constraint allows us to interpret the combination itself as a probability distribution. We will interpret each Gaussian component as the distribution of a particular cluster. For any one data point the GMM says it has a probability given by Where and are the mean and covariance of the component respectively.
To start we will introduce a set of variables just like we did with -means, where indicates the cluster identity for the data point. Constructing a set of unobserved variables like this is extremely common in unsupervised learning problems. Often times we refer to these variables as latent variables. We said before that we interpret our mixture components as our clusters, and so we have
This just means that the probability a data point comes from cluster is equal to its weighting in the mixture. We can also write out the distribution of conditioned on the value of .
This is just the distribution of the cluster.
To fit this model to our data we will rely on maximum likelihood estimation (MLE) which I wrote about in my previous post. MLE will allow us to determine the optimal values for our model parameters . To start we first must construct our likelihood function
We can use the sum and product rule to decompose our likelihood into the marginal distribution over and the condition distribution of given . This is useful because we already wrote down the forms of these above.
Now we can take the log of this expression to get our log-likelihood
The MLE procedure now tells us to find the values for and that maximize . First let’s find the maximum likelihood estimate for .
Finding
We’ll start by taking the derivative of with respect to
Here we can apply the chain rule A2
That last step is just another application of the chain rule, as well as equation 86 from the matrix cookbook.
Now to clean things up a bit let’s define a new parameter which we will define as
This quantity will appear a lot going forward and abstracting it this way will help us notationally, however this term also has a rich meaning: it is the conditional probability of given . To see this we can use Bayes Rule
This understanding allows us to view as the prior probability that before observing , and as the posterior probability that after we observe . is also often referred to as the responsibility because it can be viewed as the ability of the cluster to explain the value of . Also, notice that . This should make sense because in our model each data point has to come from a cluster.
Now we can plug into our derivative and set it equal to to solve for
So after all of that, we have this nice result that the center of cluster is a weighted average of the data where the weights are the responsibilities or the ability of that cluster to explain the data. The denominator is the sum of the responsibilities of the cluster across every data point, and so you can think about it almost like the effective number of data points assigned to cluster .
I will leave derivations of the maximum likelihood estimates for and in the appendix, but they too have nice interpretations.
Like , our covariance estimate is a weighted empirical covariance of the data where the weights are the responsibilities of the cluster, whilst is the effective fraction of the data assigned to cluster .
Notice that, as was the case in -means, our values for are dependent on our parameters, whilst our parameters are themselves dependent on . Like with -means, this motivates an iterative approach in which we alternate between computing the responsibilities and optimizing the parameters. The procedure will go like this:
GMM Algorithm
(1) Initialize our parameters
(2) Compute responsibilities
(3) Using our new responsibilities, update out parameters
(4) Repeats steps (2) and (3) until changes in between steps are negligible.
Implementation
The GMM helps to solve many of the problems we faced with -means. Being a probabalistic model means that we gain access to methods for cross validating our choice of . Unlike -means, increasing our number of clusters will not allow us to artbitrarily increase our log-likelihood. We also gain flexibility in the shape that our clusters can take, though we are still limited to Gaussian ellipsoids. And finally we can quantify our uncertainty for outlier data points which don’t fall cleanly into any one cluster. In terms of initialization, this is still a shortcoming of the GMM as there aren’t many principled ways to ensure good results. In practice we usually run -means++ first, and use this as our cluster initializations for the GMM and this usually works well.
Beyond being two of the most popular clustering techniques, I chose to write about -means and GMMs because they are often used to introduce the Expectation Maximization (EM) algorithm. I didn’t explicitly mention the EM algorithm in this post but briefly it is an algorithm which allows us to perform MLE with latent variable models. I hope to cover the EM algorithm in a future post when my understanding of it is better, and when I do I will link it here.
Appendix
def fit(X, K=2):
def J(assignments, mu):
j = 0
for k in range(K):
j += (np.linalg.norm(X[assignments == k] - mu[k], axis=1) ** 2).sum()
return j
def assign_clusters(mu):
distances = np.zeros(shape=(K, len(X)))
for k in range(K):
distances[k] = np.linalg.norm(X - mu[k], axis=1)
return distances.argmin(axis=0)
def update_mu(assignments, mu):
new_mu = np.zeros_like(mu)
for k in range(K):
new_mu[k] = X[assignments == k].mean()
return new_mu
mu_updates = []
assignment_updates = []
j_scores = []
old_mu = np.array([[1.0, -1.5],
[-1.0, 1.5]])
old_assignments = assign_clusters(old_mu)
old_j = np.inf
mu_updates.append(old_mu)
assignment_updates.append(old_assignments)
j_scores.append(J(old_assignments, old_mu))
while True:
new_mu = update_mu(old_assignments, old_mu)
j_scores.append(J(old_assignments, new_mu))
new_assignments = assign_clusters(new_mu)
new_j = J(new_assignments, new_mu)
j_scores.append(new_j)
if abs(new_j - old_j) < 1e-4:
break
mu_updates.append(new_mu)
assignment_updates.append(new_assignments)
old_mu, old_assignments, old_j = new_mu, new_assignments, new_j
return mu_updates, assignment_updates,a j_scores
A2
Chain rule application in GMM derivation.
This sometimes gets called the “log-derivative trick”, which isn’t really a trick, but can be cleverly used to help compute difficult expectations. Maybe I will write a post about it one day.
A3
MLE for covariance matrices and mixture weights.
A4
GMM implementation
from scipy.stats import multivariate_normal
outer = np.outer
def fit_gmm(X, K=2):
def compute_LL():
log_probs = np.zeros((len(X), K))
for k in range(K):
log_probs[:, k] += multivariate_normal(mu[:, k], sigma[..., k]).pdf(X) * pi[k]
log_probs = np.log(log_probs.sum(1))
return log_probs.sum()
def compute_gamma():
gamma = np.zeros((len(X), K))
for k in range(K):
gamma[:, k] = multivariate_normal(mu[:, k], sigma[..., k]).pdf(X) * pi[k]
gamma /= gamma.sum(axis=1, keepdims=True)
return gamma
def compute_Nk():
return np.sum(gamma, axis=0)
def compute_mu():
mu = np.zeros((2, K))
for k in range(K):
mu[:, k] = (gamma[:, k, np.newaxis] * X).sum(axis=0) / N_k[k]
return mu
def compute_sigma():
sigma = np.zeros((2, 2, K))
for k in range(K):
for n in range(len(X)):
sigma[..., k] += (gamma[n, k] * outer(X[n] - mu[:, k], X[n] - mu[:, k])) / N_k[k]
return sigma
def compute_pi():
return N_k / len(X)
LLs, mus, sigmas, pis, gammas = [], [], [], [], []
mu = np.array([[1.2, -1.5],
[-2.0, 1.5]])
sigma = np.zeros((2, 2, 2))
sigma[..., 0] = np.eye(2) * 0.1
sigma[..., 1] = np.eye(2) * 0.1
pi = np.array([0.5, 0.5])
mus.append(mu)
sigmas.append(sigma)
pis.append(pi)
LL_old = -np.inf
while True:
gamma = compute_gamma()
N_k = compute_Nk()
mu = compute_mu()
sigma = compute_sigma()
pi = compute_pi()
LL_new = compute_LL()
mus.append(mu)
sigmas.append(sigma)
pis.append(pi)
gammas.append(gamma)
LLs.append(LL_new)
if LL_new - LL_old < 1e-9:
gammas.append(compute_gamma())
break
else:
LL_old = LL_new
return mus, sigmas, pis, gammas, LLs