Simulating and Fitting a Gaussian Mixture Model

This tutorial demonstrates how to use StateSpaceDynamics.jl to create a Gaussian Mixture Model (GMM) and fit it using the EM algorithm. Unlike Hidden Markov Models which model temporal sequences, GMMs are designed for clustering and density estimation of independent observations. Each data point is assumed to come from one of several Gaussian components, but there's no temporal dependence.

GMMs are fundamental in machine learning for unsupervised clustering, density estimation, anomaly detection, and as building blocks for more complex models. The key insight is that complex data distributions can often be well-approximated as mixtures of simpler Gaussian distributions, each representing a different "mode" or cluster in the data.

Load Required Packages

using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using Distributions
using StatsPlots
using Combinatorics
using LaTeXStrings

Set up reproducible random number generation

rng = StableRNG(1234);

Create a True Gaussian Mixture Model

We'll create a "ground truth" GMM with known parameters, generate data from it, then see how well we can recover these parameters using only the observed data.

k = 3  # Number of mixture components (clusters)
D = 2  # Data dimensionality (2D for easy visualization)
2

Define the true component means: $\boldsymbol{\mu}_i \in \mathbb{R}^D$ for $i = 1, \ldots, k$ Each column represents the mean vector $\boldsymbol{\mu}_i$ for one component

true_μs = [
    -1.0  1.0  0.0;   # $x_1$ coordinates of the 3 component centers
    -1.0 -1.5  2.0    # $x_2$ coordinates of the 3 component centers
];  # Shape: $(D, k) = (2, 3)$

Define covariance matrices $\boldsymbol{\Sigma}_i$ for each component Using isotropic (spherical) covariances for simplicity

true_Σs = [Matrix{Float64}(0.3 * I(2)) for _ in 1:k];

Define mixing weights $\pi_i$ (must sum to 1) These represent $P(\text{component} = i)$ for a random sample

true_πs = [0.5, 0.2, 0.3];  # Component 1 most likely, component 2 least likely

Construct the complete GMM

true_gmm = GaussianMixtureModel(k, true_μs, true_Σs, true_πs);

print("Created GMM: $k components, $D dimensions\n")
for i in 1:k
    print("Component $i: μ = $(true_μs[:, i]), π = $(true_πs[i])\n")
end
Created GMM: 3 components, 2 dimensions
Component 1: μ = [-1.0, -1.0], π = 0.5
Component 2: μ = [1.0, -1.5], π = 0.2
Component 3: μ = [0.0, 2.0], π = 0.3

Sample Data from the True GMM

Generate synthetic data from our true model. We'll sample both component assignments (for evaluation) and the actual observations.

n = 500  # Number of data points to generate
500

Determine which component each sample comes from

labels = rand(rng, Categorical(true_πs), n);

Count samples per component for verification

component_counts = [sum(labels .== i) for i in 1:k]
print("Samples per component: $(component_counts) (expected: $(round.(n .* true_πs)))\n");
Samples per component: [252, 96, 152] (expected: [250.0, 100.0, 150.0])

Generate the actual data points

X = Matrix{Float64}(undef, D, n)
for i in 1:n
    component = labels[i]
    X[:, i] = rand(rng, MvNormal(true_μs[:, component], true_Σs[component]))
end

Visualize the generated data colored by true component membership

p1 = scatter(X[1, :], X[2, :];
    group=labels,
    title="True GMM Components",
    xlabel=L"x_1", ylabel=L"x_2",
    markersize=4,
    alpha=0.7,
    palette=:Set1_3,
    legend=:topright
)

for i in 1:k
    scatter!(p1, [true_μs[1, i]], [true_μs[2, i]];
        marker=:star, markersize=10, color=i,
        markerstrokewidth=2, markerstrokecolor=:black,
        label="")
end

p1
Example block output

Fit GMM Using EM Algorithm

Now we simulate the realistic scenario: observe only data points $\mathbf{X}$, not the true component labels or parameters. Our goal is to recover the underlying mixture structure using EM.

Initialize a GMM with correct number of components but unknown parameters

fit_gmm = GaussianMixtureModel(k, D)

print("Running EM algorithm...")
Running EM algorithm...

Fit the model using EM algorithm

class_probabilities, lls = fit!(fit_gmm, X;
    maxiter=100,
    tol=1e-6,
    initialize_kmeans=true  # K-means initialization helps convergence
);

print("EM converged in $(length(lls)) iterations\n")
print("Log-likelihood improved by $(round(lls[end] - lls[1], digits=1))\n");
Iteration 1: Log-likelihood = -1389.9540484217016
Iteration 2: Log-likelihood = -1356.263583656021
Iteration 3: Log-likelihood = -1348.0824392868528
Iteration 4: Log-likelihood = -1344.470128888572
Iteration 5: Log-likelihood = -1341.5868151824936
Iteration 6: Log-likelihood = -1338.847180965278
Iteration 7: Log-likelihood = -1336.2652281481735
Iteration 8: Log-likelihood = -1334.0844814520226
Iteration 9: Log-likelihood = -1332.4054672863692
Iteration 10: Log-likelihood = -1331.1637585686196
Iteration 11: Log-likelihood = -1330.2713861251125
Iteration 12: Log-likelihood = -1329.6506981349426
Iteration 13: Log-likelihood = -1329.2311196337544
Iteration 14: Log-likelihood = -1328.9515142006971
Iteration 15: Log-likelihood = -1328.7642888100243
Iteration 16: Log-likelihood = -1328.6359927189026
Iteration 17: Log-likelihood = -1328.544818255028
Iteration 18: Log-likelihood = -1328.4771801193533
Iteration 19: Log-likelihood = -1328.4247761830434
Iteration 20: Log-likelihood = -1328.3825238614113
Iteration 21: Log-likelihood = -1328.3472578593226
Iteration 22: Log-likelihood = -1328.316952791291
Iteration 23: Log-likelihood = -1328.290269818344
Iteration 24: Log-likelihood = -1328.266292684737
Iteration 25: Log-likelihood = -1328.2443720311148
Iteration 26: Log-likelihood = -1328.2240315079316
Iteration 27: Log-likelihood = -1328.2049095770592
Iteration 28: Log-likelihood = -1328.1867223292675
Iteration 29: Log-likelihood = -1328.1692389654418
Iteration 30: Log-likelihood = -1328.152265084379
Iteration 31: Log-likelihood = -1328.1356308765678
Iteration 32: Log-likelihood = -1328.1191824401378
Iteration 33: Log-likelihood = -1328.1027750868407
Iteration 34: Log-likelihood = -1328.0862678958356
Iteration 35: Log-likelihood = -1328.069519012055
Iteration 36: Log-likelihood = -1328.0523813376958
Iteration 37: Log-likelihood = -1328.034698367206
Iteration 38: Log-likelihood = -1328.0162999930194
Iteration 39: Log-likelihood = -1327.99699818013
Iteration 40: Log-likelihood = -1327.9765824892388
Iteration 41: Log-likelihood = -1327.9548155400196
Iteration 42: Log-likelihood = -1327.9314286679785
Iteration 43: Log-likelihood = -1327.9061182631644
Iteration 44: Log-likelihood = -1327.8785436026378
Iteration 45: Log-likelihood = -1327.8483273978366
Iteration 46: Log-likelihood = -1327.8150607234527
Iteration 47: Log-likelihood = -1327.778314343119
Iteration 48: Log-likelihood = -1327.7376584511835
Iteration 49: Log-likelihood = -1327.6926921566685
Iteration 50: Log-likelihood = -1327.6430823074584
Iteration 51: Log-likelihood = -1327.5886084361318
Iteration 52: Log-likelihood = -1327.5292072784596
Iteration 53: Log-likelihood = -1327.465007804374
Iteration 54: Log-likelihood = -1327.3963476455203
Iteration 55: Log-likelihood = -1327.323765086477
Iteration 56: Log-likelihood = -1327.2479665969258
Iteration 57: Log-likelihood = -1327.1697758827947
Iteration 58: Log-likelihood = -1327.090074165956
Iteration 59: Log-likelihood = -1327.009741773066
Iteration 60: Log-likelihood = -1326.9296087550024
Iteration 61: Log-likelihood = -1326.8504187143749
Iteration 62: Log-likelihood = -1326.7728067080518
Iteration 63: Log-likelihood = -1326.6972897734129
Iteration 64: Log-likelihood = -1326.6242674291925
Iteration 65: Log-likelihood = -1326.5540291945072
Iteration 66: Log-likelihood = -1326.486766420419
Iteration 67: Log-likelihood = -1326.4225862451162
Iteration 68: Log-likelihood = -1326.3615260645256
Iteration 69: Log-likelihood = -1326.3035674401706
Iteration 70: Log-likelihood = -1326.2486487954152
Iteration 71: Log-likelihood = -1326.1966765705747
Iteration 72: Log-likelihood = -1326.1475347283908
Iteration 73: Log-likelihood = -1326.1010926436602
Iteration 74: Log-likelihood = -1326.057211494823
Iteration 75: Log-likelihood = -1326.0157493186052
Iteration 76: Log-likelihood = -1325.9765649049875
Iteration 77: Log-likelihood = -1325.939520708951
Iteration 78: Log-likelihood = -1325.9044849445645
Iteration 79: Log-likelihood = -1325.8713330107512
Iteration 80: Log-likelihood = -1325.8399483795322
Iteration 81: Log-likelihood = -1325.8102230587806
Iteration 82: Log-likelihood = -1325.7820577236162
Iteration 83: Log-likelihood = -1325.7553615942368
Iteration 84: Log-likelihood = -1325.7300521236768
Iteration 85: Log-likelihood = -1325.7060545466402
Iteration 86: Log-likelihood = -1325.6833013304226
Iteration 87: Log-likelihood = -1325.661731560489
Iteration 88: Log-likelihood = -1325.6412902867307
Iteration 89: Log-likelihood = -1325.6219278512485
Iteration 90: Log-likelihood = -1325.6035992146017
Iteration 91: Log-likelihood = -1325.5862632944454
Iteration 92: Log-likelihood = -1325.5698823283244
Iteration 93: Log-likelihood = -1325.5544212705483
Iteration 94: Log-likelihood = -1325.5398472316917
Iteration 95: Log-likelihood = -1325.5261289679147
Iteration 96: Log-likelihood = -1325.5132364260055
Iteration 97: Log-likelihood = -1325.5011403486728
Iteration 98: Log-likelihood = -1325.4898119431143
Iteration 99: Log-likelihood = -1325.4792226143236
Iteration 100: Log-likelihood = -1325.4693437629173
EM converged in 100 iterations
Log-likelihood improved by 64.5

Plot EM convergence

p2 = plot(lls, xlabel="EM Iteration", ylabel="Log-Likelihood",
          title="EM Algorithm Convergence", legend=false,
          marker=:circle, markersize=3, lw=2, color=:darkblue)

if length(lls) < 100
    annotate!(p2, length(lls)*0.7, lls[end]*0.95,
        text("Converged in $(length(lls)) iterations", 10)) # Add convergence annotation
end

Visualize Fitted Model

Create visualization showing both data and fitted GMM with probability contours. Create grid for plotting contours

x_range = range(extrema(X[1, :])..., length=100)
y_range = range(extrema(X[2, :])..., length=100)
xs = collect(x_range)
ys = collect(y_range)

p3 = scatter(X[1, :], X[2, :];
    markersize=3,
    alpha=0.5,
    color=:gray,
    xlabel=L"x_1",
    ylabel=L"x_2",
    title="Fitted GMM Components",
    legend=:topright,
    label="Data points"
)

p3

colors = [:red, :green, :blue] # Plot probability density contours for each learned component
for i in 1:fit_gmm.k
    comp_dist = MvNormal(fit_gmm.μₖ[:, i], fit_gmm.Σₖ[i])
    Z_i = [fit_gmm.πₖ[i] * pdf(comp_dist, [x, y]) for y in ys, x in xs]

    contour!(p3, xs, ys, Z_i;
        levels=6,
        linewidth=2,
        c=colors[i],
        label="Component $i (π=$(round(fit_gmm.πₖ[i], digits=2)))"
    )

    scatter!(p3, [fit_gmm.μₖ[1, i]], [fit_gmm.μₖ[2, i]];
        marker=:star, markersize=8, color=colors[i],
        markerstrokewidth=2, markerstrokecolor=:black,
        label="")
end

Component Assignment Analysis

Use fitted model to assign each data point to its most likely component and compare with true assignments.

Get posterior probabilities: $P(\text{component } i | \mathbf{x}_j)$

predicted_labels = [argmax(class_probabilities[:, j]) for j in 1:n];

Calculate assignment accuracy (accounting for possible label permutation) Since EM can converge with components in different order

function best_permutation_accuracy(true_labels, pred_labels, k)
    best_acc = 0.0
    best_perm = collect(1:k)

    for perm in Combinatorics.permutations(1:k)
        mapped_pred = [perm[pred_labels[i]] for i in 1:length(pred_labels)]
        acc = mean(true_labels .== mapped_pred)
        if acc > best_acc
            best_acc = acc
            best_perm = perm
        end
    end

    return best_acc, best_perm
end

accuracy, best_perm = best_permutation_accuracy(labels, predicted_labels, k)
print("Component assignment accuracy: $(round(accuracy*100, digits=1))%\n");
Component assignment accuracy: 95.8%

Final Comparison Visualization

Side-by-side comparison of true vs learned component assignments

p_true = scatter(X[1, :], X[2, :]; group=labels, title="True Components",
                xlabel=L"x_1", ylabel=L"x_2", markersize=3, alpha=0.7,
                palette=:Set1_3, legend=false)

remapped_predicted = [best_perm[predicted_labels[i]] for i in 1:n] # Apply best permutation to predicted labels for fair comparison
p_learned = scatter(X[1, :], X[2, :]; group=remapped_predicted, title="Learned Components",
                   xlabel=L"x_1", ylabel=L"x_2", markersize=3, alpha=0.7,
                   palette=:Set1_3, legend=false)

p4 = plot(p_true, p_learned, layout=(1, 2), size=(800, 350))
Example block output

Summary

This tutorial demonstrated the complete Gaussian Mixture Model workflow:

Key Concepts:

  • Mixture modeling: Complex distributions as weighted combinations of simpler Gaussians
  • EM algorithm: Iterative parameter learning via expectation-maximization
  • Soft clustering: Probabilistic component assignments rather than hard clusters
  • Label permutation: Handling component identifiability issues

Applications:

  • Unsupervised clustering and density estimation
  • Anomaly detection via likelihood thresholding
  • Dimensionality reduction (when extended to factor analysis)
  • Building blocks for more complex probabilistic models

Technical Insights:

  • K-means initialization significantly improves EM convergence
  • Log-likelihood monitoring ensures proper algorithm behavior
  • Parameter recovery quality depends on component separation and sample size

GMMs provide a flexible, interpretable framework for modeling heterogeneous data with multiple underlying modes or clusters, forming the foundation for many advanced machine learning techniques.


This page was generated using Literate.jl.