Simulating and Fitting a Poisson Mixture Model
This tutorial demonstrates how to use StateSpaceDynamics.jl
to create a Poisson Mixture Model and fit it using the EM algorithm.
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using StatsPlots
using Distributions
rng = StableRNG(1234);
Create a true PoissonMixtureModel to simulate from
k = 3
true_λs = [5.0, 10.0, 25.0] # Poisson means
true_πs = [0.25, 0.45, 0.3] # Mixing weights
true_pmm = PoissonMixtureModel(k, true_λs, true_πs)
PoissonMixtureModel{Float64, Vector{Float64}}(3, [5.0, 10.0, 25.0], [0.25, 0.45, 0.3])
Generate data from the true model
n = 500
labels = rand(rng, Categorical(true_πs), n)
data = [rand(rng, Poisson(true_λs[labels[i]])) for i in 1:n] # Vector{Int}
500-element Vector{Int64}:
33
9
9
8
25
10
5
15
5
30
⋮
1
10
21
7
24
25
5
7
9
Plot histogram of Poisson samples by component
p1 = histogram(
data;
group=labels,
bins=0:1:maximum(data),
bar_position=:dodge,
xlabel="Count",
ylabel="Frequency",
title="Poisson Mixture Samples by Component",
alpha=0.7,
legend=:topright,
)
p1
Fit a new PoissonMixtureModel to the data
fit_pmm = PoissonMixtureModel(k)
_, lls = fit!(fit_pmm, data; maxiter=100, tol=1e-6, initialize_kmeans=true)
([1.5417854767112443e-14 0.23281667300994824 … 0.5551251910469738 0.23281667300994824; 0.9999988264615685 0.0007939518191292314 … 8.328588068036678e-5 0.0007939518191292314; 1.1735384162374018e-6 0.7663893751709226 … 0.44479152307234593 0.7663893751709226], [-1701.6126496215945, -1688.5801867165533, -1682.2780488227118, -1679.1137695443554, -1677.4220685907212, -1676.4090593967044, -1675.7141513588833, -1675.1814043453699, -1674.7444722444727, -1674.3740832673652 … -1672.106957545556, -1672.1069543474243, -1672.1069516287737, -1672.1069493177001, -1672.106947353116, -1672.1069456830764, -1672.1069442634239, -1672.1069430566226, -1672.1069420307567, -1672.1069411586964])
Plot log-likelihoods to visualize EM convergence
p2 = plot(
lls;
xlabel="Iteration",
ylabel="Log-Likelihood",
title="EM Convergence (Poisson Mixture)",
marker=:circle,
label="log_likelihood",
)
p2
Plot model PMFs over the data histogram
p3 = histogram(
data;
bins=0:1:maximum(data),
normalize=true,
alpha=0.3,
label="Data",
xlabel="Count",
ylabel="Density",
title="Poisson Mixtures: Data and PMFs",
)
x = collect(0:maximum(data))
colors = [:red, :green, :blue]
for i in 1:k
λi = fit_pmm.λₖ[i]
πi = fit_pmm.πₖ[i]
pmf_i = πi .* pdf.(Poisson(λi), x)
plot!(
p3, x, pmf_i;
lw=2,
c=colors[i],
label="Comp $i (λ=$(round(λi, sigdigits=3)))",
)
end
mix_pmf = reduce(+, (πi .* pdf.(Poisson(λi), x) for (λi, πi) in zip(fit_pmm.λₖ, fit_pmm.πₖ)))
plot!(
p3, x, mix_pmf;
lw=3, ls=:dash, c=:black,
label="Mixture",
)
p3
This page was generated using Literate.jl.