Simulating and Fitting a Gaussian Mixture Model

This tutorial demonstrates how to use StateSpaceDynamics.jl to create a Gaussian Mixture Model and fit it using the EM algorithm.

using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using Distributions
using StatsPlots

rng = StableRNG(1234);

Create a true GaussianMixtureModel to simulate from

k = 3
D = 2  # data dimension

true_μs = [
    -1.0  1.0  0.0;
    -1.0 -1.5  2.0
]  # shape (D, K)

true_Σs = [Matrix{Float64}(0.3 * I(2)) for _ in 1:k]
true_πs = [0.5, 0.2, 0.3]

true_gmm = GaussianMixtureModel(k, true_μs, true_Σs, true_πs)
GaussianMixtureModel{Float64, Matrix{Float64}, Vector{Float64}}(3, [-1.0 1.0 0.0; -1.0 -1.5 2.0], [[0.3 0.0; 0.0 0.3], [0.3 0.0; 0.0 0.3], [0.3 0.0; 0.0 0.3]], [0.5, 0.2, 0.3])

Sample data from the true GMM

n = 500
500

generate component labels (for plotting)

labels = rand(rng, Categorical(true_πs), n)
500-element Vector{Int64}:
 3
 1
 2
 1
 3
 1
 1
 2
 2
 3
 ⋮
 1
 1
 3
 1
 3
 1
 1
 2
 1

generate samples from the GMM

X = Matrix{Float64}(undef, D, n)
for i in 1:n
    X[:, i] = rand(rng, MvNormal(true_μs[:, labels[i]], true_Σs[labels[i]]))
end

p1 = scatter(
    X[1, :], X[2, :];
    group=labels,
    title="GMM Samples",
    xlabel="x₁", ylabel="x₂",
    markersize=4,
    alpha=0.8,
    legend=false,
)
p1
Example block output

Fit a new GaussianMixtureModel to the data

fit_gmm = GaussianMixtureModel(k, D)

class_probabilities, lls = fit!(fit_gmm, X;
    maxiter=100, tol=1e-6, initialize_kmeans=true)
([5.310063909271477e-8 0.996452048431004 … 5.6301597142468816e-5 0.9936789735505989; 0.9999970666717578 4.598282409606046e-6 … 1.0549960901694789e-9 3.8953968949165076e-7; 2.8802276028945277e-6 0.0035433532865863796 … 0.9999436973478617 0.006320636909711524], [-1377.40494060779, -1349.9809774669948, -1340.2957057755045, -1333.6668320602814, -1329.6864098664212, -1327.7744756706663, -1326.9215970049745, -1326.5033154271491, -1326.2635057893037, -1326.108189541366  …  -1325.3641802083969, -1325.3641711244134, -1325.3641631819098, -1325.3641562383252, -1325.364150168751, -1325.3641448637682, -1325.3641402275503, -1325.3641361762066, -1325.3641326362908, -1325.364129543519])

Plot log-likelihoods to visualize EM convergence

p2 = plot(
    lls;
    xlabel="Iteration",
    ylabel="Log-Likelihood",
    title="EM Convergence",
    label="log_likelihood",
    marker=:circle,
)
p2
Example block output

Visualize model contours over the data

xs = collect(range(minimum(X[1, :]) - 1, stop=maximum(X[1, :]) + 1, length=150))
ys = collect(range(minimum(X[2, :]) - 1, stop=maximum(X[2, :]) + 1, length=150))

p3 = scatter(
    X[1, :], X[2, :];
    markersize=3, alpha=0.5,
    xlabel="x₁", ylabel="x₂",
    title="Data & Fitted GMM Contours by Component",
    legend=:topright,
)

colors = [:red, :green, :blue]

for i in 1:fit_gmm.k
    comp_dist = MvNormal(fit_gmm.μₖ[:, i], fit_gmm.Σₖ[i])
    Z_i = [fit_gmm.πₖ[i] * pdf(comp_dist, [x, y]) for y in ys, x in xs]

    contour!(
        p3, xs, ys, Z_i;
        levels=10,
        linewidth=2,
        c=colors[i],
        label="Comp $i",
    )
end

p3
Example block output

This page was generated using Literate.jl.