Simulating and Fitting a Gaussian Mixture Model
This tutorial demonstrates how to use StateSpaceDynamics.jl to create a Gaussian Mixture Model (GMM) and fit it using the EM algorithm. Unlike Hidden Markov Models which model temporal sequences, GMMs are designed for clustering and density estimation of independent observations. Each data point is assumed to come from one of several Gaussian components, but there's no temporal dependence.
GMMs are fundamental in machine learning for unsupervised clustering, density estimation, anomaly detection, and as building blocks for more complex models. The key insight is that complex data distributions can often be well-approximated as mixtures of simpler Gaussian distributions, each representing a different "mode" or cluster in the data.
Load Required Packages
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using Distributions
using StatsPlots
using Combinatorics
using LaTeXStringsSet up reproducible random number generation
rng = StableRNG(1234);Create a True Gaussian Mixture Model
We'll create a "ground truth" GMM with known parameters, generate data from it, then see how well we can recover these parameters using only the observed data.
k = 3 # Number of mixture components (clusters)
D = 2 # Data dimensionality (2D for easy visualization)2Define the true component means: $\boldsymbol{\mu}_i \in \mathbb{R}^D$ for $i = 1, \ldots, k$ Each column represents the mean vector $\boldsymbol{\mu}_i$ for one component
true_μs = [
-1.0 1.0 0.0; # $x_1$ coordinates of the 3 component centers
-1.0 -1.5 2.0 # $x_2$ coordinates of the 3 component centers
]; # Shape: $(D, k) = (2, 3)$Define covariance matrices $\boldsymbol{\Sigma}_i$ for each component Using isotropic (spherical) covariances for simplicity
true_Σs = [Matrix{Float64}(0.3 * I(2)) for _ in 1:k];Define mixing weights $\pi_i$ (must sum to 1) These represent $P(\text{component} = i)$ for a random sample
true_πs = [0.5, 0.2, 0.3]; # Component 1 most likely, component 2 least likelyConstruct the complete GMM
true_gmm = GaussianMixtureModel(k, true_μs, true_Σs, true_πs);
print("Created GMM: $k components, $D dimensions\n")
for i in 1:k
print("Component $i: μ = $(true_μs[:, i]), π = $(true_πs[i])\n")
endCreated GMM: 3 components, 2 dimensions
Component 1: μ = [-1.0, -1.0], π = 0.5
Component 2: μ = [1.0, -1.5], π = 0.2
Component 3: μ = [0.0, 2.0], π = 0.3Sample Data from the True GMM
Generate synthetic data from our true model. We'll sample both component assignments (for evaluation) and the actual observations.
n = 500 # Number of data points to generate500Determine which component each sample comes from
labels = rand(rng, Categorical(true_πs), n);Count samples per component for verification
component_counts = [sum(labels .== i) for i in 1:k]
print("Samples per component: $(component_counts) (expected: $(round.(n .* true_πs)))\n");Samples per component: [252, 96, 152] (expected: [250.0, 100.0, 150.0])Generate the actual data points
X = Matrix{Float64}(undef, D, n)
for i in 1:n
component = labels[i]
X[:, i] = rand(rng, MvNormal(true_μs[:, component], true_Σs[component]))
endVisualize the generated data colored by true component membership
p1 = scatter(X[1, :], X[2, :];
group=labels,
title="True GMM Components",
xlabel=L"x_1", ylabel=L"x_2",
markersize=4,
alpha=0.7,
palette=:Set1_3,
legend=:topright
)
for i in 1:k
scatter!(p1, [true_μs[1, i]], [true_μs[2, i]];
marker=:star, markersize=10, color=i,
markerstrokewidth=2, markerstrokecolor=:black,
label="")
end
p1Fit GMM Using EM Algorithm
Now we simulate the realistic scenario: observe only data points $\mathbf{X}$, not the true component labels or parameters. Our goal is to recover the underlying mixture structure using EM.
Initialize a GMM with correct number of components but unknown parameters
fit_gmm = GaussianMixtureModel(k, D)
print("Running EM algorithm...")Running EM algorithm...Fit the model using EM algorithm
class_probabilities, lls = fit!(fit_gmm, X;
maxiter=100,
tol=1e-6,
initialize_kmeans=true # K-means initialization helps convergence
);
print("EM converged in $(length(lls)) iterations\n")
print("Log-likelihood improved by $(round(lls[end] - lls[1], digits=1))\n");Iteration 1: Log-likelihood = -1332.4099623018458
Iteration 2: Log-likelihood = -1327.8314292708865
Iteration 3: Log-likelihood = -1326.92067517518
Iteration 4: Log-likelihood = -1326.6159879293753
Iteration 5: Log-likelihood = -1326.467570821675
Iteration 6: Log-likelihood = -1326.3638270426625
Iteration 7: Log-likelihood = -1326.2777493096357
Iteration 8: Log-likelihood = -1326.2023974460274
Iteration 9: Log-likelihood = -1326.1354203562712
Iteration 10: Log-likelihood = -1326.0755374357018
Iteration 11: Log-likelihood = -1326.0217694221124
Iteration 12: Log-likelihood = -1325.973278100635
Iteration 13: Log-likelihood = -1325.9293340939412
Iteration 14: Log-likelihood = -1325.8893078547082
Iteration 15: Log-likelihood = -1325.852662475567
Iteration 16: Log-likelihood = -1325.8189450452733
Iteration 17: Log-likelihood = -1325.7877769101033
Iteration 18: Log-likelihood = -1325.758843667468
Iteration 19: Log-likelihood = -1325.731885545676
Iteration 20: Log-likelihood = -1325.7066885738543
Iteration 21: Log-likelihood = -1325.683076740728
Iteration 22: Log-likelihood = -1325.6609051972093
Iteration 23: Log-likelihood = -1325.6400544671556
Iteration 24: Log-likelihood = -1325.6204255803232
Iteration 25: Log-likelihood = -1325.6019360188163
Iteration 26: Log-likelihood = -1325.5845163637287
Iteration 27: Log-likelihood = -1325.5681075344178
Iteration 28: Log-likelihood = -1325.552658524376
Iteration 29: Log-likelihood = -1325.5381245509134
Iteration 30: Log-likelihood = -1325.5244655495499
Iteration 31: Log-likelihood = -1325.5116449561958
Iteration 32: Log-likelihood = -1325.4996287308434
Iteration 33: Log-likelihood = -1325.488384585261
Iteration 34: Log-likelihood = -1325.477881383886
Iteration 35: Log-likelihood = -1325.4680886924598
Iteration 36: Log-likelihood = -1325.4589764527625
Iteration 37: Log-likelihood = -1325.4505147645657
Iteration 38: Log-likelihood = -1325.4426737579938
Iteration 39: Log-likelihood = -1325.43542354092
Iteration 40: Log-likelihood = -1325.428734207256
Iteration 41: Log-likelihood = -1325.4225758930315
Iteration 42: Log-likelihood = -1325.4169188681706
Iteration 43: Log-likelihood = -1325.4117336529428
Iteration 44: Log-likelihood = -1325.4069911493527
Iteration 45: Log-likelihood = -1325.4026627787734
Iteration 46: Log-likelihood = -1325.3987206186696
Iteration 47: Log-likelihood = -1325.3951375324243
Iteration 48: Log-likelihood = -1325.3918872877184
Iteration 49: Log-likelihood = -1325.3889446600956
Iteration 50: Log-likelihood = -1325.386285519564
Iteration 51: Log-likelihood = -1325.3838868990313
Iteration 52: Log-likelihood = -1325.3817270443728
Iteration 53: Log-likelihood = -1325.3797854465192
Iteration 54: Log-likelihood = -1325.3780428565944
Iteration 55: Log-likelihood = -1325.376481285554
Iteration 56: Log-likelihood = -1325.3750839899778
Iteration 57: Log-likelihood = -1325.3738354459229
Iteration 58: Log-likelihood = -1325.372721312708
Iteration 59: Log-likelihood = -1325.3717283885492
Iteration 60: Log-likelihood = -1325.3708445598402
Iteration 61: Log-likelihood = -1325.3700587457647
Iteration 62: Log-likelihood = -1325.3693608397173
Iteration 63: Log-likelihood = -1325.3687416489377
Iteration 64: Log-likelihood = -1325.3681928334458
Iteration 65: Log-likelihood = -1325.3677068453032
Iteration 66: Log-likelihood = -1325.3672768689937
Iteration 67: Log-likelihood = -1325.3668967635374
Iteration 68: Log-likelihood = -1325.3665610068822
Iteration 69: Log-likelihood = -1325.3662646429218
Iteration 70: Log-likelihood = -1325.3660032313826
Iteration 71: Log-likelihood = -1325.3657728007604
Iteration 72: Log-likelihood = -1325.365569804385
Iteration 73: Log-likelihood = -1325.365391079605
Iteration 74: Log-likelihood = -1325.365233810087
Iteration 75: Log-likelihood = -1325.3650954911204
Iteration 76: Log-likelihood = -1325.3649738978172
Iteration 77: Log-likelihood = -1325.3648670560872
Iteration 78: Log-likelihood = -1325.3647732162158
Iteration 79: Log-likelihood = -1325.3646908289093
Iteration 80: Log-likelihood = -1325.3646185235916
Iteration 81: Log-likelihood = -1325.3645550888625
Iteration 82: Log-likelihood = -1325.364499454864
Iteration 83: Log-likelihood = -1325.364450677479
Iteration 84: Log-likelihood = -1325.3644079241253
Iteration 85: Log-likelihood = -1325.3643704610927
Iteration 86: Log-likelihood = -1325.3643376421846
Iteration 87: Log-likelihood = -1325.3643088986253
Iteration 88: Log-likelihood = -1325.3642837300529
Iteration 89: Log-likelihood = -1325.3642616965142
Iteration 90: Log-likelihood = -1325.3642424113561
Iteration 91: Log-likelihood = -1325.3642255349266
Iteration 92: Log-likelihood = -1325.3642107689748
Iteration 93: Log-likelihood = -1325.3641978517248
Iteration 94: Log-likelihood = -1325.3641865534828
Iteration 95: Log-likelihood = -1325.3641766727762
Iteration 96: Log-likelihood = -1325.3641680329517
Iteration 97: Log-likelihood = -1325.364160479159
Iteration 98: Log-likelihood = -1325.3641538756922
Iteration 99: Log-likelihood = -1325.3641481036734
Iteration 100: Log-likelihood = -1325.3641430589653
EM converged in 100 iterations
Log-likelihood improved by 7.0Plot EM convergence
p2 = plot(lls, xlabel="EM Iteration", ylabel="Log-Likelihood",
title="EM Algorithm Convergence", legend=false,
marker=:circle, markersize=3, lw=2, color=:darkblue)
if length(lls) < 100
annotate!(p2, length(lls)*0.7, lls[end]*0.95,
text("Converged in $(length(lls)) iterations", 10)) # Add convergence annotation
endVisualize Fitted Model
Create visualization showing both data and fitted GMM with probability contours. Create grid for plotting contours
x_range = range(extrema(X[1, :])..., length=100)
y_range = range(extrema(X[2, :])..., length=100)
xs = collect(x_range)
ys = collect(y_range)
p3 = scatter(X[1, :], X[2, :];
markersize=3,
alpha=0.5,
color=:gray,
xlabel=L"x_1",
ylabel=L"x_2",
title="Fitted GMM Components",
legend=:topright,
label="Data points"
)
p3
colors = [:red, :green, :blue] # Plot probability density contours for each learned component
for i in 1:fit_gmm.k
comp_dist = MvNormal(fit_gmm.μₖ[:, i], fit_gmm.Σₖ[i])
Z_i = [fit_gmm.πₖ[i] * pdf(comp_dist, [x, y]) for y in ys, x in xs]
contour!(p3, xs, ys, Z_i;
levels=6,
linewidth=2,
c=colors[i],
label="Component $i (π=$(round(fit_gmm.πₖ[i], digits=2)))"
)
scatter!(p3, [fit_gmm.μₖ[1, i]], [fit_gmm.μₖ[2, i]];
marker=:star, markersize=8, color=colors[i],
markerstrokewidth=2, markerstrokecolor=:black,
label="")
endComponent Assignment Analysis
Use fitted model to assign each data point to its most likely component and compare with true assignments.
Get posterior probabilities: $P(\text{component } i | \mathbf{x}_j)$
predicted_labels = [argmax(class_probabilities[:, j]) for j in 1:n];Calculate assignment accuracy (accounting for possible label permutation) Since EM can converge with components in different order
function best_permutation_accuracy(true_labels, pred_labels, k)
best_acc = 0.0
best_perm = collect(1:k)
for perm in Combinatorics.permutations(1:k)
mapped_pred = [perm[pred_labels[i]] for i in 1:length(pred_labels)]
acc = mean(true_labels .== mapped_pred)
if acc > best_acc
best_acc = acc
best_perm = perm
end
end
return best_acc, best_perm
end
accuracy, best_perm = best_permutation_accuracy(labels, predicted_labels, k)
print("Component assignment accuracy: $(round(accuracy*100, digits=1))%\n");Component assignment accuracy: 97.0%Final Comparison Visualization
Side-by-side comparison of true vs learned component assignments
p_true = scatter(X[1, :], X[2, :]; group=labels, title="True Components",
xlabel=L"x_1", ylabel=L"x_2", markersize=3, alpha=0.7,
palette=:Set1_3, legend=false)
remapped_predicted = [best_perm[predicted_labels[i]] for i in 1:n] # Apply best permutation to predicted labels for fair comparison
p_learned = scatter(X[1, :], X[2, :]; group=remapped_predicted, title="Learned Components",
xlabel=L"x_1", ylabel=L"x_2", markersize=3, alpha=0.7,
palette=:Set1_3, legend=false)
p4 = plot(p_true, p_learned, layout=(1, 2), size=(800, 350))Summary
This tutorial demonstrated the complete Gaussian Mixture Model workflow:
Key Concepts:
- Mixture modeling: Complex distributions as weighted combinations of simpler Gaussians
- EM algorithm: Iterative parameter learning via expectation-maximization
- Soft clustering: Probabilistic component assignments rather than hard clusters
- Label permutation: Handling component identifiability issues
Applications:
- Unsupervised clustering and density estimation
- Anomaly detection via likelihood thresholding
- Dimensionality reduction (when extended to factor analysis)
- Building blocks for more complex probabilistic models
Technical Insights:
- K-means initialization significantly improves EM convergence
- Log-likelihood monitoring ensures proper algorithm behavior
- Parameter recovery quality depends on component separation and sample size
GMMs provide a flexible, interpretable framework for modeling heterogeneous data with multiple underlying modes or clusters, forming the foundation for many advanced machine learning techniques.
This page was generated using Literate.jl.