Simulating and Fitting a Poisson Mixture Model

This tutorial demonstrates how to use StateSpaceDynamics.jl to create a Poisson Mixture Model and fit it using the EM algorithm.

using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using StatsPlots
using Distributions

rng = StableRNG(1234);

Create a true PoissonMixtureModel to simulate from

k = 3
true_λs = [5.0, 10.0, 25.0]  # Poisson means
true_πs = [0.25, 0.45, 0.3]   # Mixing weights

true_pmm = PoissonMixtureModel(k, true_λs, true_πs)
PoissonMixtureModel{Float64, Vector{Float64}}(3, [5.0, 10.0, 25.0], [0.25, 0.45, 0.3])

Generate data from the true model

n = 500
labels = rand(rng, Categorical(true_πs), n)
data = [rand(rng, Poisson(true_λs[labels[i]])) for i in 1:n]  # Vector{Int}
500-element Vector{Int64}:
 33
  9
  9
  8
 25
 10
  5
 15
  5
 30
  ⋮
  1
 10
 21
  7
 24
 25
  5
  7
  9

Plot histogram of Poisson samples by component

p1 = histogram(
    data;
    group=labels,
    bins=0:1:maximum(data),
    bar_position=:dodge,
    xlabel="Count",
    ylabel="Frequency",
    title="Poisson Mixture Samples by Component",
    alpha=0.7,
    legend=:topright,
)
p1
Example block output

Fit a new PoissonMixtureModel to the data

fit_pmm = PoissonMixtureModel(k)
_, lls = fit!(fit_pmm, data; maxiter=100, tol=1e-6, initialize_kmeans=true)
([1.5417854767112443e-14 0.23281667300994824 … 0.5551251910469738 0.23281667300994824; 0.9999988264615685 0.0007939518191292314 … 8.328588068036678e-5 0.0007939518191292314; 1.1735384162374018e-6 0.7663893751709226 … 0.44479152307234593 0.7663893751709226], [-1701.6126496215945, -1688.5801867165533, -1682.2780488227118, -1679.1137695443554, -1677.4220685907212, -1676.4090593967044, -1675.7141513588833, -1675.1814043453699, -1674.7444722444727, -1674.3740832673652  …  -1672.106957545556, -1672.1069543474243, -1672.1069516287737, -1672.1069493177001, -1672.106947353116, -1672.1069456830764, -1672.1069442634239, -1672.1069430566226, -1672.1069420307567, -1672.1069411586964])

Plot log-likelihoods to visualize EM convergence

p2 = plot(
    lls;
    xlabel="Iteration",
    ylabel="Log-Likelihood",
    title="EM Convergence (Poisson Mixture)",
    marker=:circle,
    label="log_likelihood",
)
p2
Example block output

Plot model PMFs over the data histogram

p3 = histogram(
    data;
    bins=0:1:maximum(data),
    normalize=true,
    alpha=0.3,
    label="Data",
    xlabel="Count",
    ylabel="Density",
    title="Poisson Mixtures: Data and PMFs",
)

x = collect(0:maximum(data))
colors = [:red, :green, :blue]

for i in 1:k
    λi = fit_pmm.λₖ[i]
    πi = fit_pmm.πₖ[i]
    pmf_i = πi .* pdf.(Poisson(λi), x)
    plot!(
        p3, x, pmf_i;
        lw=2,
        c=colors[i],
        label="Comp $i (λ=$(round(λi, sigdigits=3)))",
    )
end

mix_pmf = reduce(+, (πi .* pdf.(Poisson(λi), x) for (λi, πi) in zip(fit_pmm.λₖ, fit_pmm.πₖ)))
plot!(
    p3, x, mix_pmf;
    lw=3, ls=:dash, c=:black,
    label="Mixture",
)

p3
Example block output

This page was generated using Literate.jl.