Simulating and Fitting a Poisson Mixture Model
This tutorial demonstrates how to use StateSpaceDynamics.jl
to create a Poisson Mixture Model and fit it using the EM algorithm.
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using StatsPlots
using Distributions
rng = StableRNG(1234);
Create a true PoissonMixtureModel to simulate from
k = 3
true_λs = [5.0, 10.0, 25.0] # Poisson means
true_πs = [0.25, 0.45, 0.3] # Mixing weights
true_pmm = PoissonMixtureModel(k, true_λs, true_πs)
PoissonMixtureModel{Float64, Vector{Float64}}(3, [5.0, 10.0, 25.0], [0.25, 0.45, 0.3])
Generate data from the true model
n = 500
labels = rand(rng, Categorical(true_πs), n)
data = [rand(rng, Poisson(true_λs[labels[i]])) for i in 1:n] # Vector{Int}
500-element Vector{Int64}:
33
9
9
8
25
10
5
15
5
30
⋮
1
10
21
7
24
25
5
7
9
Plot histogram of Poisson samples by component
p1 = histogram(
data;
group=labels,
bins=0:1:maximum(data),
bar_position=:dodge,
xlabel="Count",
ylabel="Frequency",
title="Poisson Mixture Samples by Component",
alpha=0.7,
legend=:topright,
)
p1
Fit a new PoissonMixtureModel to the data
fit_pmm = PoissonMixtureModel(k)
_, lls = fit!(fit_pmm, data; maxiter=100, tol=1e-6, initialize_kmeans=true)
([1.167299485809293e-6 0.7670285733361216 … 0.4456154946697474 0.7670285733361216; 0.9999988327004986 0.0007940905571800423 … 8.334359187998249e-5 0.0007940905571800423; 1.5235751770191232e-14 0.23217733610669858 … 0.5543011617383728 0.23217733610669858], [-1677.6322074652844, -1674.3624894485276, -1673.5404197496819, -1673.2181725367448, -1673.0264176702262, -1672.882289515502, -1672.764176466248, -1672.6648070344565, -1672.5805981923975, -1672.5091100384484 … -1672.1069591732414, -1672.1069557303338, -1672.1069528037415, -1672.1069503160506, -1672.1069482014225, -1672.10694640391, -1672.1069448759567, -1672.1069435771435, -1672.106942473089, -1672.1069415346014])
Plot log-likelihoods to visualize EM convergence
p2 = plot(
lls;
xlabel="Iteration",
ylabel="Log-Likelihood",
title="EM Convergence (Poisson Mixture)",
marker=:circle,
label="log_likelihood",
)
p2
Plot model PMFs over the data histogram
p3 = histogram(
data;
bins=0:1:maximum(data),
normalize=true,
alpha=0.3,
label="Data",
xlabel="Count",
ylabel="Density",
title="Poisson Mixtures: Data and PMFs",
)
x = collect(0:maximum(data))
colors = [:red, :green, :blue]
for i in 1:k
λi = fit_pmm.λₖ[i]
πi = fit_pmm.πₖ[i]
pmf_i = πi .* pdf.(Poisson(λi), x)
plot!(
p3, x, pmf_i;
lw=2,
c=colors[i],
label="Comp $i (λ=$(round(λi, sigdigits=3)))",
)
end
mix_pmf = reduce(+, (πi .* pdf.(Poisson(λi), x) for (λi, πi) in zip(fit_pmm.λₖ, fit_pmm.πₖ)))
plot!(
p3, x, mix_pmf;
lw=3, ls=:dash, c=:black,
label="Mixture",
)
p3
This page was generated using Literate.jl.