Simulating and Fitting a Poisson Mixture Model

This tutorial demonstrates how to use StateSpaceDynamics.jl to create a Poisson Mixture Model and fit it using the EM algorithm.

using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using StableRNGs
using StatsPlots
using Distributions

rng = StableRNG(1234);

Create a true PoissonMixtureModel to simulate from

k = 3
true_λs = [5.0, 10.0, 25.0]  # Poisson means
true_πs = [0.25, 0.45, 0.3]   # Mixing weights

true_pmm = PoissonMixtureModel(k, true_λs, true_πs)
PoissonMixtureModel{Float64, Vector{Float64}}(3, [5.0, 10.0, 25.0], [0.25, 0.45, 0.3])

Generate data from the true model

n = 500
labels = rand(rng, Categorical(true_πs), n)
data = [rand(rng, Poisson(true_λs[labels[i]])) for i in 1:n]  # Vector{Int}
500-element Vector{Int64}:
 33
  9
  9
  8
 25
 10
  5
 15
  5
 30
  ⋮
  1
 10
 21
  7
 24
 25
  5
  7
  9

Plot histogram of Poisson samples by component

p1 = histogram(
    data;
    group=labels,
    bins=0:1:maximum(data),
    bar_position=:dodge,
    xlabel="Count",
    ylabel="Frequency",
    title="Poisson Mixture Samples by Component",
    alpha=0.7,
    legend=:topright,
)
p1
Example block output

Fit a new PoissonMixtureModel to the data

fit_pmm = PoissonMixtureModel(k)
_, lls = fit!(fit_pmm, data; maxiter=100, tol=1e-6, initialize_kmeans=true)
([1.167299485809293e-6 0.7670285733361216 … 0.4456154946697474 0.7670285733361216; 0.9999988327004986 0.0007940905571800423 … 8.334359187998249e-5 0.0007940905571800423; 1.5235751770191232e-14 0.23217733610669858 … 0.5543011617383728 0.23217733610669858], [-1677.6322074652844, -1674.3624894485276, -1673.5404197496819, -1673.2181725367448, -1673.0264176702262, -1672.882289515502, -1672.764176466248, -1672.6648070344565, -1672.5805981923975, -1672.5091100384484  …  -1672.1069591732414, -1672.1069557303338, -1672.1069528037415, -1672.1069503160506, -1672.1069482014225, -1672.10694640391, -1672.1069448759567, -1672.1069435771435, -1672.106942473089, -1672.1069415346014])

Plot log-likelihoods to visualize EM convergence

p2 = plot(
    lls;
    xlabel="Iteration",
    ylabel="Log-Likelihood",
    title="EM Convergence (Poisson Mixture)",
    marker=:circle,
    label="log_likelihood",
)
p2
Example block output

Plot model PMFs over the data histogram

p3 = histogram(
    data;
    bins=0:1:maximum(data),
    normalize=true,
    alpha=0.3,
    label="Data",
    xlabel="Count",
    ylabel="Density",
    title="Poisson Mixtures: Data and PMFs",
)

x = collect(0:maximum(data))
colors = [:red, :green, :blue]

for i in 1:k
    λi = fit_pmm.λₖ[i]
    πi = fit_pmm.πₖ[i]
    pmf_i = πi .* pdf.(Poisson(λi), x)
    plot!(
        p3, x, pmf_i;
        lw=2,
        c=colors[i],
        label="Comp $i (λ=$(round(λi, sigdigits=3)))",
    )
end

mix_pmf = reduce(+, (πi .* pdf.(Poisson(λi), x) for (λi, πi) in zip(fit_pmm.λₖ, fit_pmm.πₖ)))
plot!(
    p3, x, mix_pmf;
    lw=3, ls=:dash, c=:black,
    label="Mixture",
)

p3
Example block output

This page was generated using Literate.jl.