Simulating and Fitting a Poisson Mixture Model
This tutorial demonstrates how to build and fit a Poisson Mixture Model (PMM) with StateSpaceDynamics.jl using the Expectation-Maximization (EM) algorithm. We'll cover simulation, fitting, diagnostics, interpretation, and practical considerations.
What is a Poisson Mixture Model?
A PMM assumes each observation $x_i \in \{0,1,2,\ldots\}$ is drawn from one of $k$ Poisson distributions with rates $\lambda_1,\ldots,\lambda_k$. The component assignment is a latent categorical variable $z_i \in \{1,\ldots,k\}$ with mixing weights $\pi_1,\ldots,\pi_k$ where $\sum_j \pi_j = 1$.
Generative process:
- Draw $z_i \sim \text{Categorical}(\boldsymbol{\pi})$
- Given $z_i = j$, draw $x_i \sim \text{Poisson}(\lambda_j)$
PMMs are useful for count data from heterogeneous sub-populations (e.g., spike counts from different neuron types, customer transaction counts from different segments, or event frequencies across different regimes).
EM Algorithm Overview
EM maximizes the marginal log-likelihood $\log p(\mathbf{x} | \boldsymbol{\pi}, \boldsymbol{\lambda})$ by iterating:
- E-step: Compute responsibilities $\gamma_{ij} = P(z_i = j | x_i, \boldsymbol{\theta})$
- M-step: Update parameters to maximize expected complete-data log-likelihood
For Poisson mixtures, the M-step has closed-form updates: $\pi_j \leftarrow \frac{1}{n} \sum_i \gamma_{ij}, \quad \lambda_j \leftarrow \frac{\sum_i \gamma_{ij} x_i}{\sum_i \gamma_{ij}}$
Load Required Packages
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using StatsPlots
using DistributionsFix RNG for reproducible simulation and k-means seeding
rng = StableRNG(1234);Create True Poisson Mixture Model
We'll simulate from a mixture of $k=3$ Poisson components with distinct rates and mixing weights. These parameters create well-separated components that should be recoverable by EM.
k = 3
true_λs = [5.0, 10.0, 25.0] # Poisson rates per component
true_πs = [0.25, 0.45, 0.30] # Mixing weights (sum to 1)
true_pmm = PoissonMixtureModel(k, true_λs, true_πs);
print("True model: k=$k components\n")
for i in 1:k
print("Component $i: λ=$(true_λs[i]), π=$(true_πs[i])\n")
endTrue model: k=3 components
Component 1: λ=5.0, π=0.25
Component 2: λ=10.0, π=0.45
Component 3: λ=25.0, π=0.3Generate Synthetic Data
Draw $n$ independent samples. The labels indicate true component membership for each observation (unknown in practice and must be inferred).
n = 500
labels = rand(rng, Categorical(true_πs), n)
data = [rand(rng, Poisson(true_λs[labels[i]])) for i in 1:n];
print("Generated $n samples with count range [$(minimum(data)), $(maximum(data))]\n");Generated 500 samples with count range [1, 38]Visualize samples by true component membership Components with larger $\lambda$ shift mass toward higher counts
p1 = histogram(data;
group=labels,
bins=0:1:maximum(data),
bar_position=:dodge,
xlabel="Count", ylabel="Frequency",
title="Poisson Mixture Samples by True Component",
alpha=0.7, legend=:topright
)Fit Poisson Mixture Model with EM
Construct model with $k$ components and fit using EM algorithm. Key options:
maxiter: Maximum EM iterationstol: Convergence tolerance (relative log-likelihood improvement)initialize_kmeans=true: Use k-means for stable initialization
fit_pmm = PoissonMixtureModel(k)
_, lls = fit!(fit_pmm, data; maxiter=100, tol=1e-6, initialize_kmeans=true);
print("EM converged in $(length(lls)) iterations\n")
print("Log-likelihood improved by $(round(lls[end] - lls[1], digits=1))\n");Iteration 1: Log-likelihood = -1672.1675174811173
Iteration 2: Log-likelihood = -1672.129439734492
Iteration 3: Log-likelihood = -1672.1189633052302
Iteration 4: Log-likelihood = -1672.1151985972867
Iteration 5: Log-likelihood = -1672.1134191008116
Iteration 6: Log-likelihood = -1672.1122975214316
Iteration 7: Log-likelihood = -1672.1114517209635
Iteration 8: Log-likelihood = -1672.1107624634603
Iteration 9: Log-likelihood = -1672.1101848655817
Iteration 10: Log-likelihood = -1672.1096962746053
Iteration 11: Log-likelihood = -1672.1092816947269
Iteration 12: Log-likelihood = -1672.108929554561
Iteration 13: Log-likelihood = -1672.1086303457685
Iteration 14: Log-likelihood = -1672.1083760795443
Iteration 15: Log-likelihood = -1672.1081599935058
Iteration 16: Log-likelihood = -1672.107976349145
Iteration 17: Log-likelihood = -1672.107820272706
Iteration 18: Log-likelihood = -1672.1076876236193
Iteration 19: Log-likelihood = -1672.1075748837702
Iteration 20: Log-likelihood = -1672.1074790637347
Iteration 21: Log-likelihood = -1672.1073976232242
Iteration 22: Log-likelihood = -1672.1073284035733
Iteration 23: Log-likelihood = -1672.1072695703267
Iteration 24: Log-likelihood = -1672.107219564539
Iteration 25: Log-likelihood = -1672.107177061344
Iteration 26: Log-likelihood = -1672.1071409348096
Iteration 27: Log-likelihood = -1672.10711022802
Iteration 28: Log-likelihood = -1672.1070841277174
Iteration 29: Log-likelihood = -1672.1070619427257
Iteration 30: Log-likelihood = -1672.1070430855855
Iteration 31: Log-likelihood = -1672.1070270570317
Iteration 32: Log-likelihood = -1672.1070134326892
Iteration 33: Log-likelihood = -1672.1070018518944
Iteration 34: Log-likelihood = -1672.1069920080827
Iteration 35: Log-likelihood = -1672.106983640694
Iteration 36: Log-likelihood = -1672.1069765282668
Iteration 37: Log-likelihood = -1672.1069704825513
Iteration 38: Log-likelihood = -1672.1069653435484
Iteration 39: Log-likelihood = -1672.106960975268
Iteration 40: Log-likelihood = -1672.1069572621004
Iteration 41: Log-likelihood = -1672.1069541058043
Iteration 42: Log-likelihood = -1672.1069514228418
Iteration 43: Log-likelihood = -1672.1069491422309
Iteration 44: Log-likelihood = -1672.1069472036404
Iteration 45: Log-likelihood = -1672.1069455557601
Iteration 46: Log-likelihood = -1672.1069441550019
Iteration 47: Log-likelihood = -1672.1069429642948
Iteration 48: Log-likelihood = -1672.1069419521493
Iteration 49: Log-likelihood = -1672.1069410917778
Converged at iteration 49
EM converged in 49 iterations
Log-likelihood improved by 0.1Display learned parameters
print("Learned parameters:\n")
for i in 1:k
print("Component $i: λ=$(round(fit_pmm.λₖ[i], digits=2)), π=$(round(fit_pmm.πₖ[i], digits=3))\n")
endLearned parameters:
Component 1: λ=5.22, π=0.319
Component 2: λ=24.87, π=0.295
Component 3: λ=10.57, π=0.387Monitor EM Convergence
EM guarantees non-decreasing log-likelihood. Monotonic ascent indicates proper convergence.
p2 = plot(lls;
xlabel="Iteration", ylabel="Log-Likelihood",
title="EM Convergence",
marker=:circle, markersize=3, lw=2,
legend=false, color=:darkblue
)
annotate!(p2, length(lls)*0.7, lls[end]*0.98,
text("Final LL: $(round(lls[end], digits=1))", 10))Visual Model Assessment
Overlay fitted component PMFs and overall mixture PMF on normalized histogram. Components should explain major modes and tail behavior in the data.
p3 = histogram(data;
bins=0:1:maximum(data), normalize=true, alpha=0.3,
xlabel="Count", ylabel="Probability Density",
title="Data vs. Fitted Mixture Components",
label="Data", color=:gray
)
x_range = collect(0:maximum(data))
colors = [:red, :green, :blue]3-element Vector{Symbol}:
:red
:green
:bluePlot individual component PMFs
for i in 1:k
λᵢ = fit_pmm.λₖ[i]
πᵢ = fit_pmm.πₖ[i]
pmf_i = πᵢ .* pdf.(Poisson(λᵢ), x_range)
plot!(p3, x_range, pmf_i;
lw=2, color=colors[i],
label="Component $i (λ=$(round(λᵢ, digits=1)))"
)
endPlot overall mixture PMF
mixture_pmf = reduce(+, (πᵢ .* pdf.(Poisson(λᵢ), x_range)
for (λᵢ, πᵢ) in zip(fit_pmm.λₖ, fit_pmm.πₖ)))
plot!(p3, x_range, mixture_pmf;
lw=3, linestyle=:dash, color=:black,
label="Mixture PMF"
)Posterior Responsibilities (Soft Clustering)
Responsibilities $\gamma_{ij} = P(z_i = j | x_i, \hat{\boldsymbol{\theta}})$ quantify how likely each observation belongs to each component. These provide soft assignments and uncertainty quantification.
function responsibilities_pmm(λs::AbstractVector, πs::AbstractVector, x::AbstractVector)
k, n = length(λs), length(x)
Γ = zeros(n, k)
for i in 1:n
for j in 1:k
Γ[i, j] = πs[j] * pdf(Poisson(λs[j]), x[i])
end
row_sum = sum(Γ[i, :]) # Normalize to get probabilities
if row_sum > 0
Γ[i, :] ./= row_sum
end
end
return Γ
end
Γ = responsibilities_pmm(fit_pmm.λₖ, fit_pmm.πₖ, data);Hard assignments (if needed) are argmax over responsibilities
hard_labels = [argmax(Γ[i, :]) for i in 1:n];Calculate assignment accuracy compared to true labels
accuracy = mean(labels .== hard_labels)
print("Component assignment accuracy: $(round(accuracy*100, digits=1))%\n");Component assignment accuracy: 26.4%Information Criteria for Model Selection
When $k$ is unknown, compare models using AIC/BIC:
- AIC = $2p - 2\text{LL}$
- BIC = $p \log(n) - 2\text{LL}$
where parameter count $p = (k-1) + k = 2k-1$ (mixing weights + rates)
function compute_ic(lls::AbstractVector, n::Int, k::Int)
ll = last(lls)
p = 2k - 1
return (AIC = 2p - 2ll, BIC = p*log(n) - 2ll)
end
ic = compute_ic(lls, n, k)
print("Information criteria: AIC = $(round(ic.AIC, digits=1)), BIC = $(round(ic.BIC, digits=1))\n");Information criteria: AIC = 3354.2, BIC = 3375.3Parameter Recovery Assessment
print("\n=== Parameter Recovery Assessment ===\n")
λ_errors = [abs(true_λs[i] - fit_pmm.λₖ[i]) / true_λs[i] for i in 1:k]
π_errors = [abs(true_πs[i] - fit_pmm.πₖ[i]) for i in 1:k]
print("Rate recovery errors (%):\n")
for i in 1:k
print("Component $i: $(round(λ_errors[i]*100, digits=1))%\n")
end
print("Mixing weight recovery errors:\n")
for i in 1:k
print("Component $i: $(round(π_errors[i], digits=3))\n")
end
=== Parameter Recovery Assessment ===
Rate recovery errors (%):
Component 1: 4.3%
Component 2: 148.7%
Component 3: 57.7%
Mixing weight recovery errors:
Component 1: 0.069
Component 2: 0.155
Component 3: 0.087Model Selection Example
Demonstrate fitting multiple values of k and comparing via BIC
print("\n=== Model Selection Demo ===\n")
k_range = 1:5
bic_scores = Float64[]
aic_scores = Float64[]
for k_test in k_range
pmm_test = PoissonMixtureModel(k_test)
_, lls_test = fit!(pmm_test, data; maxiter=50, tol=1e-6, initialize_kmeans=true)
ic_test = compute_ic(lls_test, n, k_test)
push!(aic_scores, ic_test.AIC)
push!(bic_scores, ic_test.BIC)
print("k=$k_test: AIC=$(round(ic_test.AIC, digits=1)), BIC=$(round(ic_test.BIC, digits=1))\n")
end
=== Model Selection Demo ===
Iteration 1: Log-likelihood = -2446.0120010620803
Iteration 2: Log-likelihood = -2446.0120010620803
Converged at iteration 2
k=1: AIC=4894.0, BIC=4898.2
Iteration 1: Log-likelihood = -1761.8022965825448
Iteration 2: Log-likelihood = -1715.352850647611
Iteration 3: Log-likelihood = -1708.7084011822544
Iteration 4: Log-likelihood = -1707.3610148976156
Iteration 5: Log-likelihood = -1707.050835002773
Iteration 6: Log-likelihood = -1706.975641252702
Iteration 7: Log-likelihood = -1706.956966434844
Iteration 8: Log-likelihood = -1706.9522733161446
Iteration 9: Log-likelihood = -1706.9510869599253
Iteration 10: Log-likelihood = -1706.9507861831203
Iteration 11: Log-likelihood = -1706.9507098145746
Iteration 12: Log-likelihood = -1706.950690409858
Iteration 13: Log-likelihood = -1706.9506854774052
Iteration 14: Log-likelihood = -1706.9506842233998
Iteration 15: Log-likelihood = -1706.9506839045575
Converged at iteration 15
k=2: AIC=3419.9, BIC=3432.5
Iteration 1: Log-likelihood = -1704.4316314015664
Iteration 2: Log-likelihood = -1702.0009371283588
Iteration 3: Log-likelihood = -1701.3774456831416
Iteration 4: Log-likelihood = -1700.8068447082185
Iteration 5: Log-likelihood = -1700.1628921713354
Iteration 6: Log-likelihood = -1699.409453476575
Iteration 7: Log-likelihood = -1698.518482803539
Iteration 8: Log-likelihood = -1697.4635278141225
Iteration 9: Log-likelihood = -1696.223523507781
Iteration 10: Log-likelihood = -1694.7901379287714
Iteration 11: Log-likelihood = -1693.1766666013102
Iteration 12: Log-likelihood = -1691.4243647779394
Iteration 13: Log-likelihood = -1689.6003290044268
Iteration 14: Log-likelihood = -1687.7838599611991
Iteration 15: Log-likelihood = -1686.0462771316372
Iteration 16: Log-likelihood = -1684.4353544678936
Iteration 17: Log-likelihood = -1682.9720601105487
Iteration 18: Log-likelihood = -1681.6575360914424
Iteration 19: Log-likelihood = -1680.4828248577417
Iteration 20: Log-likelihood = -1679.435915373086
Iteration 21: Log-likelihood = -1678.5050522055772
Iteration 22: Log-likelihood = -1677.6795986497034
Iteration 23: Log-likelihood = -1676.949907626829
Iteration 24: Log-likelihood = -1676.3070244255534
Iteration 25: Log-likelihood = -1675.742502663587
Iteration 26: Log-likelihood = -1675.248350599719
Iteration 27: Log-likelihood = -1674.8170505065107
Iteration 28: Log-likelihood = -1674.4415992449321
Iteration 29: Log-likelihood = -1674.1155413835731
Iteration 30: Log-likelihood = -1673.8329845447959
Iteration 31: Log-likelihood = -1673.5885965802495
Iteration 32: Log-likelihood = -1673.3775880972648
Iteration 33: Log-likelihood = -1673.1956846400537
Iteration 34: Log-likelihood = -1673.0390923431908
Iteration 35: Log-likelihood = -1672.9044600516227
Iteration 36: Log-likelihood = -1672.788840116226
Iteration 37: Log-likelihood = -1672.6896494328412
Iteration 38: Log-likelihood = -1672.6046318003002
Iteration 39: Log-likelihood = -1672.5318223026807
Iteration 40: Log-likelihood = -1672.469514145654
Iteration 41: Log-likelihood = -1672.4162281733363
Iteration 42: Log-likelihood = -1672.370685143608
Iteration 43: Log-likelihood = -1672.331780733925
Iteration 44: Log-likelihood = -1672.2985631762544
Iteration 45: Log-likelihood = -1672.270213371892
Iteration 46: Log-likelihood = -1672.246027307626
Iteration 47: Log-likelihood = -1672.225400580539
Iteration 48: Log-likelihood = -1672.207814834355
Iteration 49: Log-likelihood = -1672.192825913611
Iteration 50: Log-likelihood = -1672.1800535503344
k=3: AIC=3354.4, BIC=3375.4
Iteration 1: Log-likelihood = -1691.1282468237698
Iteration 2: Log-likelihood = -1686.7221433618172
Iteration 3: Log-likelihood = -1684.307975354153
Iteration 4: Log-likelihood = -1682.551140806212
Iteration 5: Log-likelihood = -1681.126874376617
Iteration 6: Log-likelihood = -1679.914822065005
Iteration 7: Log-likelihood = -1678.8642923426971
Iteration 8: Log-likelihood = -1677.9490597951021
Iteration 9: Log-likelihood = -1677.1516070534465
Iteration 10: Log-likelihood = -1676.4577962715437
Iteration 11: Log-likelihood = -1675.8552166366267
Iteration 12: Log-likelihood = -1675.3327361115012
Iteration 13: Log-likelihood = -1674.8803818373672
Iteration 14: Log-likelihood = -1674.489274235453
Iteration 15: Log-likelihood = -1674.151550647657
Iteration 16: Log-likelihood = -1673.8602763795789
Iteration 17: Log-likelihood = -1673.609352287308
Iteration 18: Log-likelihood = -1673.3934258193026
Iteration 19: Log-likelihood = -1673.2078087522507
Iteration 20: Log-likelihood = -1673.048402495043
Iteration 21: Log-likelihood = -1672.911630737669
Iteration 22: Log-likelihood = -1672.7943788570947
Iteration 23: Log-likelihood = -1672.6939394525145
Iteration 24: Log-likelihood = -1672.6079634480407
Iteration 25: Log-likelihood = -1672.5344162786366
Iteration 26: Log-likelihood = -1672.4715387349484
Iteration 27: Log-likelihood = -1672.4178120825854
Iteration 28: Log-likelihood = -1672.3719270967624
Iteration 29: Log-likelihood = -1672.3327566705498
Iteration 30: Log-likelihood = -1672.299331669417
Iteration 31: Log-likelihood = -1672.270819719012
Iteration 32: Log-likelihood = -1672.2465066290301
Iteration 33: Log-likelihood = -1672.2257801734374
Iteration 34: Log-likelihood = -1672.2081159663041
Iteration 35: Log-likelihood = -1672.1930651928171
Iteration 36: Log-likelihood = -1672.1802439755634
Iteration 37: Log-likelihood = -1672.169324176416
Iteration 38: Log-likelihood = -1672.1600254548882
Iteration 39: Log-likelihood = -1672.152108421962
Iteration 40: Log-likelihood = -1672.1453687472426
Iteration 41: Log-likelihood = -1672.1396320929139
Iteration 42: Log-likelihood = -1672.1347497636796
Iteration 43: Log-likelihood = -1672.1305949752684
Iteration 44: Log-likelihood = -1672.1270596565316
Iteration 45: Log-likelihood = -1672.1240517110436
Iteration 46: Log-likelihood = -1672.1214926738687
Iteration 47: Log-likelihood = -1672.11931570786
Iteration 48: Log-likelihood = -1672.1174638912166
Iteration 49: Log-likelihood = -1672.1158887547767
Iteration 50: Log-likelihood = -1672.1145490331307
k=4: AIC=3358.2, BIC=3387.7
Iteration 1: Log-likelihood = -1683.2462066167125
Iteration 2: Log-likelihood = -1677.7711607516476
Iteration 3: Log-likelihood = -1675.9431349526744
Iteration 4: Log-likelihood = -1674.9279472507706
Iteration 5: Log-likelihood = -1674.236206804308
Iteration 6: Log-likelihood = -1673.733476430186
Iteration 7: Log-likelihood = -1673.360583625359
Iteration 8: Log-likelihood = -1673.0814635668223
Iteration 9: Log-likelihood = -1672.8709659630629
Iteration 10: Log-likelihood = -1672.710909651946
Iteration 11: Log-likelihood = -1672.58809901837
Iteration 12: Log-likelihood = -1672.4929748991071
Iteration 13: Log-likelihood = -1672.4186104017836
Iteration 14: Log-likelihood = -1672.3599656842187
Iteration 15: Log-likelihood = -1672.3133471490353
Iteration 16: Log-likelihood = -1672.2760222749541
Iteration 17: Log-likelihood = -1672.2459481799947
Iteration 18: Log-likelihood = -1672.221580918183
Iteration 19: Log-likelihood = -1672.2017411441482
Iteration 20: Log-likelihood = -1672.185518865462
Iteration 21: Log-likelihood = -1672.1722053011408
Iteration 22: Log-likelihood = -1672.161243625786
Iteration 23: Log-likelihood = -1672.1521929728842
Iteration 24: Log-likelihood = -1672.1447018346055
Iteration 25: Log-likelihood = -1672.1384881903925
Iteration 26: Log-likelihood = -1672.133324505908
Iteration 27: Log-likelihood = -1672.1290262955283
Iteration 28: Log-likelihood = -1672.1254433203214
Iteration 29: Log-likelihood = -1672.1224527549812
Iteration 30: Log-likelihood = -1672.119953840649
Iteration 31: Log-likelihood = -1672.1178636695336
Iteration 32: Log-likelihood = -1672.1161138392636
Iteration 33: Log-likelihood = -1672.1146477812206
Iteration 34: Log-likelihood = -1672.1134186151469
Iteration 35: Log-likelihood = -1672.1123874177363
Iteration 36: Log-likelihood = -1672.1115218189234
Iteration 37: Log-likelihood = -1672.110794859267
Iteration 38: Log-likelihood = -1672.1101840564895
Iteration 39: Log-likelihood = -1672.1096706402484
Iteration 40: Log-likelihood = -1672.1092389232028
Iteration 41: Log-likelihood = -1672.1088757825496
Iteration 42: Log-likelihood = -1672.1085702317575
Iteration 43: Log-likelihood = -1672.1083130660525
Iteration 44: Log-likelihood = -1672.1080965682822
Iteration 45: Log-likelihood = -1672.107914264717
Iteration 46: Log-likelihood = -1672.1077607216998
Iteration 47: Log-likelihood = -1672.107631376398
Iteration 48: Log-likelihood = -1672.1075223955067
Iteration 49: Log-likelihood = -1672.1074305573666
Iteration 50: Log-likelihood = -1672.1073531533136
k=5: AIC=3362.2, BIC=3400.1Plot information criteria vs number of components
p4 = plot(k_range, [aic_scores bic_scores];
xlabel="Number of Components (k)", ylabel="Information Criterion",
title="Model Selection via Information Criteria",
label=["AIC" "BIC"], marker=:circle, lw=2
)
optimal_k_aic = k_range[argmin(aic_scores)]
optimal_k_bic = k_range[argmin(bic_scores)]
print("Optimal k: AIC suggests k=$optimal_k_aic, BIC suggests k=$optimal_k_bic\n");Optimal k: AIC suggests k=3, BIC suggests k=3Summary
This tutorial demonstrated the complete Poisson Mixture Model workflow:
Key Concepts:
- Discrete mixtures: Model count data as mixture of Poisson distributions
- EM algorithm: Iterative optimization with closed-form M-step updates
- Soft clustering: Posterior responsibilities provide probabilistic assignments
- Model selection: Information criteria help choose appropriate number of components
Applications:
- Spike count analysis in neuroscience
- Customer transaction modeling in business analytics
- Event frequency analysis in reliability engineering
- Gene expression count clustering in bioinformatics
Technical Insights:
- Initialization strategy significantly affects final solution quality
- Label switching is a fundamental identifiability issue in mixture models
- Information criteria provide principled approach to model complexity selection
- Component separation quality affects parameter recovery accuracy
Extensions:
- Zero-inflated Poisson mixtures for excess zero counts
- Negative Binomial mixtures for overdispersed count data
- Bayesian approaches for uncertainty quantification
- Mixture regression models for count data with covariates
Poisson mixture models provide a flexible framework for modeling heterogeneous count data, enabling both clustering and density estimation while maintaining interpretable probabilistic structure.
This page was generated using Literate.jl.