Simulating and Fitting a Poisson Linear Dynamical System

This tutorial demonstrates how to use StateSpaceDynamics.jl to simulate and fit a Linear Dynamical System (LDS) with Poisson observations using the Laplace-EM algorithm. Unlike the standard Gaussian LDS, this model is designed for count data (e.g., neural spike counts, customer arrivals, or discrete event data) where observations are non-negative integers following Poisson distributions.

The key insight is that while latent dynamics remain continuous and Gaussian, the observations are discrete counts whose rates depend on the latent state through an exponential link function: $\lambda_i(t) = \exp(\mathbf{C}_i^T \mathbf{x}_t + d_i)$.

Load Required Packages

using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using StableRNGs

Set up reproducible random number generation

rng = StableRNG(54321);

Why Poisson Linear Dynamical Systems?

Many real-world phenomena involve discrete events generated by underlying continuous processes. Traditional Gaussian LDS assumes continuous observations, but fails when data are:

  • Neural spike counts (non-negative integers)
  • Customer arrivals per time window
  • Gene expression counts
  • Social media posts or interactions

Poisson LDS elegantly handles this by maintaining Gaussian latent dynamics while modeling observations as count data with rates that depend on the hidden state.

Create a Poisson Linear Dynamical System

We define a system where continuous latent dynamics generate discrete count observations. This is particularly relevant in neuroscience (neural spike trains) and other domains where discrete events are generated by underlying continuous processes.

obs_dim = 10       # Number of observed count variables (e.g., neurons)
latent_dim = 2;    # Number of latent state dimensions

Define latent dynamics: same spiral structure as Gaussian LDS Latent states evolve smoothly according to linear dynamics

A = 0.95 * [cos(0.25) -sin(0.25); sin(0.25) cos(0.25)]  # Rotation with contraction
b = zeros(latent_dim)               # bias
Q = Matrix(0.1 * I(latent_dim))     # Process noise covariance
x0 = zeros(latent_dim)              # Initial state mean
P0 = Matrix(0.1 * I(latent_dim));   # Initial state covariance

Poisson observation model parameters: For Poisson observations, the rate parameter $\lambda_i$ is modeled as: $\log(\lambda_i) = \mathbf{C}_i^T \mathbf{x}_t + d_i$ where $\mathbf{C}$ maps latent states to log-rates and $d_i$ provides baseline log-rates

log_d = log.(fill(0.1, obs_dim));    # Log baseline rates (small positive rates)

Observation matrix $\mathbf{C}$: maps 2D latent states to log-rates for each observed dimension Use positive values so latent activity increases firing rates

C = permutedims([abs.(randn(rng, obs_dim))'; abs.(randn(rng, obs_dim))']);

Understanding Poisson LDS Parameters

Latent dynamics parameters (same as Gaussian LDS):

  • $A$: How latent states evolve (rotation + contraction creates stable oscillation)
  • $Q$: Process noise (uncertainty in latent evolution)
  • $x_0$, $P_0$: Initial state distribution

Observation parameters (unique to Poisson case):

  • $C[i,:]$: How latent dimensions affect log-rate of observation $i$
  • Positive $C[i,j]$: latent dimension $j$ increases firing rate of unit $i$
  • Negative $C[i,j]$: latent dimension $j$ decreases firing rate of unit $i$
  • $d[i]$: Baseline log-rate for observation $i$ when latent state = 0
  • $\exp(d[i])$ gives the baseline firing rate

The key innovation is how we connect continuous latent states to discrete counts. Instead of linear observations $y = \mathbf{C} \mathbf{x} + \text{noise}$, we use:

\[\lambda_i(t) = \exp(\mathbf{C}_i^T \mathbf{x}_t + d_i)\]

\[y_i(t) \sim \text{Poisson}(\lambda_i(t))\]

The exponential ensures rates are always positive (required for Poisson), and the log-linear relationship means latent states multiplicatively affect firing rates.

The baseline parameter $d_i$ sets the minimum firing rate when latent states are zero.

Construct the model components

state_model = GaussianStateModel(; A, Q, b, x0, P0)          # Gaussian latent dynamics
obs_model = PoissonObservationModel(; C, log_d);           # Poisson observations

Create the complete Poisson Linear Dynamical System

true_plds = LinearDynamicalSystem(;
    state_model=state_model,
    obs_model=obs_model,
    latent_dim=latent_dim,
    obs_dim=obs_dim,
    fit_bool=fill(true, 6)  # Learn all parameters: A, Q, C, log_d, x0, P0
);

Simulate Latent States and Count Observations

Generate synthetic data from our Poisson LDS. Latent states evolve according to linear dynamics, while observations are drawn from Poisson distributions whose rates depend exponentially on the current latent state.

Generate both latent trajectories and count observations

tSteps = 500
latents, observations = rand(rng, true_plds; tsteps=tSteps, ntrials=1);

Visualize Latent Dynamics

Show the underlying continuous dynamics that drive discrete observations. This vector field illustrates how latent states evolve deterministically (ignoring noise).

Create grid for vector field

x = y = -3:0.5:3
X = repeat(x', length(y), 1)
Y = repeat(y, 1, length(x))
U = zeros(size(X))  # Flow in x-direction
V = zeros(size(Y));  # Flow in y-direction

for i in 1:size(X, 1), j in 1:size(X, 2)
    v = A * [X[i,j], Y[i,j]]
    U[i,j] = v[1] - X[i,j]
    V[i,j] = v[2] - Y[i,j]
end

magnitude = @. sqrt(U^2 + V^2)  # Normalize arrow lengths for cleaner visualization
U_norm = U ./ magnitude
V_norm = V ./ magnitude;

Plot vector field with simulated trajectory

p1 = quiver(X, Y, quiver=(U_norm, V_norm), color=:blue, alpha=0.3,
           linewidth=1, arrow=arrow(:closed, :head, 0.1, 0.1))
plot!(latents[1, :, 1], latents[2, :, 1], xlabel=L"x_1", ylabel=L"x_2",
      color=:black, linewidth=1.5, title="Latent Dynamics", legend=false)

p1
Example block output

Visualize Latent States and Spike Observations

Create visualizations highlighting the contrast between continuous latent dynamics and discrete count observations (spike trains).

states = latents[:, :, 1]
emissions = observations[:, :, 1]
10×500 Matrix{Float64}:
 0.0  1.0  0.0  11.0  7.0  2.0  4.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 1.0  0.0  2.0   1.0  3.0  6.0  1.0     1.0  1.0  0.0  0.0  1.0  0.0  0.0
 0.0  1.0  1.0   3.0  2.0  3.0  1.0     2.0  0.0  1.0  0.0  1.0  0.0  0.0
 2.0  2.0  2.0   2.0  3.0  0.0  2.0     1.0  1.0  0.0  0.0  0.0  1.0  1.0
 0.0  0.0  2.0   2.0  3.0  5.0  7.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 1.0  2.0  1.0   2.0  1.0  3.0  2.0  …  1.0  0.0  0.0  0.0  0.0  1.0  1.0
 2.0  2.0  2.0   4.0  1.0  0.0  0.0     1.0  0.0  0.0  1.0  0.0  1.0  0.0
 1.0  1.0  0.0   1.0  1.0  1.0  6.0     0.0  1.0  0.0  0.0  0.0  0.0  0.0
 0.0  2.0  0.0   3.0  5.0  1.0  4.0     0.0  0.0  0.0  0.0  0.0  2.0  0.0
 0.0  2.0  1.0   1.0  5.0  3.0  1.0     1.0  0.0  0.0  0.0  0.0  0.0  0.0

Two-panel layout: continuous latent states above, discrete spike rasters below

lim_states = maximum(abs.(states))

p2 = plot(size=(800, 600), layout=@layout[a{0.3h}; b])

for d in 1:latent_dim
    plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black,
          linewidth=2, label="", subplot=1) # Plot smooth latent state trajectories

end

plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
      xticks=[], xlims=(0, tSteps), title="Simulated Latent States",
      yformatter=y->"", tickfontsize=12)

colors = palette(:default, obs_dim)
for f in 1:obs_dim
    spike_times = findall(x -> x > 0, emissions[f, :])
    for t in spike_times
        plot!([t, t], [f-0.4, f+0.4], color=colors[f], linewidth=1, label="", subplot=2)
    end
end

plot!(subplot=2, yticks=(1:obs_dim, [L"y_{%$d}" for d in 1:obs_dim]),
      xlims=(0, tSteps), ylims=(0.5, obs_dim + 0.5), title="Spike Raster Plot",
      xlabel="Time", tickfontsize=12, grid=false)

p2
Example block output

The Inference Challenge

Unlike Gaussian LDS where exact inference is possible via Kalman filtering, Poisson observations break the conjugate Gaussian structure. The posterior $p(x_t | y_{1:T})$ is no longer Gaussian, requiring approximations.

Laplace-EM Algorithm:

  • E-step: Use Laplace approximation to make posterior "locally Gaussian"
  • M-step: Update parameters using expected sufficient statistics
  • Iteration: Repeat until ELBO converges

The Laplace approximation finds the mode of the posterior and approximates it with a Gaussian, enabling tractable inference at the cost of some accuracy.

Initialize and Fit Poisson LDS

In practice, we only observe spike counts, not latent states. Our goal is to infer both latent dynamics and the mapping from latent states to firing rates. Start with randomly initialized model.

Random initialization (simulating lack of prior knowledge)

A_init = random_rotation_matrix(latent_dim, rng)  # Random rotation matrix
Q_init = Matrix(0.1 * I(latent_dim))              # Process noise guess
C_init = randn(rng, obs_dim, latent_dim)          # Random observation mapping
log_d_init = log.(fill(0.1, obs_dim))             # Baseline log-rate guess
x0_init = zeros(latent_dim)                       # Start from origin
P0_init = Matrix(0.1 * I(latent_dim));             # Initial uncertainty

Construct naive model

sm_init = GaussianStateModel(; A=A_init, Q=Q_init, b=b, x0=x0_init, P0=P0_init)
om_init = PoissonObservationModel(; C=C_init, log_d=log_d_init)

naive_plds = LinearDynamicalSystem(;
    state_model=sm_init,
    obs_model=om_init,
    latent_dim=latent_dim,
    obs_dim=obs_dim,
    fit_bool=fill(true, 6)
);

For Poisson observations, this requires Laplace approximations since the posterior is no longer Gaussian (unlike linear-Gaussian case)

smoothed_x_pre, smoothed_p_pre = smooth(naive_plds, observations);

Compare true vs. initial estimated latent states

p3 = plot()
for d in 1:latent_dim
    plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black,
          linewidth=2, label=(d==1 ? "True" : ""), alpha=0.8)
    plot!(1:tSteps, smoothed_x_pre[d, :, 1] .+ lim_states * (d-1), color=:red,
          linewidth=2, label=(d==1 ? "Initial Est." : ""), alpha=0.8)
end

plot!(yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
      xlabel="Time", xlims=(0, tSteps), title="Pre-EM: True vs. Initial Estimates",
      yformatter=y->"", tickfontsize=12, legend=:topright)

p3
Example block output

Fit Using Laplace-EM Algorithm

Fit the model - using fewer iterations due to computational cost

elbo, _ = fit!(naive_plds, observations; max_iter=25, tol=1e-6);

print("Laplace-EM completed in $(length(elbo)) iterations\n")

Fitting Poisson LDS via LaPlaceEM...   8%|████                                              |  ETA: 0:00:24 ( 1.03  s/it)
Fitting Poisson LDS via LaPlaceEM...  12%|██████                                            |  ETA: 0:00:16 ( 0.74  s/it)
Fitting Poisson LDS via LaPlaceEM...  16%|████████                                          |  ETA: 0:00:13 ( 0.60  s/it)
Fitting Poisson LDS via LaPlaceEM...  20%|██████████                                        |  ETA: 0:00:10 ( 0.51  s/it)
Fitting Poisson LDS via LaPlaceEM...  24%|████████████                                      |  ETA: 0:00:09 ( 0.46  s/it)
Fitting Poisson LDS via LaPlaceEM...  28%|██████████████                                    |  ETA: 0:00:08 ( 0.44  s/it)
Fitting Poisson LDS via LaPlaceEM...  32%|████████████████                                  |  ETA: 0:00:09 ( 0.54  s/it)
Fitting Poisson LDS via LaPlaceEM...  36%|██████████████████                                |  ETA: 0:00:09 ( 0.56  s/it)
Fitting Poisson LDS via LaPlaceEM...  40%|████████████████████                              |  ETA: 0:00:08 ( 0.53  s/it)
Fitting Poisson LDS via LaPlaceEM...  44%|██████████████████████                            |  ETA: 0:00:07 ( 0.53  s/it)
Fitting Poisson LDS via LaPlaceEM...  48%|████████████████████████                          |  ETA: 0:00:07 ( 0.52  s/it)
Fitting Poisson LDS via LaPlaceEM...  52%|██████████████████████████                        |  ETA: 0:00:06 ( 0.51  s/it)
Fitting Poisson LDS via LaPlaceEM...  56%|████████████████████████████                      |  ETA: 0:00:06 ( 0.51  s/it)
Fitting Poisson LDS via LaPlaceEM...  60%|██████████████████████████████                    |  ETA: 0:00:05 ( 0.49  s/it)
Fitting Poisson LDS via LaPlaceEM...  64%|████████████████████████████████                  |  ETA: 0:00:04 ( 0.48  s/it)
Fitting Poisson LDS via LaPlaceEM...  72%|████████████████████████████████████              |  ETA: 0:00:03 ( 0.43  s/it)
Fitting Poisson LDS via LaPlaceEM...  80%|████████████████████████████████████████          |  ETA: 0:00:02 ( 0.40  s/it)
Fitting Poisson LDS via LaPlaceEM...  84%|██████████████████████████████████████████        |  ETA: 0:00:02 ( 0.38  s/it)
Fitting Poisson LDS via LaPlaceEM...  92%|██████████████████████████████████████████████    |  ETA: 0:00:01 ( 0.35  s/it)
Fitting Poisson LDS via LaPlaceEM... 100%|██████████████████████████████████████████████████| Time: 0:00:08 ( 0.33  s/it)
Laplace-EM completed in 25 iterations

Parameter identifiability:

  • Scale ambiguity: ($\mathbf{C}$, $d$) and ($\alpha \mathbf{C}$, $d + \log(\alpha)$) give same likelihood
  • Can be resolved by constraining norm of $\mathbf{C}$ or fixing one element
  • Rotation ambiguity in latent space (same as Gaussian LDS)

Perform smoothing with learned parameters

smoothed_x_post, smoothed_p_post = smooth(naive_plds, observations);

Compare true vs. learned latent state estimates

p4 = plot()
for d in 1:latent_dim
    plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black,
          linewidth=2, label=(d==1 ? "True" : ""), alpha=0.8)
    plot!(1:tSteps, smoothed_x_post[d, :, 1] .+ lim_states * (d-1), color=:red,
          linewidth=2, label=(d==1 ? "Post-EM Est." : ""), alpha=0.8)
end

plot!(yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
      xlabel="Time", xlims=(0, tSteps), title="Post-EM: True vs. Learned Estimates",
      yformatter=y->"", tickfontsize=12, legend=:topright)

p4
Example block output

Monitor ELBO Convergence

The Evidence Lower Bound (ELBO) tracks algorithm progress. For Poisson LDS, ELBO includes both data likelihood and Laplace approximation terms. Convergence may be less smooth than Gaussian case due to approximations.

p5 = plot(elbo, xlabel="Iteration", ylabel="ELBO",
          title="Laplace-EM Convergence", legend=false,
          linewidth=2, marker=:circle, markersize=3, color=:darkgreen)

if length(elbo) > 1
    improvement = elbo[end] - elbo[1]
    annotate!(p5, length(elbo)*0.7, elbo[end]*0.95,
        text("Improvement: $(round(improvement, digits=1))", 10)) # Add convergence annotation
end

p5
Example block output

Summary

This tutorial demonstrated fitting a Poisson Linear Dynamical System:

Key Concepts:

  • Hybrid model: Continuous Gaussian latent dynamics generate discrete Poisson observations
  • Exponential link: $\log(\lambda_i) = \mathbf{C}_i^T \mathbf{x}_t + d_i$ connects latent states to count rates
  • Laplace-EM: Handles non-conjugate Poisson-Gaussian combination through approximations
  • Count data modeling: Extends LDS framework to spike trains and event sequences

Technical Insights:

  • More computationally intensive than Gaussian LDS due to required approximations
  • Convergence can be slower and less smooth than conjugate models
  • Parameter recovery quality depends on observation density and latent state separation
  • Laplace approximations become more accurate with higher count rates

Advantages:

  • Principled probabilistic framework for count data
  • Maintains interpretable continuous latent dynamics
  • Enables simultaneous state estimation and parameter learning
  • Provides uncertainty quantification for both states and parameters

The Poisson LDS successfully bridges continuous dynamical systems and discrete observation models, enabling principled analysis of count data with underlying temporal structure.


This page was generated using Literate.jl.