Simulating and Fitting a Poisson Mixture Model

This tutorial demonstrates how to build and fit a Poisson Mixture Model (PMM) with StateSpaceDynamics.jl using the Expectation-Maximization (EM) algorithm. We'll cover simulation, fitting, diagnostics, interpretation, and practical considerations.

What is a Poisson Mixture Model?

A PMM assumes each observation $x_i \in \{0,1,2,\ldots\}$ is drawn from one of $k$ Poisson distributions with rates $\lambda_1,\ldots,\lambda_k$. The component assignment is a latent categorical variable $z_i \in \{1,\ldots,k\}$ with mixing weights $\pi_1,\ldots,\pi_k$ where $\sum_j \pi_j = 1$.

Generative process:

  1. Draw $z_i \sim \text{Categorical}(\boldsymbol{\pi})$
  2. Given $z_i = j$, draw $x_i \sim \text{Poisson}(\lambda_j)$

PMMs are useful for count data from heterogeneous sub-populations (e.g., spike counts from different neuron types, customer transaction counts from different segments, or event frequencies across different regimes).

EM Algorithm Overview

EM maximizes the marginal log-likelihood $\log p(\mathbf{x} | \boldsymbol{\pi}, \boldsymbol{\lambda})$ by iterating:

  • E-step: Compute responsibilities $\gamma_{ij} = P(z_i = j | x_i, \boldsymbol{\theta})$
  • M-step: Update parameters to maximize expected complete-data log-likelihood

For Poisson mixtures, the M-step has closed-form updates: $\pi_j \leftarrow \frac{1}{n} \sum_i \gamma_{ij}, \quad \lambda_j \leftarrow \frac{\sum_i \gamma_{ij} x_i}{\sum_i \gamma_{ij}}$

Load Required Packages

using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using StatsPlots
using Distributions

Fix RNG for reproducible simulation and k-means seeding

rng = StableRNG(1234);

Create True Poisson Mixture Model

We'll simulate from a mixture of $k=3$ Poisson components with distinct rates and mixing weights. These parameters create well-separated components that should be recoverable by EM.

k = 3
true_λs = [5.0, 10.0, 25.0]   # Poisson rates per component
true_πs = [0.25, 0.45, 0.30]  # Mixing weights (sum to 1)

true_pmm = PoissonMixtureModel(k, true_λs, true_πs);

print("True model: k=$k components\n")
for i in 1:k
    print("Component $i: λ=$(true_λs[i]), π=$(true_πs[i])\n")
end
True model: k=3 components
Component 1: λ=5.0, π=0.25
Component 2: λ=10.0, π=0.45
Component 3: λ=25.0, π=0.3

Generate Synthetic Data

Draw $n$ independent samples. The labels indicate true component membership for each observation (unknown in practice and must be inferred).

n = 500
labels = rand(rng, Categorical(true_πs), n)
data = [rand(rng, Poisson(true_λs[labels[i]])) for i in 1:n];

print("Generated $n samples with count range [$(minimum(data)), $(maximum(data))]\n");
Generated 500 samples with count range [1, 38]

Visualize samples by true component membership Components with larger $\lambda$ shift mass toward higher counts

p1 = histogram(data;
    group=labels,
    bins=0:1:maximum(data),
    bar_position=:dodge,
    xlabel="Count", ylabel="Frequency",
    title="Poisson Mixture Samples by True Component",
    alpha=0.7, legend=:topright
)
Example block output

Fit Poisson Mixture Model with EM

Construct model with $k$ components and fit using EM algorithm. Key options:

  • maxiter: Maximum EM iterations
  • tol: Convergence tolerance (relative log-likelihood improvement)
  • initialize_kmeans=true: Use k-means for stable initialization
fit_pmm = PoissonMixtureModel(k)
_, lls = fit!(fit_pmm, data; maxiter=100, tol=1e-6, initialize_kmeans=true);

print("EM converged in $(length(lls)) iterations\n")
print("Log-likelihood improved by $(round(lls[end] - lls[1], digits=1))\n");
Iteration 1: Log-likelihood = -1675.2855560047683
Iteration 2: Log-likelihood = -1672.8617414276803
Iteration 3: Log-likelihood = -1672.3074358415222
Iteration 4: Log-likelihood = -1672.164521940615
Iteration 5: Log-likelihood = -1672.1258016768434
Iteration 6: Log-likelihood = -1672.1147947830093
Iteration 7: Log-likelihood = -1672.111371874729
Iteration 8: Log-likelihood = -1672.1100880678769
Iteration 9: Log-likelihood = -1672.1094447975577
Iteration 10: Log-likelihood = -1672.1090217273575
Iteration 11: Log-likelihood = -1672.108696263604
Iteration 12: Log-likelihood = -1672.1084290039594
Iteration 13: Log-likelihood = -1672.1082043790832
Iteration 14: Log-likelihood = -1672.1080141125772
Iteration 15: Log-likelihood = -1672.1078525389155
Iteration 16: Log-likelihood = -1672.1077152192565
Iteration 17: Log-likelihood = -1672.1075984832678
Iteration 18: Log-likelihood = -1672.1074992384272
Iteration 19: Log-likelihood = -1672.107414862758
Iteration 20: Log-likelihood = -1672.1073431287916
Iteration 21: Log-likelihood = -1672.107282143003
Iteration 22: Log-likelihood = -1672.1072302954524
Iteration 23: Log-likelihood = -1672.1071862172657
Iteration 24: Log-likelihood = -1672.107148744525
Iteration 25: Log-likelihood = -1672.1071168876272
Iteration 26: Log-likelihood = -1672.10708980514
Iteration 27: Log-likelihood = -1672.1070667816903
Iteration 28: Log-likelihood = -1672.1070472090382
Iteration 29: Log-likelihood = -1672.1070305700794
Iteration 30: Log-likelihood = -1672.1070164251605
Iteration 31: Log-likelihood = -1672.1070044005103
Iteration 32: Log-likelihood = -1672.106994178339
Iteration 33: Log-likelihood = -1672.106985488515
Iteration 34: Log-likelihood = -1672.1069781013432
Iteration 35: Log-likelihood = -1672.1069718215851
Iteration 36: Log-likelihood = -1672.106966483238
Iteration 37: Log-likelihood = -1672.106961945188
Iteration 38: Log-likelihood = -1672.1069580874637
Iteration 39: Log-likelihood = -1672.1069548080945
Iteration 40: Log-likelihood = -1672.106952020372
Iteration 41: Log-likelihood = -1672.106949650591
Iteration 42: Log-likelihood = -1672.1069476360994
Iteration 43: Log-likelihood = -1672.1069459236369
Iteration 44: Log-likelihood = -1672.1069444679154
Iteration 45: Log-likelihood = -1672.106943230453
Iteration 46: Log-likelihood = -1672.1069421785267
Iteration 47: Log-likelihood = -1672.1069412843055
Converged at iteration 47
EM converged in 47 iterations
Log-likelihood improved by 3.2

Display learned parameters

print("Learned parameters:\n")
for i in 1:k
    print("Component $i: λ=$(round(fit_pmm.λₖ[i], digits=2)), π=$(round(fit_pmm.πₖ[i], digits=3))\n")
end
Learned parameters:
Component 1: λ=5.22, π=0.319
Component 2: λ=24.87, π=0.295
Component 3: λ=10.58, π=0.386

Monitor EM Convergence

EM guarantees non-decreasing log-likelihood. Monotonic ascent indicates proper convergence.

p2 = plot(lls;
    xlabel="Iteration", ylabel="Log-Likelihood",
    title="EM Convergence",
    marker=:circle, markersize=3, lw=2,
    legend=false, color=:darkblue
)

annotate!(p2, length(lls)*0.7, lls[end]*0.98,
    text("Final LL: $(round(lls[end], digits=1))", 10))
Example block output

Visual Model Assessment

Overlay fitted component PMFs and overall mixture PMF on normalized histogram. Components should explain major modes and tail behavior in the data.

p3 = histogram(data;
    bins=0:1:maximum(data), normalize=true, alpha=0.3,
    xlabel="Count", ylabel="Probability Density",
    title="Data vs. Fitted Mixture Components",
    label="Data", color=:gray
)

x_range = collect(0:maximum(data))
colors = [:red, :green, :blue]
3-element Vector{Symbol}:
 :red
 :green
 :blue

Plot individual component PMFs

for i in 1:k
    λᵢ = fit_pmm.λₖ[i]
    πᵢ = fit_pmm.πₖ[i]
    pmf_i = πᵢ .* pdf.(Poisson(λᵢ), x_range)
    plot!(p3, x_range, pmf_i;
        lw=2, color=colors[i],
        label="Component $i (λ=$(round(λᵢ, digits=1)))"
    )
end

Plot overall mixture PMF

mixture_pmf = reduce(+, (πᵢ .* pdf.(Poisson(λᵢ), x_range)
                        for (λᵢ, πᵢ) in zip(fit_pmm.λₖ, fit_pmm.πₖ)))
plot!(p3, x_range, mixture_pmf;
    lw=3, linestyle=:dash, color=:black,
    label="Mixture PMF"
)
Example block output

Posterior Responsibilities (Soft Clustering)

Responsibilities $\gamma_{ij} = P(z_i = j | x_i, \hat{\boldsymbol{\theta}})$ quantify how likely each observation belongs to each component. These provide soft assignments and uncertainty quantification.

function responsibilities_pmm(λs::AbstractVector, πs::AbstractVector, x::AbstractVector)
    k, n = length(λs), length(x)
    Γ = zeros(n, k)

    for i in 1:n
        for j in 1:k
            Γ[i, j] = πs[j] * pdf(Poisson(λs[j]), x[i])
        end

        row_sum = sum(Γ[i, :]) # Normalize to get probabilities
        if row_sum > 0
            Γ[i, :] ./= row_sum
        end
    end
    return Γ
end

Γ = responsibilities_pmm(fit_pmm.λₖ, fit_pmm.πₖ, data);

Hard assignments (if needed) are argmax over responsibilities

hard_labels = [argmax(Γ[i, :]) for i in 1:n];

Calculate assignment accuracy compared to true labels

accuracy = mean(labels .== hard_labels)
print("Component assignment accuracy: $(round(accuracy*100, digits=1))%\n");
Component assignment accuracy: 26.4%

Information Criteria for Model Selection

When $k$ is unknown, compare models using AIC/BIC:

  • AIC = $2p - 2\text{LL}$
  • BIC = $p \log(n) - 2\text{LL}$

where parameter count $p = (k-1) + k = 2k-1$ (mixing weights + rates)

function compute_ic(lls::AbstractVector, n::Int, k::Int)
    ll = last(lls)
    p = 2k - 1
    return (AIC = 2p - 2ll, BIC = p*log(n) - 2ll)
end

ic = compute_ic(lls, n, k)
print("Information criteria: AIC = $(round(ic.AIC, digits=1)), BIC = $(round(ic.BIC, digits=1))\n");
Information criteria: AIC = 3354.2, BIC = 3375.3

Parameter Recovery Assessment

print("\n=== Parameter Recovery Assessment ===\n")

λ_errors = [abs(true_λs[i] - fit_pmm.λₖ[i]) / true_λs[i] for i in 1:k]
π_errors = [abs(true_πs[i] - fit_pmm.πₖ[i]) for i in 1:k]

print("Rate recovery errors (%):\n")
for i in 1:k
    print("Component $i: $(round(λ_errors[i]*100, digits=1))%\n")
end

print("Mixing weight recovery errors:\n")
for i in 1:k
    print("Component $i: $(round(π_errors[i], digits=3))\n")
end

=== Parameter Recovery Assessment ===
Rate recovery errors (%):
Component 1: 4.4%
Component 2: 148.7%
Component 3: 57.7%
Mixing weight recovery errors:
Component 1: 0.069
Component 2: 0.155
Component 3: 0.086

Model Selection Example

Demonstrate fitting multiple values of k and comparing via BIC

print("\n=== Model Selection Demo ===\n")

k_range = 1:5
bic_scores = Float64[]
aic_scores = Float64[]

for k_test in k_range
    pmm_test = PoissonMixtureModel(k_test)
    _, lls_test = fit!(pmm_test, data; maxiter=50, tol=1e-6, initialize_kmeans=true)

    ic_test = compute_ic(lls_test, n, k_test)
    push!(aic_scores, ic_test.AIC)
    push!(bic_scores, ic_test.BIC)

    print("k=$k_test: AIC=$(round(ic_test.AIC, digits=1)), BIC=$(round(ic_test.BIC, digits=1))\n")
end

=== Model Selection Demo ===
Iteration 1: Log-likelihood = -2446.0120010620803
Iteration 2: Log-likelihood = -2446.0120010620803
Converged at iteration 2
k=1: AIC=4894.0, BIC=4898.2
Iteration 1: Log-likelihood = -1725.159275168668
Iteration 2: Log-likelihood = -1710.378241489074
Iteration 3: Log-likelihood = -1707.7232060184288
Iteration 4: Log-likelihood = -1707.1363604250537
Iteration 5: Log-likelihood = -1706.9966164146529
Iteration 6: Log-likelihood = -1706.9622050675277
Iteration 7: Log-likelihood = -1706.9535934824933
Iteration 8: Log-likelihood = -1706.9514211429553
Iteration 9: Log-likelihood = -1706.950870967497
Iteration 10: Log-likelihood = -1706.9507313492327
Iteration 11: Log-likelihood = -1706.9506958826319
Iteration 12: Log-likelihood = -1706.9506868686449
Iteration 13: Log-likelihood = -1706.9506845771148
Iteration 14: Log-likelihood = -1706.950683994491
Converged at iteration 14
k=2: AIC=3419.9, BIC=3432.5
Iteration 1: Log-likelihood = -1704.3195555659797
Iteration 2: Log-likelihood = -1703.9458362949379
Iteration 3: Log-likelihood = -1703.685813777206
Iteration 4: Log-likelihood = -1703.4548339987298
Iteration 5: Log-likelihood = -1703.22801395101
Iteration 6: Log-likelihood = -1702.99125277468
Iteration 7: Log-likelihood = -1702.734046523778
Iteration 8: Log-likelihood = -1702.4470356075224
Iteration 9: Log-likelihood = -1702.1206307410766
Iteration 10: Log-likelihood = -1701.7440444750641
Iteration 11: Log-likelihood = -1701.3045009861003
Iteration 12: Log-likelihood = -1700.7865554458565
Iteration 13: Log-likelihood = -1700.1715808539734
Iteration 14: Log-likelihood = -1699.4376412036665
Iteration 15: Log-likelihood = -1698.5601980673152
Iteration 16: Log-likelihood = -1697.5143712319518
Iteration 17: Log-likelihood = -1696.2796252279393
Iteration 18: Log-likelihood = -1694.8473472323265
Iteration 19: Log-likelihood = -1693.2301850950457
Iteration 20: Log-likelihood = -1691.4691485771789
Iteration 21: Log-likelihood = -1689.6322896258673
Iteration 22: Log-likelihood = -1687.8012392847806
Iteration 23: Log-likelihood = -1686.0501875578743
Iteration 24: Log-likelihood = -1684.4290298662092
Iteration 25: Log-likelihood = -1682.9594009945486
Iteration 26: Log-likelihood = -1681.6419020754574
Iteration 27: Log-likelihood = -1680.4665287788532
Iteration 28: Log-likelihood = -1679.4202795784063
Iteration 29: Log-likelihood = -1678.4906960585263
Iteration 30: Log-likelihood = -1677.6667267525277
Iteration 31: Log-likelihood = -1676.93851372879
Iteration 32: Log-likelihood = -1676.29701045982
Iteration 33: Log-likelihood = -1675.7337392608997
Iteration 34: Log-likelihood = -1675.2407043802618
Iteration 35: Log-likelihood = -1674.810394909604
Iteration 36: Log-likelihood = -1674.4358180750003
Iteration 37: Log-likelihood = -1674.110529465094
Iteration 38: Log-likelihood = -1673.8286473588853
Iteration 39: Log-likelihood = -1673.5848495942018
Iteration 40: Log-likelihood = -1673.374356023739
Iteration 41: Log-likelihood = -1673.1929006932346
Iteration 42: Log-likelihood = -1673.0366975083916
Iteration 43: Log-likelihood = -1672.9024023745687
Iteration 44: Log-likelihood = -1672.7870740176859
Iteration 45: Log-likelihood = -1672.6881350550498
Iteration 46: Log-likelihood = -1672.6033343911613
Iteration 47: Log-likelihood = -1672.5307116423744
Iteration 48: Log-likelihood = -1672.4685640184512
Iteration 49: Log-likelihood = -1672.4154158859096
Iteration 50: Log-likelihood = -1672.3699910894566
k=3: AIC=3354.7, BIC=3375.8
Iteration 1: Log-likelihood = -1687.5997063100372
Iteration 2: Log-likelihood = -1682.9755525697433
Iteration 3: Log-likelihood = -1680.6701152086302
Iteration 4: Log-likelihood = -1679.1779264126162
Iteration 5: Log-likelihood = -1678.0925511352536
Iteration 6: Log-likelihood = -1677.2360490179935
Iteration 7: Log-likelihood = -1676.524644894624
Iteration 8: Log-likelihood = -1675.9175528261112
Iteration 9: Log-likelihood = -1675.3932239514377
Iteration 10: Log-likelihood = -1674.938519100328
Iteration 11: Log-likelihood = -1674.5440234482487
Iteration 12: Log-likelihood = -1674.2021279893906
Iteration 13: Log-likelihood = -1673.906280318686
Iteration 14: Log-likelihood = -1673.6506961498142
Iteration 15: Log-likelihood = -1673.430238095617
Iteration 16: Log-likelihood = -1673.2403490970394
Iteration 17: Log-likelihood = -1673.0770012850276
Iteration 18: Log-likelihood = -1672.9366485749838
Iteration 19: Log-likelihood = -1672.8161805953034
Iteration 20: Log-likelihood = -1672.7128781918946
Iteration 21: Log-likelihood = -1672.6243712023809
Iteration 22: Log-likelihood = -1672.548599048343
Iteration 23: Log-likelihood = -1672.4837744570677
Iteration 24: Log-likelihood = -1672.4283504278974
Iteration 25: Log-likelihood = -1672.3809904190255
Iteration 26: Log-likelihood = -1672.3405416386001
Iteration 27: Log-likelihood = -1672.306011266212
Iteration 28: Log-likelihood = -1672.2765453968887
Iteration 29: Log-likelihood = -1672.2514104831673
Iteration 30: Log-likelihood = -1672.2299770458126
Iteration 31: Log-likelihood = -1672.2117054272253
Iteration 32: Log-likelihood = -1672.1961333710767
Iteration 33: Log-likelihood = -1672.1828652241895
Iteration 34: Log-likelihood = -1672.1715625721565
Iteration 35: Log-likelihood = -1672.1619361359371
Iteration 36: Log-likelihood = -1672.1537387732171
Iteration 37: Log-likelihood = -1672.1467594442347
Iteration 38: Log-likelihood = -1672.140818017025
Iteration 39: Log-likelihood = -1672.1357608013006
Iteration 40: Log-likelihood = -1672.1314567133008
Iteration 41: Log-likelihood = -1672.127793985931
Iteration 42: Log-likelihood = -1672.1246773493472
Iteration 43: Log-likelihood = -1672.1220256166228
Iteration 44: Log-likelihood = -1672.1197696179124
Iteration 45: Log-likelihood = -1672.1178504338989
Iteration 46: Log-likelihood = -1672.116217886124
Iteration 47: Log-likelihood = -1672.1148292474613
Iteration 48: Log-likelihood = -1672.1136481411163
Iteration 49: Log-likelihood = -1672.1126436011173
Iteration 50: Log-likelihood = -1672.1117892707553
k=4: AIC=3358.2, BIC=3387.7
Iteration 1: Log-likelihood = -1680.758984117847
Iteration 2: Log-likelihood = -1677.0495363952361
Iteration 3: Log-likelihood = -1675.6898936254063
Iteration 4: Log-likelihood = -1674.9034547640667
Iteration 5: Log-likelihood = -1674.369182966404
Iteration 6: Log-likelihood = -1673.9796850340795
Iteration 7: Log-likelihood = -1673.6852604734079
Iteration 8: Log-likelihood = -1673.4576541516078
Iteration 9: Log-likelihood = -1673.2787097269809
Iteration 10: Log-likelihood = -1673.1359501009015
Iteration 11: Log-likelihood = -1673.0204862984735
Iteration 12: Log-likelihood = -1672.9258536033838
Iteration 13: Log-likelihood = -1672.8472860515194
Iteration 14: Log-likelihood = -1672.7812328478365
Iteration 15: Log-likelihood = -1672.725024117896
Iteration 16: Log-likelihood = -1672.6766351794518
Iteration 17: Log-likelihood = -1672.6345181522745
Iteration 18: Log-likelihood = -1672.5974804778466
Iteration 19: Log-likelihood = -1672.5645965068131
Iteration 20: Log-likelihood = -1672.5351426108011
Iteration 21: Log-likelihood = -1672.5085491696302
Iteration 22: Log-likelihood = -1672.484364764128
Iteration 23: Log-likelihood = -1672.4622292693593
Iteration 24: Log-likelihood = -1672.4418534908457
Iteration 25: Log-likelihood = -1672.423003648747
Iteration 26: Log-likelihood = -1672.4054894813567
Iteration 27: Log-likelihood = -1672.3891550702785
Iteration 28: Log-likelihood = -1672.3738717261883
Iteration 29: Log-likelihood = -1672.3595324445862
Iteration 30: Log-likelihood = -1672.3460475650793
Iteration 31: Log-likelihood = -1672.3333413581743
Iteration 32: Log-likelihood = -1672.3213493308087
Iteration 33: Log-likelihood = -1672.3100160913516
Iteration 34: Log-likelihood = -1672.299293652206
Iteration 35: Log-likelihood = -1672.2891400759668
Iteration 36: Log-likelihood = -1672.279518392472
Iteration 37: Log-likelihood = -1672.2703957302351
Iteration 38: Log-likelihood = -1672.2617426178897
Iteration 39: Log-likelihood = -1672.2535324212374
Iteration 40: Log-likelihood = -1672.2457408883747
Iteration 41: Log-likelihood = -1672.238345781472
Iteration 42: Log-likelihood = -1672.2313265780558
Iteration 43: Log-likelihood = -1672.224664228118
Iteration 44: Log-likelihood = -1672.2183409562736
Iteration 45: Log-likelihood = -1672.2123401002395
Iteration 46: Log-likelihood = -1672.2066459785617
Iteration 47: Log-likelihood = -1672.2012437821681
Iteration 48: Log-likelihood = -1672.1961194849953
Iteration 49: Log-likelihood = -1672.1912597701623
Iteration 50: Log-likelihood = -1672.186651968744
k=5: AIC=3362.4, BIC=3400.3

Plot information criteria vs number of components

p4 = plot(k_range, [aic_scores bic_scores];
    xlabel="Number of Components (k)", ylabel="Information Criterion",
    title="Model Selection via Information Criteria",
    label=["AIC" "BIC"], marker=:circle, lw=2
)

optimal_k_aic = k_range[argmin(aic_scores)]
optimal_k_bic = k_range[argmin(bic_scores)]

print("Optimal k: AIC suggests k=$optimal_k_aic, BIC suggests k=$optimal_k_bic\n");
Optimal k: AIC suggests k=3, BIC suggests k=3

Summary

This tutorial demonstrated the complete Poisson Mixture Model workflow:

Key Concepts:

  • Discrete mixtures: Model count data as mixture of Poisson distributions
  • EM algorithm: Iterative optimization with closed-form M-step updates
  • Soft clustering: Posterior responsibilities provide probabilistic assignments
  • Model selection: Information criteria help choose appropriate number of components

Applications:

  • Spike count analysis in neuroscience
  • Customer transaction modeling in business analytics
  • Event frequency analysis in reliability engineering
  • Gene expression count clustering in bioinformatics

Technical Insights:

  • Initialization strategy significantly affects final solution quality
  • Label switching is a fundamental identifiability issue in mixture models
  • Information criteria provide principled approach to model complexity selection
  • Component separation quality affects parameter recovery accuracy

Extensions:

  • Zero-inflated Poisson mixtures for excess zero counts
  • Negative Binomial mixtures for overdispersed count data
  • Bayesian approaches for uncertainty quantification
  • Mixture regression models for count data with covariates

Poisson mixture models provide a flexible framework for modeling heterogeneous count data, enabling both clustering and density estimation while maintaining interpretable probabilistic structure.


This page was generated using Literate.jl.