Simulating and Fitting a Gaussian Mixture Model
This tutorial demonstrates how to use StateSpaceDynamics.jl
to create a Gaussian Mixture Model and fit it using the EM algorithm.
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using Distributions
using StatsPlots
rng = StableRNG(1234);
Create a true GaussianMixtureModel to simulate from
k = 3
D = 2 # data dimension
true_μs = [
-1.0 1.0 0.0;
-1.0 -1.5 2.0
] # shape (D, K)
true_Σs = [Matrix{Float64}(0.3 * I(2)) for _ in 1:k]
true_πs = [0.5, 0.2, 0.3]
true_gmm = GaussianMixtureModel(k, true_μs, true_Σs, true_πs)
GaussianMixtureModel{Float64, Matrix{Float64}, Vector{Float64}}(3, [-1.0 1.0 0.0; -1.0 -1.5 2.0], [[0.3 0.0; 0.0 0.3], [0.3 0.0; 0.0 0.3], [0.3 0.0; 0.0 0.3]], [0.5, 0.2, 0.3])
Sample data from the true GMM
n = 500
500
generate component labels (for plotting)
labels = rand(rng, Categorical(true_πs), n)
500-element Vector{Int64}:
3
1
2
1
3
1
1
2
2
3
⋮
1
1
3
1
3
1
1
2
1
generate samples from the GMM
X = Matrix{Float64}(undef, D, n)
for i in 1:n
X[:, i] = rand(rng, MvNormal(true_μs[:, labels[i]], true_Σs[labels[i]]))
end
p1 = scatter(
X[1, :], X[2, :];
group=labels,
title="GMM Samples",
xlabel="x₁", ylabel="x₂",
markersize=4,
alpha=0.8,
legend=false,
)
p1
Fit a new GaussianMixtureModel to the data
fit_gmm = GaussianMixtureModel(k, D)
class_probabilities, lls = fit!(fit_gmm, X;
maxiter=100, tol=1e-6, initialize_kmeans=true)
([0.9999951135892645 4.414883238842431e-6 … 8.905753114892089e-10 3.672409253921998e-7; 2.9329503175003706e-8 0.9937244994317322 … 3.390969963330243e-5 0.9899321992425393; 4.857081232344198e-6 0.006271085685029012 … 0.9999660894097914 0.010067433516535456], [-1381.37940397858, -1346.1979506292728, -1337.4800124364558, -1333.3203701042023, -1331.1942812522202, -1330.0614587093876, -1329.423328446764, -1329.0430243835804, -1328.8039060395492, -1328.6458503543663 … -1325.4765754320758, -1325.4668773947649, -1325.4578535490584, -1325.4494748627392, -1325.4417120539758, -1325.4345356242895, -1325.427915930257, -1325.4218232860674, -1325.4162280888536, -1325.4111009588971])
Plot log-likelihoods to visualize EM convergence
p2 = plot(
lls;
xlabel="Iteration",
ylabel="Log-Likelihood",
title="EM Convergence",
label="log_likelihood",
marker=:circle,
)
p2
Visualize model contours over the data
xs = collect(range(minimum(X[1, :]) - 1, stop=maximum(X[1, :]) + 1, length=150))
ys = collect(range(minimum(X[2, :]) - 1, stop=maximum(X[2, :]) + 1, length=150))
p3 = scatter(
X[1, :], X[2, :];
markersize=3, alpha=0.5,
xlabel="x₁", ylabel="x₂",
title="Data & Fitted GMM Contours by Component",
legend=:topright,
)
colors = [:red, :green, :blue]
for i in 1:fit_gmm.k
comp_dist = MvNormal(fit_gmm.μₖ[:, i], fit_gmm.Σₖ[i])
Z_i = [fit_gmm.πₖ[i] * pdf(comp_dist, [x, y]) for y in ys, x in xs]
contour!(
p3, xs, ys, Z_i;
levels=10,
linewidth=2,
c=colors[i],
label="Comp $i",
)
end
p3
This page was generated using Literate.jl.