Simulating and Fitting a Poisson Mixture Model
This tutorial demonstrates how to build and fit a Poisson Mixture Model (PMM) with StateSpaceDynamics.jl using the Expectation-Maximization (EM) algorithm. We'll cover simulation, fitting, diagnostics, interpretation, and practical considerations.
What is a Poisson Mixture Model?
A PMM assumes each observation $x_i \in \{0,1,2,\ldots\}$ is drawn from one of $k$ Poisson distributions with rates $\lambda_1,\ldots,\lambda_k$. The component assignment is a latent categorical variable $z_i \in \{1,\ldots,k\}$ with mixing weights $\pi_1,\ldots,\pi_k$ where $\sum_j \pi_j = 1$.
Generative process:
- Draw $z_i \sim \text{Categorical}(\boldsymbol{\pi})$
- Given $z_i = j$, draw $x_i \sim \text{Poisson}(\lambda_j)$
PMMs are useful for count data from heterogeneous sub-populations (e.g., spike counts from different neuron types, customer transaction counts from different segments, or event frequencies across different regimes).
EM Algorithm Overview
EM maximizes the marginal log-likelihood $\log p(\mathbf{x} | \boldsymbol{\pi}, \boldsymbol{\lambda})$ by iterating:
- E-step: Compute responsibilities $\gamma_{ij} = P(z_i = j | x_i, \boldsymbol{\theta})$
- M-step: Update parameters to maximize expected complete-data log-likelihood
For Poisson mixtures, the M-step has closed-form updates: $\pi_j \leftarrow \frac{1}{n} \sum_i \gamma_{ij}, \quad \lambda_j \leftarrow \frac{\sum_i \gamma_{ij} x_i}{\sum_i \gamma_{ij}}$
Load Required Packages
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using StatsPlots
using DistributionsFix RNG for reproducible simulation and k-means seeding
rng = StableRNG(1234);Create True Poisson Mixture Model
We'll simulate from a mixture of $k=3$ Poisson components with distinct rates and mixing weights. These parameters create well-separated components that should be recoverable by EM.
k = 3
true_λs = [5.0, 10.0, 25.0] # Poisson rates per component
true_πs = [0.25, 0.45, 0.30] # Mixing weights (sum to 1)
true_pmm = PoissonMixtureModel(k, true_λs, true_πs);
print("True model: k=$k components\n")
for i in 1:k
print("Component $i: λ=$(true_λs[i]), π=$(true_πs[i])\n")
endTrue model: k=3 components
Component 1: λ=5.0, π=0.25
Component 2: λ=10.0, π=0.45
Component 3: λ=25.0, π=0.3Generate Synthetic Data
Draw $n$ independent samples. The labels indicate true component membership for each observation (unknown in practice and must be inferred).
n = 500
labels = rand(rng, Categorical(true_πs), n)
data = [rand(rng, Poisson(true_λs[labels[i]])) for i in 1:n];
print("Generated $n samples with count range [$(minimum(data)), $(maximum(data))]\n");Generated 500 samples with count range [1, 38]Visualize samples by true component membership Components with larger $\lambda$ shift mass toward higher counts
p1 = histogram(data;
group=labels,
bins=0:1:maximum(data),
bar_position=:dodge,
xlabel="Count", ylabel="Frequency",
title="Poisson Mixture Samples by True Component",
alpha=0.7, legend=:topright
)Fit Poisson Mixture Model with EM
Construct model with $k$ components and fit using EM algorithm. Key options:
maxiter: Maximum EM iterationstol: Convergence tolerance (relative log-likelihood improvement)initialize_kmeans=true: Use k-means for stable initialization
fit_pmm = PoissonMixtureModel(k)
_, lls = fit!(fit_pmm, data; maxiter=100, tol=1e-6, initialize_kmeans=true);
print("EM converged in $(length(lls)) iterations\n")
print("Log-likelihood improved by $(round(lls[end] - lls[1], digits=1))\n");Iteration 1: Log-likelihood = -1675.2855560047683
Iteration 2: Log-likelihood = -1672.8617414276803
Iteration 3: Log-likelihood = -1672.3074358415222
Iteration 4: Log-likelihood = -1672.164521940615
Iteration 5: Log-likelihood = -1672.1258016768434
Iteration 6: Log-likelihood = -1672.1147947830093
Iteration 7: Log-likelihood = -1672.111371874729
Iteration 8: Log-likelihood = -1672.1100880678769
Iteration 9: Log-likelihood = -1672.1094447975577
Iteration 10: Log-likelihood = -1672.1090217273575
Iteration 11: Log-likelihood = -1672.108696263604
Iteration 12: Log-likelihood = -1672.1084290039594
Iteration 13: Log-likelihood = -1672.1082043790832
Iteration 14: Log-likelihood = -1672.1080141125772
Iteration 15: Log-likelihood = -1672.1078525389155
Iteration 16: Log-likelihood = -1672.1077152192565
Iteration 17: Log-likelihood = -1672.1075984832678
Iteration 18: Log-likelihood = -1672.1074992384272
Iteration 19: Log-likelihood = -1672.107414862758
Iteration 20: Log-likelihood = -1672.1073431287916
Iteration 21: Log-likelihood = -1672.107282143003
Iteration 22: Log-likelihood = -1672.1072302954524
Iteration 23: Log-likelihood = -1672.1071862172657
Iteration 24: Log-likelihood = -1672.107148744525
Iteration 25: Log-likelihood = -1672.1071168876272
Iteration 26: Log-likelihood = -1672.10708980514
Iteration 27: Log-likelihood = -1672.1070667816903
Iteration 28: Log-likelihood = -1672.1070472090382
Iteration 29: Log-likelihood = -1672.1070305700794
Iteration 30: Log-likelihood = -1672.1070164251605
Iteration 31: Log-likelihood = -1672.1070044005103
Iteration 32: Log-likelihood = -1672.106994178339
Iteration 33: Log-likelihood = -1672.106985488515
Iteration 34: Log-likelihood = -1672.1069781013432
Iteration 35: Log-likelihood = -1672.1069718215851
Iteration 36: Log-likelihood = -1672.106966483238
Iteration 37: Log-likelihood = -1672.106961945188
Iteration 38: Log-likelihood = -1672.1069580874637
Iteration 39: Log-likelihood = -1672.1069548080945
Iteration 40: Log-likelihood = -1672.106952020372
Iteration 41: Log-likelihood = -1672.106949650591
Iteration 42: Log-likelihood = -1672.1069476360994
Iteration 43: Log-likelihood = -1672.1069459236369
Iteration 44: Log-likelihood = -1672.1069444679154
Iteration 45: Log-likelihood = -1672.106943230453
Iteration 46: Log-likelihood = -1672.1069421785267
Iteration 47: Log-likelihood = -1672.1069412843055
Converged at iteration 47
EM converged in 47 iterations
Log-likelihood improved by 3.2Display learned parameters
print("Learned parameters:\n")
for i in 1:k
print("Component $i: λ=$(round(fit_pmm.λₖ[i], digits=2)), π=$(round(fit_pmm.πₖ[i], digits=3))\n")
endLearned parameters:
Component 1: λ=5.22, π=0.319
Component 2: λ=24.87, π=0.295
Component 3: λ=10.58, π=0.386Monitor EM Convergence
EM guarantees non-decreasing log-likelihood. Monotonic ascent indicates proper convergence.
p2 = plot(lls;
xlabel="Iteration", ylabel="Log-Likelihood",
title="EM Convergence",
marker=:circle, markersize=3, lw=2,
legend=false, color=:darkblue
)
annotate!(p2, length(lls)*0.7, lls[end]*0.98,
text("Final LL: $(round(lls[end], digits=1))", 10))Visual Model Assessment
Overlay fitted component PMFs and overall mixture PMF on normalized histogram. Components should explain major modes and tail behavior in the data.
p3 = histogram(data;
bins=0:1:maximum(data), normalize=true, alpha=0.3,
xlabel="Count", ylabel="Probability Density",
title="Data vs. Fitted Mixture Components",
label="Data", color=:gray
)
x_range = collect(0:maximum(data))
colors = [:red, :green, :blue]3-element Vector{Symbol}:
:red
:green
:bluePlot individual component PMFs
for i in 1:k
λᵢ = fit_pmm.λₖ[i]
πᵢ = fit_pmm.πₖ[i]
pmf_i = πᵢ .* pdf.(Poisson(λᵢ), x_range)
plot!(p3, x_range, pmf_i;
lw=2, color=colors[i],
label="Component $i (λ=$(round(λᵢ, digits=1)))"
)
endPlot overall mixture PMF
mixture_pmf = reduce(+, (πᵢ .* pdf.(Poisson(λᵢ), x_range)
for (λᵢ, πᵢ) in zip(fit_pmm.λₖ, fit_pmm.πₖ)))
plot!(p3, x_range, mixture_pmf;
lw=3, linestyle=:dash, color=:black,
label="Mixture PMF"
)Posterior Responsibilities (Soft Clustering)
Responsibilities $\gamma_{ij} = P(z_i = j | x_i, \hat{\boldsymbol{\theta}})$ quantify how likely each observation belongs to each component. These provide soft assignments and uncertainty quantification.
function responsibilities_pmm(λs::AbstractVector, πs::AbstractVector, x::AbstractVector)
k, n = length(λs), length(x)
Γ = zeros(n, k)
for i in 1:n
for j in 1:k
Γ[i, j] = πs[j] * pdf(Poisson(λs[j]), x[i])
end
row_sum = sum(Γ[i, :]) # Normalize to get probabilities
if row_sum > 0
Γ[i, :] ./= row_sum
end
end
return Γ
end
Γ = responsibilities_pmm(fit_pmm.λₖ, fit_pmm.πₖ, data);Hard assignments (if needed) are argmax over responsibilities
hard_labels = [argmax(Γ[i, :]) for i in 1:n];Calculate assignment accuracy compared to true labels
accuracy = mean(labels .== hard_labels)
print("Component assignment accuracy: $(round(accuracy*100, digits=1))%\n");Component assignment accuracy: 26.4%Information Criteria for Model Selection
When $k$ is unknown, compare models using AIC/BIC:
- AIC = $2p - 2\text{LL}$
- BIC = $p \log(n) - 2\text{LL}$
where parameter count $p = (k-1) + k = 2k-1$ (mixing weights + rates)
function compute_ic(lls::AbstractVector, n::Int, k::Int)
ll = last(lls)
p = 2k - 1
return (AIC = 2p - 2ll, BIC = p*log(n) - 2ll)
end
ic = compute_ic(lls, n, k)
print("Information criteria: AIC = $(round(ic.AIC, digits=1)), BIC = $(round(ic.BIC, digits=1))\n");Information criteria: AIC = 3354.2, BIC = 3375.3Parameter Recovery Assessment
print("\n=== Parameter Recovery Assessment ===\n")
λ_errors = [abs(true_λs[i] - fit_pmm.λₖ[i]) / true_λs[i] for i in 1:k]
π_errors = [abs(true_πs[i] - fit_pmm.πₖ[i]) for i in 1:k]
print("Rate recovery errors (%):\n")
for i in 1:k
print("Component $i: $(round(λ_errors[i]*100, digits=1))%\n")
end
print("Mixing weight recovery errors:\n")
for i in 1:k
print("Component $i: $(round(π_errors[i], digits=3))\n")
end
=== Parameter Recovery Assessment ===
Rate recovery errors (%):
Component 1: 4.4%
Component 2: 148.7%
Component 3: 57.7%
Mixing weight recovery errors:
Component 1: 0.069
Component 2: 0.155
Component 3: 0.086Model Selection Example
Demonstrate fitting multiple values of k and comparing via BIC
print("\n=== Model Selection Demo ===\n")
k_range = 1:5
bic_scores = Float64[]
aic_scores = Float64[]
for k_test in k_range
pmm_test = PoissonMixtureModel(k_test)
_, lls_test = fit!(pmm_test, data; maxiter=50, tol=1e-6, initialize_kmeans=true)
ic_test = compute_ic(lls_test, n, k_test)
push!(aic_scores, ic_test.AIC)
push!(bic_scores, ic_test.BIC)
print("k=$k_test: AIC=$(round(ic_test.AIC, digits=1)), BIC=$(round(ic_test.BIC, digits=1))\n")
end
=== Model Selection Demo ===
Iteration 1: Log-likelihood = -2446.0120010620803
Iteration 2: Log-likelihood = -2446.0120010620803
Converged at iteration 2
k=1: AIC=4894.0, BIC=4898.2
Iteration 1: Log-likelihood = -1725.159275168668
Iteration 2: Log-likelihood = -1710.378241489074
Iteration 3: Log-likelihood = -1707.7232060184288
Iteration 4: Log-likelihood = -1707.1363604250537
Iteration 5: Log-likelihood = -1706.9966164146529
Iteration 6: Log-likelihood = -1706.9622050675277
Iteration 7: Log-likelihood = -1706.9535934824933
Iteration 8: Log-likelihood = -1706.9514211429553
Iteration 9: Log-likelihood = -1706.950870967497
Iteration 10: Log-likelihood = -1706.9507313492327
Iteration 11: Log-likelihood = -1706.9506958826319
Iteration 12: Log-likelihood = -1706.9506868686449
Iteration 13: Log-likelihood = -1706.9506845771148
Iteration 14: Log-likelihood = -1706.950683994491
Converged at iteration 14
k=2: AIC=3419.9, BIC=3432.5
Iteration 1: Log-likelihood = -1704.3195555659797
Iteration 2: Log-likelihood = -1703.9458362949379
Iteration 3: Log-likelihood = -1703.685813777206
Iteration 4: Log-likelihood = -1703.4548339987298
Iteration 5: Log-likelihood = -1703.22801395101
Iteration 6: Log-likelihood = -1702.99125277468
Iteration 7: Log-likelihood = -1702.734046523778
Iteration 8: Log-likelihood = -1702.4470356075224
Iteration 9: Log-likelihood = -1702.1206307410766
Iteration 10: Log-likelihood = -1701.7440444750641
Iteration 11: Log-likelihood = -1701.3045009861003
Iteration 12: Log-likelihood = -1700.7865554458565
Iteration 13: Log-likelihood = -1700.1715808539734
Iteration 14: Log-likelihood = -1699.4376412036665
Iteration 15: Log-likelihood = -1698.5601980673152
Iteration 16: Log-likelihood = -1697.5143712319518
Iteration 17: Log-likelihood = -1696.2796252279393
Iteration 18: Log-likelihood = -1694.8473472323265
Iteration 19: Log-likelihood = -1693.2301850950457
Iteration 20: Log-likelihood = -1691.4691485771789
Iteration 21: Log-likelihood = -1689.6322896258673
Iteration 22: Log-likelihood = -1687.8012392847806
Iteration 23: Log-likelihood = -1686.0501875578743
Iteration 24: Log-likelihood = -1684.4290298662092
Iteration 25: Log-likelihood = -1682.9594009945486
Iteration 26: Log-likelihood = -1681.6419020754574
Iteration 27: Log-likelihood = -1680.4665287788532
Iteration 28: Log-likelihood = -1679.4202795784063
Iteration 29: Log-likelihood = -1678.4906960585263
Iteration 30: Log-likelihood = -1677.6667267525277
Iteration 31: Log-likelihood = -1676.93851372879
Iteration 32: Log-likelihood = -1676.29701045982
Iteration 33: Log-likelihood = -1675.7337392608997
Iteration 34: Log-likelihood = -1675.2407043802618
Iteration 35: Log-likelihood = -1674.810394909604
Iteration 36: Log-likelihood = -1674.4358180750003
Iteration 37: Log-likelihood = -1674.110529465094
Iteration 38: Log-likelihood = -1673.8286473588853
Iteration 39: Log-likelihood = -1673.5848495942018
Iteration 40: Log-likelihood = -1673.374356023739
Iteration 41: Log-likelihood = -1673.1929006932346
Iteration 42: Log-likelihood = -1673.0366975083916
Iteration 43: Log-likelihood = -1672.9024023745687
Iteration 44: Log-likelihood = -1672.7870740176859
Iteration 45: Log-likelihood = -1672.6881350550498
Iteration 46: Log-likelihood = -1672.6033343911613
Iteration 47: Log-likelihood = -1672.5307116423744
Iteration 48: Log-likelihood = -1672.4685640184512
Iteration 49: Log-likelihood = -1672.4154158859096
Iteration 50: Log-likelihood = -1672.3699910894566
k=3: AIC=3354.7, BIC=3375.8
Iteration 1: Log-likelihood = -1687.5997063100372
Iteration 2: Log-likelihood = -1682.9755525697433
Iteration 3: Log-likelihood = -1680.6701152086302
Iteration 4: Log-likelihood = -1679.1779264126162
Iteration 5: Log-likelihood = -1678.0925511352536
Iteration 6: Log-likelihood = -1677.2360490179935
Iteration 7: Log-likelihood = -1676.524644894624
Iteration 8: Log-likelihood = -1675.9175528261112
Iteration 9: Log-likelihood = -1675.3932239514377
Iteration 10: Log-likelihood = -1674.938519100328
Iteration 11: Log-likelihood = -1674.5440234482487
Iteration 12: Log-likelihood = -1674.2021279893906
Iteration 13: Log-likelihood = -1673.906280318686
Iteration 14: Log-likelihood = -1673.6506961498142
Iteration 15: Log-likelihood = -1673.430238095617
Iteration 16: Log-likelihood = -1673.2403490970394
Iteration 17: Log-likelihood = -1673.0770012850276
Iteration 18: Log-likelihood = -1672.9366485749838
Iteration 19: Log-likelihood = -1672.8161805953034
Iteration 20: Log-likelihood = -1672.7128781918946
Iteration 21: Log-likelihood = -1672.6243712023809
Iteration 22: Log-likelihood = -1672.548599048343
Iteration 23: Log-likelihood = -1672.4837744570677
Iteration 24: Log-likelihood = -1672.4283504278974
Iteration 25: Log-likelihood = -1672.3809904190255
Iteration 26: Log-likelihood = -1672.3405416386001
Iteration 27: Log-likelihood = -1672.306011266212
Iteration 28: Log-likelihood = -1672.2765453968887
Iteration 29: Log-likelihood = -1672.2514104831673
Iteration 30: Log-likelihood = -1672.2299770458126
Iteration 31: Log-likelihood = -1672.2117054272253
Iteration 32: Log-likelihood = -1672.1961333710767
Iteration 33: Log-likelihood = -1672.1828652241895
Iteration 34: Log-likelihood = -1672.1715625721565
Iteration 35: Log-likelihood = -1672.1619361359371
Iteration 36: Log-likelihood = -1672.1537387732171
Iteration 37: Log-likelihood = -1672.1467594442347
Iteration 38: Log-likelihood = -1672.140818017025
Iteration 39: Log-likelihood = -1672.1357608013006
Iteration 40: Log-likelihood = -1672.1314567133008
Iteration 41: Log-likelihood = -1672.127793985931
Iteration 42: Log-likelihood = -1672.1246773493472
Iteration 43: Log-likelihood = -1672.1220256166228
Iteration 44: Log-likelihood = -1672.1197696179124
Iteration 45: Log-likelihood = -1672.1178504338989
Iteration 46: Log-likelihood = -1672.116217886124
Iteration 47: Log-likelihood = -1672.1148292474613
Iteration 48: Log-likelihood = -1672.1136481411163
Iteration 49: Log-likelihood = -1672.1126436011173
Iteration 50: Log-likelihood = -1672.1117892707553
k=4: AIC=3358.2, BIC=3387.7
Iteration 1: Log-likelihood = -1680.758984117847
Iteration 2: Log-likelihood = -1677.0495363952361
Iteration 3: Log-likelihood = -1675.6898936254063
Iteration 4: Log-likelihood = -1674.9034547640667
Iteration 5: Log-likelihood = -1674.369182966404
Iteration 6: Log-likelihood = -1673.9796850340795
Iteration 7: Log-likelihood = -1673.6852604734079
Iteration 8: Log-likelihood = -1673.4576541516078
Iteration 9: Log-likelihood = -1673.2787097269809
Iteration 10: Log-likelihood = -1673.1359501009015
Iteration 11: Log-likelihood = -1673.0204862984735
Iteration 12: Log-likelihood = -1672.9258536033838
Iteration 13: Log-likelihood = -1672.8472860515194
Iteration 14: Log-likelihood = -1672.7812328478365
Iteration 15: Log-likelihood = -1672.725024117896
Iteration 16: Log-likelihood = -1672.6766351794518
Iteration 17: Log-likelihood = -1672.6345181522745
Iteration 18: Log-likelihood = -1672.5974804778466
Iteration 19: Log-likelihood = -1672.5645965068131
Iteration 20: Log-likelihood = -1672.5351426108011
Iteration 21: Log-likelihood = -1672.5085491696302
Iteration 22: Log-likelihood = -1672.484364764128
Iteration 23: Log-likelihood = -1672.4622292693593
Iteration 24: Log-likelihood = -1672.4418534908457
Iteration 25: Log-likelihood = -1672.423003648747
Iteration 26: Log-likelihood = -1672.4054894813567
Iteration 27: Log-likelihood = -1672.3891550702785
Iteration 28: Log-likelihood = -1672.3738717261883
Iteration 29: Log-likelihood = -1672.3595324445862
Iteration 30: Log-likelihood = -1672.3460475650793
Iteration 31: Log-likelihood = -1672.3333413581743
Iteration 32: Log-likelihood = -1672.3213493308087
Iteration 33: Log-likelihood = -1672.3100160913516
Iteration 34: Log-likelihood = -1672.299293652206
Iteration 35: Log-likelihood = -1672.2891400759668
Iteration 36: Log-likelihood = -1672.279518392472
Iteration 37: Log-likelihood = -1672.2703957302351
Iteration 38: Log-likelihood = -1672.2617426178897
Iteration 39: Log-likelihood = -1672.2535324212374
Iteration 40: Log-likelihood = -1672.2457408883747
Iteration 41: Log-likelihood = -1672.238345781472
Iteration 42: Log-likelihood = -1672.2313265780558
Iteration 43: Log-likelihood = -1672.224664228118
Iteration 44: Log-likelihood = -1672.2183409562736
Iteration 45: Log-likelihood = -1672.2123401002395
Iteration 46: Log-likelihood = -1672.2066459785617
Iteration 47: Log-likelihood = -1672.2012437821681
Iteration 48: Log-likelihood = -1672.1961194849953
Iteration 49: Log-likelihood = -1672.1912597701623
Iteration 50: Log-likelihood = -1672.186651968744
k=5: AIC=3362.4, BIC=3400.3Plot information criteria vs number of components
p4 = plot(k_range, [aic_scores bic_scores];
xlabel="Number of Components (k)", ylabel="Information Criterion",
title="Model Selection via Information Criteria",
label=["AIC" "BIC"], marker=:circle, lw=2
)
optimal_k_aic = k_range[argmin(aic_scores)]
optimal_k_bic = k_range[argmin(bic_scores)]
print("Optimal k: AIC suggests k=$optimal_k_aic, BIC suggests k=$optimal_k_bic\n");Optimal k: AIC suggests k=3, BIC suggests k=3Summary
This tutorial demonstrated the complete Poisson Mixture Model workflow:
Key Concepts:
- Discrete mixtures: Model count data as mixture of Poisson distributions
- EM algorithm: Iterative optimization with closed-form M-step updates
- Soft clustering: Posterior responsibilities provide probabilistic assignments
- Model selection: Information criteria help choose appropriate number of components
Applications:
- Spike count analysis in neuroscience
- Customer transaction modeling in business analytics
- Event frequency analysis in reliability engineering
- Gene expression count clustering in bioinformatics
Technical Insights:
- Initialization strategy significantly affects final solution quality
- Label switching is a fundamental identifiability issue in mixture models
- Information criteria provide principled approach to model complexity selection
- Component separation quality affects parameter recovery accuracy
Extensions:
- Zero-inflated Poisson mixtures for excess zero counts
- Negative Binomial mixtures for overdispersed count data
- Bayesian approaches for uncertainty quantification
- Mixture regression models for count data with covariates
Poisson mixture models provide a flexible framework for modeling heterogeneous count data, enabling both clustering and density estimation while maintaining interpretable probabilistic structure.
This page was generated using Literate.jl.