K-Means and Gaussian Mixture Models

K-means and Gaussian Mixtures are popular and effective clustering techniques. In this post I introduce the problem of clustering and motivate their use.

In machine learning, clustering encompasses a family of problems and techniques for grouping unlabeled data into categories. In contrast to classification, where our job is to learn relationships between data and their associated labels, in clustering our job is to learn the labels themselves by leveraging the patterns and structure present in the data. To make this more concrete, imagine you are given two datasets containing images of cats and dogs. The datasets are identical except that in the first dataset each image comes with a label cat or dog, and in the second you only are given the image. In classification our job is to learn what features of the image are most predictive of the label (cat ears, dog nose, etc.), whereas in clustering our job is to learn how to group similar images together without any pre-existing labels. That is, to identify clusters that correspond to cat and dog based solely on the inherent features and patterns observed in the images.

The vast majority of categorical data is unlabeled, yet the need for effective algorithmic techniques to extract meaningful information from such data is critical for its use. From the grouping of genetic sequences to identify disease patterns, to image segmentation, uses for clustering abound. In this post I introduce the problem of clustering and derive two popular techniques: KK-means and Gaussian mixture models (GMM). To begin I will apply KK-means to the Old Faithful dataset to give a concrete example of clustering before using the techniques shortcomings to motivate GMMs.


The KK-means algorithm is a simple and intuitive approach to clustering that doesn’t require any probability theory to understand. Suppose we are given some data set D\mathcal{D} consisting of NN Euclidean vectors. D={x1xnxN} \mathcal{D} = \{\mathbf{x}_1 \dots \mathbf{x}_n \dots \mathbf{x}_N\} In our example dataset each vector is 22-dimensional, but KK-means still works in higher dimensions. Importantly, we have no other additional information. Unlike classification, we have no class labels to work with, only the raw coordinates of each data point.

Figure 1. The normalized old faithful dataset. In the unmodified data the axes represent the duration of an eruption and the time till the next eruption for the Old Faithful geyser in Yellowstone National Park.

We would like to partition the dataset above into KK sub-groups such that each data point xn\mathbf{x}_n is associated with a unique cluster kk. Intuitively, we would like similar data points to be clustered together and dissimilar data points to be clustered apart. It just so happens that in our dataset it is visually obvious how we should assign points to clusters, but in practice our data is rarely this nice. Our goal is essentially to categorize or classify this unlabeled data. In doing so we will depart from classification and optimize not just the boundary of the clusters, but the labels associated with each data point as well.

To begin, we’ll introduce two variables. The first is a set of KK vectors {μ1,,μK}\{\boldsymbol{\mu}_1, \dots, \boldsymbol{\mu}_K\} where μk\boldsymbol{\mu}_k is the centroid or prototype of the kthk^{th} cluster. You can think of μk\boldsymbol{\mu}_k as the center of mass for cluster kk. We will also introduce an indicator variable zn,kz_{n,k} where zn,k=1z_{n,k} = 1 if the nthn^{th} data point belongs to cluster kk and 00 otherwise. Our goal then, is to optimize our set of cluster centroids {μk}\{\boldsymbol{\mu}_k\} and our cluster assignments zn,kz_{n, k} such that the sum of the squared distances from each data point to its assigned centroid is minimized. In plain language, if we are going to assign a point to a particular cluster, we would like the center of that cluster to be as close as possible to the point. We can write out this objective with the following equation J=n=1Nk=1Kzn,kxnμk22 J = \sum_{n = 1}^N\sum_{k = 1}^Kz_{n, k}\Vert\mathbf{x}_n - \boldsymbol{\mu}_k\Vert^2_2 Our job then is to choose the values for zn,kz_{n, k} and μk\boldsymbol{\mu}_k that minimize JJ arg minrn,k,μkn=1Nk=1Kzn,kxnμk22 \argmin_{r_{n, k}, \boldsymbol{\mu}_{k}}\sum_{n = 1}^N\sum_{k = 1}^Kz_{n, k}\Vert\mathbf{x}_n - \boldsymbol{\mu}_k\Vert^2_2

Finding zn,kz_{n, k}

Intuitively, the value for zn,kz_{n, k} that minimizes JJ is simply the one which assigns its associated data point to the closest cluster center. More formally zn,k={1if k=arg minjxnμj220otherwise z_{n, k} = \begin{cases} 1 & \text{if } k = \argmin_j \Vert \mathbf{x}_n - \boldsymbol{\mu}_j \Vert^2_2 \\ 0 & \text{otherwise} \end{cases}

Finding μk\boldsymbol{\mu}_k

To find the values for μk\boldsymbol{\mu}_k that minimize JJ we’ll start by taking its derivative with respect to μk\boldsymbol{\mu}_k Jμk=μkn=1Nk=1Kzn,kxnμk22=μkn=1Nzn,kxnμk22=μkn=1Nzn,k(xnμk)(xnμk)=2n=1Nzn,k(xnμk) \begin{aligned} \frac{\partial J}{\partial \boldsymbol{\mu}_k} &= \frac{\partial}{\partial\boldsymbol{\mu}_k}\sum_{n = 1}^N\sum_{k = 1}^Kz_{n, k}\Vert\mathbf{x}_n - \boldsymbol{\mu}_k\Vert^2_2\\ &= \frac{\partial}{\partial\boldsymbol{\mu}_k}\sum_{n=1}^N z_{n, k}\Vert\mathbf{x}_n - \boldsymbol{\mu}_k\Vert^2_2\\ &= \frac{\partial}{\partial\boldsymbol{\mu}_k}\sum_{n=1}^N z_{n, k}(\mathbf{x}_n - \boldsymbol{\mu}_k)^\top(\mathbf{x}_n - \boldsymbol{\mu}_k)\\ &= -2\sum_{n=1}^N z_{n, k}(\mathbf{x}_n - \boldsymbol{\mu}_k) \end{aligned} Now we can set the derivative equal to 00 and solve for μk\boldsymbol{\mu}_k 0=2n=1Nzn,k(xnμk)=n=1Nzn,kxnzn,kμkn=1Nzn,kμk=n=1Nzn,kxnμkn=1Nzn,k=n=1Nzn,kxnμk=n=1Nzn,kxnn=1Nzn,k \begin{aligned} 0 &= -2\sum_{n=1}^N z_{n, k}(\mathbf{x}_n - \boldsymbol{\mu}_k) \\ &= \sum_{n=1}^N z_{n, k}\mathbf{x}_n - z_{n, k}\boldsymbol{\mu}_k \\ \sum_{n=1}^N z_{n, k}\boldsymbol{\mu}_k &= \sum_{n=1}^N z_{n, k}\mathbf{x}_n \\ \boldsymbol{\mu}_k \sum_{n=1}^N z_{n, k} &= \sum_{n=1}^N z_{n, k}\mathbf{x}_n \\ \boldsymbol{\mu}_k &= \frac{\sum_{n=1}^N z_{n, k}\mathbf{x}_n}{\sum_{n=1}^N z_{n, k}} \end{aligned} The denominator in the above equation is the number of data points assigned to the kthk^{th} cluster, and so μk\boldsymbol{\mu}_k is simply the average of its assigned points.

Now notice that the equations for our two parameters each contains the other parameter as a variable. This motivates an iterative, algorithmic solution in which we optimize one parameter while fixing the other, and then optimize the second parameter holding the first one fixed. Our procedure will go like this:

KK-means Algorithm

(1) Initialize our cluster centers μk\boldsymbol{\mu}_k
(2) Compute our assignments zn,kz_{n, k} with our current value for μk\boldsymbol{\mu}_k
(3) Using our new assignments, compute new cluster centers μk\boldsymbol{\mu}_k
(4) Repeats steps (2) and (3) until changes in JJ between steps are negligible.

Figure 2. Iterations of the KK-means algorithm on the rescaled Old Faithful data set
Figure 3. KK-means objective function during fitting.

As you can see, for this dataset the algorithm converges extremely quickly and we get a decent clustering in only a few optimization steps. In practice KK-means is often a good first choice for clustering, but the technique does have some serious drawbacks.

Perhaps the most obvious drawback is that we have no principled way of choosing the number of clusters KK. In fact a larger number of clusters will always result in a lower value of JJ. In our toy example this isn’t a problem, largely because our data is 2-dimensional and so we can visualize it, but if our data were to be higher dimensional than this would be more difficult. There does exist some metrics for determining the optimal cluster number, but these have their flaws as well. In addition, KK-means can be quite sensitive to the initialization of {μk}\{\boldsymbol{\mu}_k\}, particularly when our data is not as easily separable as our toy example. There does exist initialization strategies which help to alleviate this, but there are more problems. One is that our objective function JJ models our clusters as equally sized spheres, an assumption which will not hold in many (maybe most) circumstances. Finally, KK-means gives us back hard cluster assignments, when sometimes we’d like to have a confidence or probability that a data point belongs to a particular cluster.

Luckily, our next technique will allow us to address almost all of these shortcomings.

Gaussian Mixture Model

Gaussian mixture models (GMM) are a powerful probabilistic tool for solving the clustering problem that will help to ameliorate many of the problems we faced with KK-means. A GMM is a just a linear combination of KK Gaussian distributions where each component is weighted by some scalar πk\pi_k such that k=1Kπk=1\sum_{k=1}^K\pi_k = 1. This constraint allows us to interpret the combination itself as a probability distribution. We will interpret each Gaussian component as the distribution of a particular cluster. For any one data point x\mathbf{x} the GMM says it has a probability given by p(x)=k=1KπkN(xμk,Σk) p(\mathbf{x}) = \sum_{k=1}^K \pi_k \cdot \mathcal{N}(\mathbf{x}\mid \boldsymbol{\mu}_k, \bold{\Sigma}_k) Where μk\boldsymbol{\mu}_k and Σk\bold{\Sigma}_k are the mean and covariance of the kthk^{th} component respectively.

Figure 4. Example GMM with two components

To start we will introduce a set of variables znz_n just like we did with KK-means, where zn{1,,K}z_n \in \{1, \dots, K\} indicates the cluster identity for the nthn^{th} data point. Constructing a set of unobserved variables like this is extremely common in unsupervised learning problems. Often times we refer to these zz variables as latent variables. We said before that we interpret our mixture components as our clusters, and so we have p(z=k)=πk p(z = k) = \pi_k This just means that the probability a data point comes from cluster kk is equal to its weighting in the mixture. We can also write out the distribution of x\mathbf{x} conditioned on the value of zz. p(xz=k)=N(xμk,Σk) p(\mathbf{x} \mid z=k) = \mathcal{N}(\mathbf{x}\mid \boldsymbol{\mu}_k, \bold{\Sigma}_k) This is just the distribution of the kthk^{th} cluster.

To fit this model to our data we will rely on maximum likelihood estimation (MLE) which I wrote about in my previous post. MLE will allow us to determine the optimal values for our model parameters Θ={πk,μk,Σk}\bold{\Theta} = \{\pi_k, \boldsymbol{\mu}_k, \bold{\Sigma}_k\}. To start we first must construct our likelihood function p({x}Θ)=n=1Np(xn) p(\{\mathbf{x}\}\mid \bold{\Theta}) = \prod_{n=1}^N p(\mathbf{x}_n) Because of how we formulated our latent variables we can rewrite p(xn)p(\mathbf{x}_n) as a marginal of the joint distribution between zz and x\mathbf{x}. n=1Np(xn)=n=1Nk=1Kp(xn,z=k)=n=1Nk=1Kp(z=k)p(xnz=k)=n=1Nk=1KπkN(xnμk,Σk) \begin{aligned} \prod_{n=1}^N p(\mathbf{x}_n) &= \prod_{n=1}^N \sum_{k=1}^K p(\mathbf{x}_n, z=k) \\ &= \prod_{n=1}^N\sum_{k=1}^K p(z=k)p(\mathbf{x}_n\mid z=k) \\ &= \prod_{n=1}^N\sum_{k=1}^K \pi_k \cdot \mathcal{N}(\mathbf{x}_n\mid \boldsymbol{\mu}_k, \bold{\Sigma}_k) \end{aligned} Now we can take the log of this expression to get our log-likelihood L=logp({x}Θ)=n=1Nlog[k=1KπkN(xnμk,Σk)] \mathcal{L} = \log p(\{\mathbf{x}\}\mid \bold{\Theta}) = \sum_{n=1}^N \log \left[ \sum_{k=1}^K \pi_k \cdot \mathcal{N}(\mathbf{x}_n\mid \boldsymbol{\mu}_k, \bold{\Sigma}_k)\right] The MLE procedure now tells us to find the values for {πk},{μk}\{\pi_k\}, \{\boldsymbol{\mu}_k\} and {Σk}\{\bold{\Sigma}_k\} that maximize L\mathcal{L}. First let’s find the maximum likelihood estimate for μk\boldsymbol{\mu}_k.

Finding μk\boldsymbol{\mu}_k

We’ll start by taking the derivative of L\mathcal{L} with respect to μk\boldsymbol{\mu}_k Lμk=μkn=1Nlog[k=1KπkN(xnμk,Σk)]=n=1Nμklog[k=1KπkN(xnμk,Σk)] \begin{aligned} \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mu}_k} &= \frac{\partial}{\partial\boldsymbol{\mu}_k}\sum_{n=1}^N \log \left[ \sum_{k=1}^K \pi_k \cdot \mathcal{N}(\mathbf{x}_n\mid \boldsymbol{\mu}_k, \bold{\Sigma}_k)\right]\\ &=\sum_{n=1}^N \frac{\partial}{\partial\boldsymbol{\mu}_k} \log \left[ \sum_{k=1}^K \pi_k \cdot \mathcal{N}(\mathbf{x}_n\mid \boldsymbol{\mu}_k, \bold{\Sigma}_k)\right] \end{aligned} Here we can apply the chain rule A2 Lμk=n=1Nμk(k=1KπkN(xnμk,Σk))k=1KπkN(xnμk,Σk)=n=1NπkN(xnμk,Σk)Σk1(xnμk)j=1KπjN(xnμj,Σj) \begin{aligned} \frac{\partial\mathcal{L}}{\partial\boldsymbol{\mu}_k} &= \sum_{n=1}^N \frac{\frac{\partial}{\partial\boldsymbol{\mu}_k} \Bigl(\sum_{k=1}^K\pi_k\cdot\mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_k, \bold{\Sigma}_k)\Bigr)} {\sum_{k=1}^K\pi_k\cdot\mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_k, \bold{\Sigma}_k)} \\ &=\sum_{n=1}^N\frac{\pi_k\cdot\mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_k, \bold{\Sigma}_k)\cdot\bold{\Sigma}_k^{-1}(\mathbf{x}_n - \boldsymbol{\mu}_k)} {\sum_{j=1}^K\pi_j\cdot\mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_j, \bold{\Sigma}_j)} \end{aligned} Refer to A3 for that last step.

Now to clean things up a bit let’s define a new parameter γn,k\gamma_{n, k} which we will define as γn,k=πkN(xnμk,Σk)j=1KπjN(xnμj,Σj) \gamma_{n, k} = \frac{\pi_k\cdot\mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_k, \bold{\Sigma}_k)}{\sum_{j=1}^K\pi_j\cdot\mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_j, \bold{\Sigma}_j)} This quantity will appear a lot going forward and abstracting it this way will help us notationally, however this term also has a rich meaning: it is the conditional probability of zz given x\mathbf{x}. To see this we can use Bayes Rule p(zn=kxn)=p(zn=k)p(xnzn=k)j=1Kp(zn=j)p(xnzn=j)=πkN(xnμk,Σk)j=1KπjN(xnμj,Σj)=γn,k \begin{aligned} p(z_n = k \mid \mathbf{x}_n) &= \frac{p(z_n = k)p(\mathbf{x}_n \mid z_n=k)}{\sum_{j=1}^K p(z_n = j)p(\mathbf{x}_n \mid z_n = j)}\\ &=\frac{\pi_k\cdot\mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_k, \bold{\Sigma}_k)}{\sum_{j=1}^K\pi_j\cdot\mathcal{N}(\mathbf{x}_n \mid \boldsymbol{\mu}_j, \bold{\Sigma}_j)} \\ &=\gamma_{n, k} \end{aligned} This understanding allows us to view πk\pi_k as the prior probability that zn=kz_n = k before observing xn\mathbf{x}_n, and γn,k\gamma_{n, k} as the posterior probability that zn=kz_n = k after we observe xn\mathbf{x}_n. γn,k\gamma_{n, k} is also often referred to as the responsibility because it can be viewed as the ability of the kthk^{th} cluster to explain the value of xn\mathbf{x}_n. Also, notice that k=1Kγn,k=1\sum_{k=1}^K \gamma_{n, k} = 1. This should make sense because in our model each data point has to come from a cluster.

Now we can plug γn,k\gamma_{n, k} into our derivative and set it equal to 00 to solve for μk\boldsymbol{\mu}_k 0=n=1Nγn,kΣk1(xnμk)=Σk1n=1Nγn,kxnγn,kμkn=1Nγn,kμk=n=1Nγn,kxnμk=n=1Nγn,kxnn=1Nγn,kμk=n=1Nγn,kxnNk \begin{aligned} 0 &= \sum_{n=1}^N \gamma_{n, k} \bold{\Sigma}_k^{-1}(\mathbf{x}_n - \boldsymbol{\mu}_k)\\ &=\bold{\Sigma}_k^{-1} \sum_{n=1}^N \gamma_{n, k}\mathbf{x}_n - \gamma_{n, k}\boldsymbol{\mu}_k\\ \sum_{n=1}^N \gamma_{n, k}\boldsymbol{\mu}_k &= \sum_{n=1}^N \gamma_{n, k}\mathbf{x}_n\\ \boldsymbol{\mu}_k &= \frac{\sum_{n=1}^N \gamma_{n, k}\mathbf{x}_n}{\sum_{n=1}^N \gamma_{n, k}}\\ \boldsymbol{\mu}_k &= \frac{\sum_{n=1}^N \gamma_{n, k}\mathbf{x}_n}{N_k}\\ \end{aligned} So after all of that, we have this nice result that the center of cluster kk is a weighted average of the data where the weights are the responsibilities or the ability of that cluster to explain the data. The denominator is the sum of the responsibility of the kthk^{th} cluster across every data point, and so you can think about it almost like the effective number of data points assigned to cluster kk.

I will leave derivations of the maximum likelihood estimates for πk\pi_k and Σk\bold{\Sigma}_k in the appendix, but they too have nice interpretations. Σk=1Nkn=1Nγn,k(xnμk)(xnμk)πk=NkN \bold{\Sigma}_k = \frac{1}{N_k}\sum_{n=1}^N\gamma_{n, k}(\mathbf{x}_n - \boldsymbol{\mu}_k)(\mathbf{x}_n - \boldsymbol{\mu}_k)^\top\qquad \pi_k = \frac{N_k}{N} Like μk\boldsymbol{\mu}_k, our covariance estimate is a weighted empirical covariance of the data where the weights are the responsibilities of the kthk^{th} cluster, whilst πk\pi_k is the effective fraction of the data assigned to cluster kk.

Notice that, as was the case in KK-means, our values for γn,k\gamma_{n, k} are dependent on our parameters, whilst our parameters are themselves dependent on γn,k\gamma_{n, k}. Like with KK-means this motivates an iterative approach in which we alternate between computing the responsibilities and optimizing the parameters. The procedure will go like this:

GMM Algorithm

(1) Initialize our parameters μk,Σk,πk\boldsymbol{\mu}_k, \bold{\Sigma}_k, \pi_k
(2) Compute responsibilities γn,k\gamma_{n, k}
(3) Using our new responsibilities, update out parameters
(4) Repeats steps (2) and (3) until changes in L\mathcal{L} between steps are negligible.

Figure 2. Iterations of our GMM fitting procedure.

The GMM helps to solve many of the problems we faced with KK-means. Being a probabalistic model means that we gain access to methods for cross validating our choice of KK. Unlike KK-means, increasing our number of clusters will not allow us to artbitrarily increase our log-likelihood. We also gain flexibility in the shape that our clusters can take, though we are still limited to Gaussian ellipsoids. And finally we can quantify our uncertainty for outlier data points which don’t fall cleanly into any one cluster. In terms of initialization, this is still a shortcoming of the GMM as there aren’t many principled ways to ensure good results. In practice we usually run KK-means++ first, and use this as our cluster initializations for the GMM and this usually works well.

Beyond being two of the most popular clustering techniques, I chose to write about KK-means and GMMs because they are often used to introduce the Expectation Maximization (EM) algorithm. I didn’t explicitly mention the EM algorithm in this post but briefly it is an algorithm which allows us to perform MLE with latent variable models. I hope to cover the EM algorithm in a future post when my understanding of it is better, and when I do I will link it here.


KK-means implementation

def fit(X, K=2):

    def J(assignments, mu):
        j = 0
        for k in range(K):
            j += (np.linalg.norm(X[assignments == k] - mu[k], axis=1) ** 2).sum()
        return j
    def assign_clusters(mu):
        distances = np.zeros(shape=(K, len(X)))
        for k in range(K):
            distances[k] = np.linalg.norm(X - mu[k], axis=1)
        return distances.argmin(axis=0)
    def update_mu(assignments, mu):
        new_mu = np.zeros_like(mu)
        for k in range(K):
            new_mu[k] = X[assignments == k].mean()
        return new_mu
    mu_updates = []
    assignment_updates = []
    j_scores = []
    old_mu = np.array([[1.0, -1.5],
                       [-1.0, 1.5]])
    old_assignments = assign_clusters(old_mu)
    old_j = np.inf
    j_scores.append(J(old_assignments, old_mu))

    while True:

        new_mu = update_mu(old_assignments, old_mu)
        j_scores.append(J(old_assignments, new_mu))

        new_assignments = assign_clusters(new_mu)
        new_j = J(new_assignments, new_mu)
        if abs(new_j - old_j) < 1e-4:

        old_mu, old_assignments, old_j = new_mu, new_assignments, new_j
    return mu_updates, assignment_updates, j_scores

Chain rule application in GMM derivation ddxlog(g(x))=g(x)g(x) \frac{d}{dx}\log(g(x)) = \frac{g^\prime(x)}{g(x)}

Derivative of a multivariate normal distribution with respect to μ\boldsymbol{\mu}

GMM implementation

from scipy.stats import multivariate_normal
outer = np.outer
def fit_gmm(X, K=2):

    def compute_LL():
        log_probs = np.zeros((len(X), K))
        for k in range(K):
            log_probs[:, k] += multivariate_normal(mu[:, k], sigma[..., k]).pdf(X) * pi[k]
        log_probs = np.log(log_probs.sum(1))
        return log_probs.sum()
    def compute_gamma():
        gamma = np.zeros((len(X), K))
        for k in range(K):
            gamma[:, k] = multivariate_normal(mu[:, k], sigma[..., k]).pdf(X) * pi[k]
        gamma /= gamma.sum(axis=1, keepdims=True)
        return gamma

    def compute_Nk():
        return np.sum(gamma, axis=0)

    def compute_mu():
        mu = np.zeros((2, K))
        for k in range(K):
            mu[:, k] = (gamma[:, k, np.newaxis] * X).sum(axis=0) / N_k[k]
        return mu

    def compute_sigma():
        sigma = np.zeros((2, 2, K))
        for k in range(K):
            for n in range(len(X)):
                sigma[..., k] += (gamma[n, k] * outer(X[n] - mu[:, k], X[n] - mu[:, k])) / N_k[k]
        return sigma

    def compute_pi():
        return N_k / len(X)
    LLs, mus, sigmas, pis, gammas = [], [], [], [], []
    mu = np.array([[1.2, -1.5],
                   [-2.0, 1.5]])
    sigma = np.zeros((2, 2, 2))
    sigma[..., 0] = np.eye(2) * 0.1
    sigma[..., 1] = np.eye(2) * 0.1
    pi = np.array([0.5, 0.5])
    LL_old = -np.inf
    while True:
        gamma = compute_gamma()
        N_k = compute_Nk()
        mu = compute_mu()
        sigma = compute_sigma()
        pi = compute_pi()
        LL_new = compute_LL()

        if LL_new - LL_old < 1e-9:
            LL_old = LL_new

    return mus, sigmas, pis, gammas, LLs