towardsai.net
Author(s): Kaitai Dong Originally published on Towards AI. Figure 1: Gaussian mixture model illustration [Image by AI]IntroductionIn a time where deep learning (DL) and transformers steal the spotlight, its easy to forget about classic algorithms like K-means, DBSCAN, and GMM. But heres a hot take: for anyone tackling real-world clustering and anomaly detection challenges, these statistical workhorses remain indispensable tools with surprising staying power.Consider the everyday clustering puzzles: customer segmentation, social network analysis, or image segmentation. K-means has been used to solve these problems for decades with its simple centroid-based approach. When data forms irregular shapes, DBSCAN steps in with its density-based algorithm to identify non-convex clusters that leave K-means bewildered.But real-world data rarely forms neat, separated bubbles. Enter Gaussian Mixture Model and its variants! GMMs acknowledge the fundamental uncertainty in cluster assignments. By modeling the probability density of normal behavior, they can identify observations that dont fit the expected pattern without requiring labeled examples. So before chasing the latest neural architecture for your clustering or segmentation task, consider the statistical classics such as GMMs.Many people can confidently talk about how K-means works but I bet my good dollars that not many have that confidence when it comes to GMMs. This article will discuss the math behind GMM and its variants in an understandable way (I will try my best!), and showcase why it deserves more attention for your next clustering tasks.Remember this. Classics never make a comeback. They wait for that perfect moment to take the spotlight from overdone, tired trends.What is a Gaussian Mixture Model?A Gaussian mixture is a function that is composed of several Gaussian distributions, each identified by k {1,, K}, where K is the number of clusters of our dataset, which you must know in advance. Each Gaussian distribution k in the mixture contain the following parameters:A mean that defines its center.A covariance matrix that describes its shape and orientation. This would be equivalent to the dimensions of an ellipsoid in a multivariate scenario.A mixing coefficient that defines the weight of the Gaussian function, where 0 and the sum of k adds up to 1.Mathematically, it can be written in:where p(x) represents the probability density at point x, and N(x|, ) is the multivariate Gaussian density with mean and covariance matrix .Equations all look scary. But lets take a step back and look at this multivariate Gaussian density function N(x|, ) and the dimension of each parameter in it. Assume the dataset include N = 500 three-dimensional data points (D=3), then the dataset x is essentially a 500 3 matrix, is a is a 3 3 matrix. The output of the Gaussian distribution function will be a 500 1 vector.When working with a GMM, we face a circular problem:To know which Gaussian cluster each data point belongs to, we need to know the parameters of each Gaussian (means, covariances, weights).But to estimate these parameters correctly, we need to know which data points belong to each Gaussian.To break this cycle, here enters the Expectation-Maximization (EM) algorithm, where it makes educated guesses and then refines them iteratively.Parameter estimation with EM algorithmEM algorithm helps determine the optimal values for the parameters of a GMM through the following steps:Step 1: Initialize Start with random guesses for parameters (, , ) of each Gaussian cluster.Step 2: Expectation (E-step) Calculate how much each point belongs to each Gaussian cluster and then compute a set of responsibilities for each data point, which represents the probabilities that the data point comes from each cluster.Step 3: Maximization (M-step) Update each Gaussian cluster using all the instances in the dataset, with each instance weighted by the estimated probability (a.k.a. responsibilities) that it belongs to that cluster. Specifically, new means are the weighted average of all data points, where weights are the responsibilities. New covariances are the weighted spread around each new mean. Finally, new mixing weights are the fraction of the total responsibilities each component receives. Note that each clusters update will mostly be impacted by the instances it is most responsible for.Step 4: Repeat Go back to Step 2 with these updated parameters and continue until the changes become minimal (convergence).Often, people get confused with the M-step as a lot of terms are thrown in. I will use the previous example (500 3-D data points with 3 Gaussian clusters) to break it down into more concrete terms.For updating the means, were doing a weighted average where each point contributes according to its responsibility value to the corresponding Gaussian cluster. Mathematically, for kth Gaussian cluster,new means = (sum of [responsibility_ik point_i]) / (sum of all responsibilities for cluster k)For updating the covariances, we use a similar weighted approach. For each point, calculate how far it is from the new mean, and then multiply this deviation by its transpose to get a matrix. Subsequently, weight this matrix by the points responsibility and sum these weighted matrices across all points. Finally, divide it by the total responsibility for that cluster.For updating the mixing weights, we simply sum up all the responsibilities for cluster k and then divide by the total number of data points.Lets say for Gaussian cluster 2:The sum of responsibilities is 200 (out of 500 points)The weighted sum of points is (400, 600, 800)The weighted sum of squared deviations gives a certain covariance matrixThen:New mean for cluster 2 = (400, 600, 800)/200 = (2, 3, 4)New mixing weight = 200/500 = 0.4New covariance = (weighted sum of deviations)/200Hopefully it makes a lot more sense now!!Clustering with GMMNow that I have an estimate of the location, size, shape, orientation, and relative weights of each Gaussian cluster, GMM can easily assign data point to the most likely cluster (hard clustering) or estimate the probability that it belongs to a particular cluster (soft clustering).The implementation of GMM in Python is quite simple and straightforward, thanks to the good old scikit-learn library. Here I provide a sample code for a clustering task using built-in GMM using randomly generated data points with 3 clusters, shown in Figure 2.Figure 2: Data points for clustering [Image by author]The Python code is given below:from sklearn.mixture import GaussianMixturegm = GaussianMixture(n_components=3, n_init=10, random_state=42)gm.fit(X)# To see the parameters the GM has estimated# weights, means, and covariancesgm.weights_gm.means_gm.covariances_# To make a prediction for the data point# hard clusteringgm.predict(X)# soft clusteringgm.predict_proba(X).round(2)Figure 3 illustrates the cluster locations, decision boundaries and the density contours of the GMM (it estimates the density of the model at any given location).Figure 3: Cluster locations, decision boundaries, and density contours of a trained GMM [Image by author]It looks like the GMM has clearly found a great solution! But it is worth noting that real-life data is not always so Gaussian and low-dimensional. EM can struggle to converge to the optimal solution when the problem is of high dimensions and high number of clusters. To tackle this issue, you can limit the number of parameters GMM has to learn. One way to do this is to limit the range of shapes and orientations that the clusters can have. This is achieved by imposing constraints on the covariance matrices, which can be done by setting the covariance_type hyperparameter.gm_full = GaussianMixture(n_components=3, n_init=10, covariance_type="full", # default value random_state=42)"full" (default): No constraint, all clusters can take on any ellipsoidal shape of any size [1]."spherical": All clusters must be spherical, but they can have different diameters (i.e., different variances)."diag": Clusters can take on any ellipsoidal shape of any size, but the ellipsoid's axes must be parallel to the axes (i.e., the covariance matrices must be diagonal)."tied": All clusters must have the same shape, which can be any ellipsoid (i.e., they all share the same covariance matrix).To show the difference with the default setting, Figure 4 illustrates the solutions found by the EM algorithm when covariance_type is set to "tied".Figure 4: Clustering result of the same task using GMM with tied clusters [Image by author]It is also important to discuss the computational complexity of training a GMM. It largely depends on the number of data points m, the number of dimensions n, the number of clusters k, and the constraints on the covariance matrices (4 types mentioned above). If covariance_type is "spherical" or "diag", the complexity is O(kmn), assuming the data has a clustering structure. If covariance_type is "tied" or "full", the complexity then becomes O(kmn + kn), this will not scale well [1].Finding the right number of clustersThe given example is quite simple partially due to the fact that the number of clusters is already known when I generated the dataset. But when you do not have this information prior to the training, certain metrics are required to help determine the optimal number of clusters.For GMM, you can try to find the model that minimizes a theoretical information criterion, such as the Bayesian information criterion (BIC) and the Akaike information criterion (AIC), defined in equations below.BIC = log(m)*p 2*log(L)AIC = 2*p 2*log(L)Where m is the number of data points, p is the number of parameters learned by the model, and L is the maximized value of the likelihood function of the model. The computation of these values are simple with Python.gm.bic(X)gm.aic(X)BIC and AIC penalize models that have more parameters to learn (e.g., more clusters) and reward the models that fit the data well. The lower the value, the better model fits the data. In practice, you can set a range of numbers of clusters k and plot the BIC or AIC against different k and find the one with the lowest BIC or AIC [2].Variants of GMMI find some variants of GMM quite useful and handy and often can make further improvements on its classic form.Bayesian GMM: It is capable of giving weights equal to or close to zero to unnecessary clusters. In practice, you can set the number of clusters n_components to a value that you have good reasons to believe is greater than the actual optimal number of clusters, and then it will automatically handle the learning for you.Robust GMM: It addresses GMMs over-sensitivity to outliers issue by modifying the objective function. Instead of maximizing the standard log-likelihood, it uses robust estimators to put less weight on points that are far from cluster centers. It provides a more stable outcome.Online/Incremental GMM: It deals with computational and memory limitations of standard GMMs. Parameters are updated after seeing each new data point or small batch, rather than requiring the full dataset. It also includes a forgetting mechanism that allows the model to forget older data and adapt more to non-stationary distributions.GMM vs K-meansSince real-life data is often complex, GMM generally outperforms K-means in clustering and segmentation tasks. I typically run K-means first as a baseline and then try GMM or its variants to see if additional complexity provides any meaningful improvements. But lets compare them side by side and see the main differences between these two classic algorithms.Figure 5: Comparison between K-means and GMM over different aspects [Image by author]Bonus point of GMM Handy anomaly detection tool!Using a GMM for anomaly detection task is simple: any instance located in a low-density region can be considered as an anomaly. However, the trick is you must define what density threshold you want to use. As an example, I will use GMM to identify abnormal network traffic patterns.The features included in this task are as follows:Packet size (bytes)Inter-arrival time (ms)Connection duration (s)Protocol-specific valueEntropy of packet payloadTCP window sizeThe raw dataset looks like this:Figure 6: The headers and value formats of the network traffic dataset [Image by author]The code snippet will show how to preprocess the data, train and evaluate the model, and also compare with other common anomaly detection methods.import numpy as npimport pandas as pdimport matplotlib.pyplot as pltfrom sklearn.mixture import GaussianMixturefrom sklearn.preprocessing import StandardScalerfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import precision_recall_curve, average_precision_scorefrom sklearn.metrics import f1_score# raw_df has been shown previouslydf = raw_df.copy()X = df.drop(columns=['is_anomaly'])y = df['is_anomaly']# Split the dataX_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.3, random_state=42, stratify=y)# Scale the featuresscaler = StandardScaler()X_train_scaled = scaler.fit_transform(X_train)X_test_scaled = scaler.transform(X_test)# We'll create models using only normal traffic data for training# This is a common approach for anomaly detectionX_train_normal = X_train_scaled[y_train == 0]# Try different numbers of components to find the best fitn_components_range = range(1, 10)bic_scores = []aic_scores = []for n_components in n_components_range: gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=42) gmm.fit(X_train_normal) bic_scores.append(gmm.bic(X_train_normal)) aic_scores.append(gmm.aic(X_train_normal))# Choose the optimal number of components based on BICoptimal_components = n_components_range[np.argmin(bic_scores)]print(f"Optimal number of components based on BIC: {optimal_components}")# Train the final modelgmm = GaussianMixture(n_components=optimal_components, covariance_type='full', random_state=42)gmm.fit(X_train_normal)# Calculate negative log probability (higher means more anomalous)# gmm_train_scores is very important to determine the threshold percentile in the evaluationgmm_train_scores = -gmm.score_samples(X_train_scaled)gmm_test_scores = -gmm.score_samples(X_test_scaled)def evaluate_model(y_true, anomaly_scores, threshold_percentile=3): """ Evaluate model performance with various metrics Parameters: y_true: True labels (0 for normal, 1 for anomaly) anomaly_scores: Scores where higher means more anomalous threshold_percentile: Percentile for threshold selection Returns: Dictionary of performance metrics """ # Use a percentile threshold from the training scores threshold = np.percentile(anomaly_scores, 100 - threshold_percentile) # Predict anomalies y_pred = (anomaly_scores > threshold).astype(int) # calculate evaluation metrics f1 = f1_score(y_true, y_pred) precision, recall, _ = precision_recall_curve(y_true, anomaly_scores) avg_precision = average_precision_score(y_true, anomaly_scores) return { 'f1_score': f1, 'avg_precision': avg_precision, 'precision_curve': precision, 'recall_curve': recall, 'threshold': threshold, 'y_pred': y_pred }# Calculate metricsgmm_results = evaluate_model(y_test, gmm_test_scores)Lets throw in a few common anomaly detection methods, i.e., Isolation Forest, One-Class SVM, and Local Outlier Factor (LOF), and check their performance. Since irregular traffic pattern is a rare case, so I will use PR-AUC as the evaluation metric for models effectiveness. The result is given in Figure 7, where the closer the result to 1 the more accurate the model is.Figure 7: Comparative analysis for network traffic detection task using PR-AUC metric to evaluate GMM, Isolation Forest, One-class SVM, and LOF. [Image by author]The result shows GMM is pretty strong in identifying irregular network traffic and outperforms other common methods! GMM can be a good start for anomaly detection tasks especially if normal behaviors include multiple distinct patterns or if you need probability anomaly scores.Real-life cases are usually more complex than the steps I have shown, but hopefully this blog provides a good foundation for you to understand how GMM works and how you can implement it for your clustering or anomaly detection tasks.References[1] Aurelien Geron. Hands-on machine learning with scikit-learn, keras & tensorflow. OReilly, 2023[2] Bayesian information criterion, Wikipedia. https://en.wikipedia.org/wiki/Bayesian_information_criterionJoin thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming asponsor. Published via Towards AI