Simulating and Fitting a Gaussian Mixture Model

This tutorial demonstrates how to use StateSpaceDynamics.jl to create a Gaussian Mixture Model (GMM) and fit it using the EM algorithm. Unlike Hidden Markov Models which model temporal sequences, GMMs are designed for clustering and density estimation of independent observations. Each data point is assumed to come from one of several Gaussian components, but there's no temporal dependence.

GMMs are fundamental in machine learning for unsupervised clustering, density estimation, anomaly detection, and as building blocks for more complex models. The key insight is that complex data distributions can often be well-approximated as mixtures of simpler Gaussian distributions, each representing a different "mode" or cluster in the data.

Load Required Packages

using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using Distributions
using StatsPlots
using Combinatorics
using LaTeXStrings

Set up reproducible random number generation

rng = StableRNG(1234);

Create a True Gaussian Mixture Model

We'll create a "ground truth" GMM with known parameters, generate data from it, then see how well we can recover these parameters using only the observed data.

k = 3  # Number of mixture components (clusters)
D = 2  # Data dimensionality (2D for easy visualization)
2

Define the true component means: $\boldsymbol{\mu}_i \in \mathbb{R}^D$ for $i = 1, \ldots, k$ Each column represents the mean vector $\boldsymbol{\mu}_i$ for one component

true_μs = [
    -1.0  1.0  0.0;   # $x_1$ coordinates of the 3 component centers
    -1.0 -1.5  2.0    # $x_2$ coordinates of the 3 component centers
];  # Shape: $(D, k) = (2, 3)$

Define covariance matrices $\boldsymbol{\Sigma}_i$ for each component Using isotropic (spherical) covariances for simplicity

true_Σs = [Matrix{Float64}(0.3 * I(2)) for _ in 1:k];

Define mixing weights $\pi_i$ (must sum to 1) These represent $P(\text{component} = i)$ for a random sample

true_πs = [0.5, 0.2, 0.3];  # Component 1 most likely, component 2 least likely

Construct the complete GMM

true_gmm = GaussianMixtureModel(k, true_μs, true_Σs, true_πs);

print("Created GMM: $k components, $D dimensions\n")
for i in 1:k
    print("Component $i: μ = $(true_μs[:, i]), π = $(true_πs[i])\n")
end
Created GMM: 3 components, 2 dimensions
Component 1: μ = [-1.0, -1.0], π = 0.5
Component 2: μ = [1.0, -1.5], π = 0.2
Component 3: μ = [0.0, 2.0], π = 0.3

Sample Data from the True GMM

Generate synthetic data from our true model. We'll sample both component assignments (for evaluation) and the actual observations.

n = 500  # Number of data points to generate
500

Determine which component each sample comes from

labels = rand(rng, Categorical(true_πs), n);

Count samples per component for verification

component_counts = [sum(labels .== i) for i in 1:k]
print("Samples per component: $(component_counts) (expected: $(round.(n .* true_πs)))\n");
Samples per component: [252, 96, 152] (expected: [250.0, 100.0, 150.0])

Generate the actual data points

X = Matrix{Float64}(undef, D, n)
for i in 1:n
    component = labels[i]
    X[:, i] = rand(rng, MvNormal(true_μs[:, component], true_Σs[component]))
end

Visualize the generated data colored by true component membership

p1 = scatter(X[1, :], X[2, :];
    group=labels,
    title="True GMM Components",
    xlabel=L"x_1", ylabel=L"x_2",
    markersize=4,
    alpha=0.7,
    palette=:Set1_3,
    legend=:topright
)

for i in 1:k
    scatter!(p1, [true_μs[1, i]], [true_μs[2, i]];
        marker=:star, markersize=10, color=i,
        markerstrokewidth=2, markerstrokecolor=:black,
        label="")
end

p1
Example block output

Fit GMM Using EM Algorithm

Now we simulate the realistic scenario: observe only data points $\mathbf{X}$, not the true component labels or parameters. Our goal is to recover the underlying mixture structure using EM.

Initialize a GMM with correct number of components but unknown parameters

fit_gmm = GaussianMixtureModel(k, D)

print("Running EM algorithm...")
Running EM algorithm...

Fit the model using EM algorithm

class_probabilities, lls = fit!(fit_gmm, X;
    maxiter=100,
    tol=1e-6,
    initialize_kmeans=true  # K-means initialization helps convergence
);

print("EM converged in $(length(lls)) iterations\n")
print("Log-likelihood improved by $(round(lls[end] - lls[1], digits=1))\n");
Iteration 1: Log-likelihood = -1332.4099623018458
Iteration 2: Log-likelihood = -1327.8314292708865
Iteration 3: Log-likelihood = -1326.92067517518
Iteration 4: Log-likelihood = -1326.6159879293753
Iteration 5: Log-likelihood = -1326.467570821675
Iteration 6: Log-likelihood = -1326.3638270426625
Iteration 7: Log-likelihood = -1326.2777493096357
Iteration 8: Log-likelihood = -1326.2023974460274
Iteration 9: Log-likelihood = -1326.1354203562712
Iteration 10: Log-likelihood = -1326.0755374357018
Iteration 11: Log-likelihood = -1326.0217694221124
Iteration 12: Log-likelihood = -1325.973278100635
Iteration 13: Log-likelihood = -1325.9293340939412
Iteration 14: Log-likelihood = -1325.8893078547082
Iteration 15: Log-likelihood = -1325.852662475567
Iteration 16: Log-likelihood = -1325.8189450452733
Iteration 17: Log-likelihood = -1325.7877769101033
Iteration 18: Log-likelihood = -1325.758843667468
Iteration 19: Log-likelihood = -1325.731885545676
Iteration 20: Log-likelihood = -1325.7066885738543
Iteration 21: Log-likelihood = -1325.683076740728
Iteration 22: Log-likelihood = -1325.6609051972093
Iteration 23: Log-likelihood = -1325.6400544671556
Iteration 24: Log-likelihood = -1325.6204255803232
Iteration 25: Log-likelihood = -1325.6019360188163
Iteration 26: Log-likelihood = -1325.5845163637287
Iteration 27: Log-likelihood = -1325.5681075344178
Iteration 28: Log-likelihood = -1325.552658524376
Iteration 29: Log-likelihood = -1325.5381245509134
Iteration 30: Log-likelihood = -1325.5244655495499
Iteration 31: Log-likelihood = -1325.5116449561958
Iteration 32: Log-likelihood = -1325.4996287308434
Iteration 33: Log-likelihood = -1325.488384585261
Iteration 34: Log-likelihood = -1325.477881383886
Iteration 35: Log-likelihood = -1325.4680886924598
Iteration 36: Log-likelihood = -1325.4589764527625
Iteration 37: Log-likelihood = -1325.4505147645657
Iteration 38: Log-likelihood = -1325.4426737579938
Iteration 39: Log-likelihood = -1325.43542354092
Iteration 40: Log-likelihood = -1325.428734207256
Iteration 41: Log-likelihood = -1325.4225758930315
Iteration 42: Log-likelihood = -1325.4169188681706
Iteration 43: Log-likelihood = -1325.4117336529428
Iteration 44: Log-likelihood = -1325.4069911493527
Iteration 45: Log-likelihood = -1325.4026627787734
Iteration 46: Log-likelihood = -1325.3987206186696
Iteration 47: Log-likelihood = -1325.3951375324243
Iteration 48: Log-likelihood = -1325.3918872877184
Iteration 49: Log-likelihood = -1325.3889446600956
Iteration 50: Log-likelihood = -1325.386285519564
Iteration 51: Log-likelihood = -1325.3838868990313
Iteration 52: Log-likelihood = -1325.3817270443728
Iteration 53: Log-likelihood = -1325.3797854465192
Iteration 54: Log-likelihood = -1325.3780428565944
Iteration 55: Log-likelihood = -1325.376481285554
Iteration 56: Log-likelihood = -1325.3750839899778
Iteration 57: Log-likelihood = -1325.3738354459229
Iteration 58: Log-likelihood = -1325.372721312708
Iteration 59: Log-likelihood = -1325.3717283885492
Iteration 60: Log-likelihood = -1325.3708445598402
Iteration 61: Log-likelihood = -1325.3700587457647
Iteration 62: Log-likelihood = -1325.3693608397173
Iteration 63: Log-likelihood = -1325.3687416489377
Iteration 64: Log-likelihood = -1325.3681928334458
Iteration 65: Log-likelihood = -1325.3677068453032
Iteration 66: Log-likelihood = -1325.3672768689937
Iteration 67: Log-likelihood = -1325.3668967635374
Iteration 68: Log-likelihood = -1325.3665610068822
Iteration 69: Log-likelihood = -1325.3662646429218
Iteration 70: Log-likelihood = -1325.3660032313826
Iteration 71: Log-likelihood = -1325.3657728007604
Iteration 72: Log-likelihood = -1325.365569804385
Iteration 73: Log-likelihood = -1325.365391079605
Iteration 74: Log-likelihood = -1325.365233810087
Iteration 75: Log-likelihood = -1325.3650954911204
Iteration 76: Log-likelihood = -1325.3649738978172
Iteration 77: Log-likelihood = -1325.3648670560872
Iteration 78: Log-likelihood = -1325.3647732162158
Iteration 79: Log-likelihood = -1325.3646908289093
Iteration 80: Log-likelihood = -1325.3646185235916
Iteration 81: Log-likelihood = -1325.3645550888625
Iteration 82: Log-likelihood = -1325.364499454864
Iteration 83: Log-likelihood = -1325.364450677479
Iteration 84: Log-likelihood = -1325.3644079241253
Iteration 85: Log-likelihood = -1325.3643704610927
Iteration 86: Log-likelihood = -1325.3643376421846
Iteration 87: Log-likelihood = -1325.3643088986253
Iteration 88: Log-likelihood = -1325.3642837300529
Iteration 89: Log-likelihood = -1325.3642616965142
Iteration 90: Log-likelihood = -1325.3642424113561
Iteration 91: Log-likelihood = -1325.3642255349266
Iteration 92: Log-likelihood = -1325.3642107689748
Iteration 93: Log-likelihood = -1325.3641978517248
Iteration 94: Log-likelihood = -1325.3641865534828
Iteration 95: Log-likelihood = -1325.3641766727762
Iteration 96: Log-likelihood = -1325.3641680329517
Iteration 97: Log-likelihood = -1325.364160479159
Iteration 98: Log-likelihood = -1325.3641538756922
Iteration 99: Log-likelihood = -1325.3641481036734
Iteration 100: Log-likelihood = -1325.3641430589653
EM converged in 100 iterations
Log-likelihood improved by 7.0

Plot EM convergence

p2 = plot(lls, xlabel="EM Iteration", ylabel="Log-Likelihood",
          title="EM Algorithm Convergence", legend=false,
          marker=:circle, markersize=3, lw=2, color=:darkblue)

if length(lls) < 100
    annotate!(p2, length(lls)*0.7, lls[end]*0.95,
        text("Converged in $(length(lls)) iterations", 10)) # Add convergence annotation
end

Visualize Fitted Model

Create visualization showing both data and fitted GMM with probability contours. Create grid for plotting contours

x_range = range(extrema(X[1, :])..., length=100)
y_range = range(extrema(X[2, :])..., length=100)
xs = collect(x_range)
ys = collect(y_range)

p3 = scatter(X[1, :], X[2, :];
    markersize=3,
    alpha=0.5,
    color=:gray,
    xlabel=L"x_1",
    ylabel=L"x_2",
    title="Fitted GMM Components",
    legend=:topright,
    label="Data points"
)

p3

colors = [:red, :green, :blue] # Plot probability density contours for each learned component
for i in 1:fit_gmm.k
    comp_dist = MvNormal(fit_gmm.μₖ[:, i], fit_gmm.Σₖ[i])
    Z_i = [fit_gmm.πₖ[i] * pdf(comp_dist, [x, y]) for y in ys, x in xs]

    contour!(p3, xs, ys, Z_i;
        levels=6,
        linewidth=2,
        c=colors[i],
        label="Component $i (π=$(round(fit_gmm.πₖ[i], digits=2)))"
    )

    scatter!(p3, [fit_gmm.μₖ[1, i]], [fit_gmm.μₖ[2, i]];
        marker=:star, markersize=8, color=colors[i],
        markerstrokewidth=2, markerstrokecolor=:black,
        label="")
end

Component Assignment Analysis

Use fitted model to assign each data point to its most likely component and compare with true assignments.

Get posterior probabilities: $P(\text{component } i | \mathbf{x}_j)$

predicted_labels = [argmax(class_probabilities[:, j]) for j in 1:n];

Calculate assignment accuracy (accounting for possible label permutation) Since EM can converge with components in different order

function best_permutation_accuracy(true_labels, pred_labels, k)
    best_acc = 0.0
    best_perm = collect(1:k)

    for perm in Combinatorics.permutations(1:k)
        mapped_pred = [perm[pred_labels[i]] for i in 1:length(pred_labels)]
        acc = mean(true_labels .== mapped_pred)
        if acc > best_acc
            best_acc = acc
            best_perm = perm
        end
    end

    return best_acc, best_perm
end

accuracy, best_perm = best_permutation_accuracy(labels, predicted_labels, k)
print("Component assignment accuracy: $(round(accuracy*100, digits=1))%\n");
Component assignment accuracy: 97.0%

Final Comparison Visualization

Side-by-side comparison of true vs learned component assignments

p_true = scatter(X[1, :], X[2, :]; group=labels, title="True Components",
                xlabel=L"x_1", ylabel=L"x_2", markersize=3, alpha=0.7,
                palette=:Set1_3, legend=false)

remapped_predicted = [best_perm[predicted_labels[i]] for i in 1:n] # Apply best permutation to predicted labels for fair comparison
p_learned = scatter(X[1, :], X[2, :]; group=remapped_predicted, title="Learned Components",
                   xlabel=L"x_1", ylabel=L"x_2", markersize=3, alpha=0.7,
                   palette=:Set1_3, legend=false)

p4 = plot(p_true, p_learned, layout=(1, 2), size=(800, 350))
Example block output

Summary

This tutorial demonstrated the complete Gaussian Mixture Model workflow:

Key Concepts:

  • Mixture modeling: Complex distributions as weighted combinations of simpler Gaussians
  • EM algorithm: Iterative parameter learning via expectation-maximization
  • Soft clustering: Probabilistic component assignments rather than hard clusters
  • Label permutation: Handling component identifiability issues

Applications:

  • Unsupervised clustering and density estimation
  • Anomaly detection via likelihood thresholding
  • Dimensionality reduction (when extended to factor analysis)
  • Building blocks for more complex probabilistic models

Technical Insights:

  • K-means initialization significantly improves EM convergence
  • Log-likelihood monitoring ensures proper algorithm behavior
  • Parameter recovery quality depends on component separation and sample size

GMMs provide a flexible, interpretable framework for modeling heterogeneous data with multiple underlying modes or clusters, forming the foundation for many advanced machine learning techniques.


This page was generated using Literate.jl.