Simulating and Fitting a Gaussian Mixture Model
This tutorial demonstrates how to use StateSpaceDynamics.jl to create a Gaussian Mixture Model (GMM) and fit it using the EM algorithm. Unlike Hidden Markov Models which model temporal sequences, GMMs are designed for clustering and density estimation of independent observations. Each data point is assumed to come from one of several Gaussian components, but there's no temporal dependence.
GMMs are fundamental in machine learning for unsupervised clustering, density estimation, anomaly detection, and as building blocks for more complex models. The key insight is that complex data distributions can often be well-approximated as mixtures of simpler Gaussian distributions, each representing a different "mode" or cluster in the data.
Load Required Packages
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using Distributions
using StatsPlots
using Combinatorics
using LaTeXStringsSet up reproducible random number generation
rng = StableRNG(1234);Create a True Gaussian Mixture Model
We'll create a "ground truth" GMM with known parameters, generate data from it, then see how well we can recover these parameters using only the observed data.
k = 3 # Number of mixture components (clusters)
D = 2 # Data dimensionality (2D for easy visualization)2Define the true component means: $\boldsymbol{\mu}_i \in \mathbb{R}^D$ for $i = 1, \ldots, k$ Each column represents the mean vector $\boldsymbol{\mu}_i$ for one component
true_μs = [
-1.0 1.0 0.0; # $x_1$ coordinates of the 3 component centers
-1.0 -1.5 2.0 # $x_2$ coordinates of the 3 component centers
]; # Shape: $(D, k) = (2, 3)$Define covariance matrices $\boldsymbol{\Sigma}_i$ for each component Using isotropic (spherical) covariances for simplicity
true_Σs = [Matrix{Float64}(0.3 * I(2)) for _ in 1:k];Define mixing weights $\pi_i$ (must sum to 1) These represent $P(\text{component} = i)$ for a random sample
true_πs = [0.5, 0.2, 0.3]; # Component 1 most likely, component 2 least likelyConstruct the complete GMM
true_gmm = GaussianMixtureModel(k, true_μs, true_Σs, true_πs);
print("Created GMM: $k components, $D dimensions\n")
for i in 1:k
print("Component $i: μ = $(true_μs[:, i]), π = $(true_πs[i])\n")
endCreated GMM: 3 components, 2 dimensions
Component 1: μ = [-1.0, -1.0], π = 0.5
Component 2: μ = [1.0, -1.5], π = 0.2
Component 3: μ = [0.0, 2.0], π = 0.3Sample Data from the True GMM
Generate synthetic data from our true model. We'll sample both component assignments (for evaluation) and the actual observations.
n = 500 # Number of data points to generate500Determine which component each sample comes from
labels = rand(rng, Categorical(true_πs), n);Count samples per component for verification
component_counts = [sum(labels .== i) for i in 1:k]
print("Samples per component: $(component_counts) (expected: $(round.(n .* true_πs)))\n");Samples per component: [252, 96, 152] (expected: [250.0, 100.0, 150.0])Generate the actual data points
X = Matrix{Float64}(undef, D, n)
for i in 1:n
component = labels[i]
X[:, i] = rand(rng, MvNormal(true_μs[:, component], true_Σs[component]))
endVisualize the generated data colored by true component membership
p1 = scatter(X[1, :], X[2, :];
group=labels,
title="True GMM Components",
xlabel=L"x_1", ylabel=L"x_2",
markersize=4,
alpha=0.7,
palette=:Set1_3,
legend=:topright
)
for i in 1:k
scatter!(p1, [true_μs[1, i]], [true_μs[2, i]];
marker=:star, markersize=10, color=i,
markerstrokewidth=2, markerstrokecolor=:black,
label="")
end
p1Fit GMM Using EM Algorithm
Now we simulate the realistic scenario: observe only data points $\mathbf{X}$, not the true component labels or parameters. Our goal is to recover the underlying mixture structure using EM.
Initialize a GMM with correct number of components but unknown parameters
fit_gmm = GaussianMixtureModel(k, D)
print("Running EM algorithm...")Running EM algorithm...Fit the model using EM algorithm
class_probabilities, lls = fit!(fit_gmm, X;
maxiter=100,
tol=1e-6,
initialize_kmeans=true # K-means initialization helps convergence
);
print("EM converged in $(length(lls)) iterations\n")
print("Log-likelihood improved by $(round(lls[end] - lls[1], digits=1))\n");Iteration 1: Log-likelihood = -1389.9540484217016
Iteration 2: Log-likelihood = -1356.263583656021
Iteration 3: Log-likelihood = -1348.0824392868528
Iteration 4: Log-likelihood = -1344.470128888572
Iteration 5: Log-likelihood = -1341.5868151824936
Iteration 6: Log-likelihood = -1338.847180965278
Iteration 7: Log-likelihood = -1336.2652281481735
Iteration 8: Log-likelihood = -1334.0844814520226
Iteration 9: Log-likelihood = -1332.4054672863692
Iteration 10: Log-likelihood = -1331.1637585686196
Iteration 11: Log-likelihood = -1330.2713861251125
Iteration 12: Log-likelihood = -1329.6506981349426
Iteration 13: Log-likelihood = -1329.2311196337544
Iteration 14: Log-likelihood = -1328.9515142006971
Iteration 15: Log-likelihood = -1328.7642888100243
Iteration 16: Log-likelihood = -1328.6359927189026
Iteration 17: Log-likelihood = -1328.544818255028
Iteration 18: Log-likelihood = -1328.4771801193533
Iteration 19: Log-likelihood = -1328.4247761830434
Iteration 20: Log-likelihood = -1328.3825238614113
Iteration 21: Log-likelihood = -1328.3472578593226
Iteration 22: Log-likelihood = -1328.316952791291
Iteration 23: Log-likelihood = -1328.290269818344
Iteration 24: Log-likelihood = -1328.266292684737
Iteration 25: Log-likelihood = -1328.2443720311148
Iteration 26: Log-likelihood = -1328.2240315079316
Iteration 27: Log-likelihood = -1328.2049095770592
Iteration 28: Log-likelihood = -1328.1867223292675
Iteration 29: Log-likelihood = -1328.1692389654418
Iteration 30: Log-likelihood = -1328.152265084379
Iteration 31: Log-likelihood = -1328.1356308765678
Iteration 32: Log-likelihood = -1328.1191824401378
Iteration 33: Log-likelihood = -1328.1027750868407
Iteration 34: Log-likelihood = -1328.0862678958356
Iteration 35: Log-likelihood = -1328.069519012055
Iteration 36: Log-likelihood = -1328.0523813376958
Iteration 37: Log-likelihood = -1328.034698367206
Iteration 38: Log-likelihood = -1328.0162999930194
Iteration 39: Log-likelihood = -1327.99699818013
Iteration 40: Log-likelihood = -1327.9765824892388
Iteration 41: Log-likelihood = -1327.9548155400196
Iteration 42: Log-likelihood = -1327.9314286679785
Iteration 43: Log-likelihood = -1327.9061182631644
Iteration 44: Log-likelihood = -1327.8785436026378
Iteration 45: Log-likelihood = -1327.8483273978366
Iteration 46: Log-likelihood = -1327.8150607234527
Iteration 47: Log-likelihood = -1327.778314343119
Iteration 48: Log-likelihood = -1327.7376584511835
Iteration 49: Log-likelihood = -1327.6926921566685
Iteration 50: Log-likelihood = -1327.6430823074584
Iteration 51: Log-likelihood = -1327.5886084361318
Iteration 52: Log-likelihood = -1327.5292072784596
Iteration 53: Log-likelihood = -1327.465007804374
Iteration 54: Log-likelihood = -1327.3963476455203
Iteration 55: Log-likelihood = -1327.323765086477
Iteration 56: Log-likelihood = -1327.2479665969258
Iteration 57: Log-likelihood = -1327.1697758827947
Iteration 58: Log-likelihood = -1327.090074165956
Iteration 59: Log-likelihood = -1327.009741773066
Iteration 60: Log-likelihood = -1326.9296087550024
Iteration 61: Log-likelihood = -1326.8504187143749
Iteration 62: Log-likelihood = -1326.7728067080518
Iteration 63: Log-likelihood = -1326.6972897734129
Iteration 64: Log-likelihood = -1326.6242674291925
Iteration 65: Log-likelihood = -1326.5540291945072
Iteration 66: Log-likelihood = -1326.486766420419
Iteration 67: Log-likelihood = -1326.4225862451162
Iteration 68: Log-likelihood = -1326.3615260645256
Iteration 69: Log-likelihood = -1326.3035674401706
Iteration 70: Log-likelihood = -1326.2486487954152
Iteration 71: Log-likelihood = -1326.1966765705747
Iteration 72: Log-likelihood = -1326.1475347283908
Iteration 73: Log-likelihood = -1326.1010926436602
Iteration 74: Log-likelihood = -1326.057211494823
Iteration 75: Log-likelihood = -1326.0157493186052
Iteration 76: Log-likelihood = -1325.9765649049875
Iteration 77: Log-likelihood = -1325.939520708951
Iteration 78: Log-likelihood = -1325.9044849445645
Iteration 79: Log-likelihood = -1325.8713330107512
Iteration 80: Log-likelihood = -1325.8399483795322
Iteration 81: Log-likelihood = -1325.8102230587806
Iteration 82: Log-likelihood = -1325.7820577236162
Iteration 83: Log-likelihood = -1325.7553615942368
Iteration 84: Log-likelihood = -1325.7300521236768
Iteration 85: Log-likelihood = -1325.7060545466402
Iteration 86: Log-likelihood = -1325.6833013304226
Iteration 87: Log-likelihood = -1325.661731560489
Iteration 88: Log-likelihood = -1325.6412902867307
Iteration 89: Log-likelihood = -1325.6219278512485
Iteration 90: Log-likelihood = -1325.6035992146017
Iteration 91: Log-likelihood = -1325.5862632944454
Iteration 92: Log-likelihood = -1325.5698823283244
Iteration 93: Log-likelihood = -1325.5544212705483
Iteration 94: Log-likelihood = -1325.5398472316917
Iteration 95: Log-likelihood = -1325.5261289679147
Iteration 96: Log-likelihood = -1325.5132364260055
Iteration 97: Log-likelihood = -1325.5011403486728
Iteration 98: Log-likelihood = -1325.4898119431143
Iteration 99: Log-likelihood = -1325.4792226143236
Iteration 100: Log-likelihood = -1325.4693437629173
EM converged in 100 iterations
Log-likelihood improved by 64.5Plot EM convergence
p2 = plot(lls, xlabel="EM Iteration", ylabel="Log-Likelihood",
title="EM Algorithm Convergence", legend=false,
marker=:circle, markersize=3, lw=2, color=:darkblue)
if length(lls) < 100
annotate!(p2, length(lls)*0.7, lls[end]*0.95,
text("Converged in $(length(lls)) iterations", 10)) # Add convergence annotation
endVisualize Fitted Model
Create visualization showing both data and fitted GMM with probability contours. Create grid for plotting contours
x_range = range(extrema(X[1, :])..., length=100)
y_range = range(extrema(X[2, :])..., length=100)
xs = collect(x_range)
ys = collect(y_range)
p3 = scatter(X[1, :], X[2, :];
markersize=3,
alpha=0.5,
color=:gray,
xlabel=L"x_1",
ylabel=L"x_2",
title="Fitted GMM Components",
legend=:topright,
label="Data points"
)
p3
colors = [:red, :green, :blue] # Plot probability density contours for each learned component
for i in 1:fit_gmm.k
comp_dist = MvNormal(fit_gmm.μₖ[:, i], fit_gmm.Σₖ[i])
Z_i = [fit_gmm.πₖ[i] * pdf(comp_dist, [x, y]) for y in ys, x in xs]
contour!(p3, xs, ys, Z_i;
levels=6,
linewidth=2,
c=colors[i],
label="Component $i (π=$(round(fit_gmm.πₖ[i], digits=2)))"
)
scatter!(p3, [fit_gmm.μₖ[1, i]], [fit_gmm.μₖ[2, i]];
marker=:star, markersize=8, color=colors[i],
markerstrokewidth=2, markerstrokecolor=:black,
label="")
endComponent Assignment Analysis
Use fitted model to assign each data point to its most likely component and compare with true assignments.
Get posterior probabilities: $P(\text{component } i | \mathbf{x}_j)$
predicted_labels = [argmax(class_probabilities[:, j]) for j in 1:n];Calculate assignment accuracy (accounting for possible label permutation) Since EM can converge with components in different order
function best_permutation_accuracy(true_labels, pred_labels, k)
best_acc = 0.0
best_perm = collect(1:k)
for perm in Combinatorics.permutations(1:k)
mapped_pred = [perm[pred_labels[i]] for i in 1:length(pred_labels)]
acc = mean(true_labels .== mapped_pred)
if acc > best_acc
best_acc = acc
best_perm = perm
end
end
return best_acc, best_perm
end
accuracy, best_perm = best_permutation_accuracy(labels, predicted_labels, k)
print("Component assignment accuracy: $(round(accuracy*100, digits=1))%\n");Component assignment accuracy: 95.8%Final Comparison Visualization
Side-by-side comparison of true vs learned component assignments
p_true = scatter(X[1, :], X[2, :]; group=labels, title="True Components",
xlabel=L"x_1", ylabel=L"x_2", markersize=3, alpha=0.7,
palette=:Set1_3, legend=false)
remapped_predicted = [best_perm[predicted_labels[i]] for i in 1:n] # Apply best permutation to predicted labels for fair comparison
p_learned = scatter(X[1, :], X[2, :]; group=remapped_predicted, title="Learned Components",
xlabel=L"x_1", ylabel=L"x_2", markersize=3, alpha=0.7,
palette=:Set1_3, legend=false)
p4 = plot(p_true, p_learned, layout=(1, 2), size=(800, 350))Summary
This tutorial demonstrated the complete Gaussian Mixture Model workflow:
Key Concepts:
- Mixture modeling: Complex distributions as weighted combinations of simpler Gaussians
- EM algorithm: Iterative parameter learning via expectation-maximization
- Soft clustering: Probabilistic component assignments rather than hard clusters
- Label permutation: Handling component identifiability issues
Applications:
- Unsupervised clustering and density estimation
- Anomaly detection via likelihood thresholding
- Dimensionality reduction (when extended to factor analysis)
- Building blocks for more complex probabilistic models
Technical Insights:
- K-means initialization significantly improves EM convergence
- Log-likelihood monitoring ensures proper algorithm behavior
- Parameter recovery quality depends on component separation and sample size
GMMs provide a flexible, interpretable framework for modeling heterogeneous data with multiple underlying modes or clusters, forming the foundation for many advanced machine learning techniques.
This page was generated using Literate.jl.