Simulating and Fitting a Poisson Linear Dynamical System
This tutorial demonstrates how to use StateSpaceDynamics.jl to simulate and fit a Linear Dynamical System (LDS) with Poisson observations using the Laplace-EM algorithm. Unlike the standard Gaussian LDS, this model is designed for count data (e.g., neural spike counts, customer arrivals, or discrete event data) where observations are non-negative integers following Poisson distributions.
The key insight is that while latent dynamics remain continuous and Gaussian, the observations are discrete counts whose rates depend on the latent state through an exponential link function: $\lambda_i(t) = \exp(\mathbf{C}_i^T \mathbf{x}_t + d_i)$.
Load Required Packages
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using StableRNGsSet up reproducible random number generation
rng = StableRNG(54321);Why Poisson Linear Dynamical Systems?
Many real-world phenomena involve discrete events generated by underlying continuous processes. Traditional Gaussian LDS assumes continuous observations, but fails when data are:
- Neural spike counts (non-negative integers)
- Customer arrivals per time window
- Gene expression counts
- Social media posts or interactions
Poisson LDS elegantly handles this by maintaining Gaussian latent dynamics while modeling observations as count data with rates that depend on the hidden state.
Create a Poisson Linear Dynamical System
We define a system where continuous latent dynamics generate discrete count observations. This is particularly relevant in neuroscience (neural spike trains) and other domains where discrete events are generated by underlying continuous processes.
obs_dim = 10 # Number of observed count variables (e.g., neurons)
latent_dim = 2; # Number of latent state dimensionsDefine latent dynamics: same spiral structure as Gaussian LDS Latent states evolve smoothly according to linear dynamics
A = 0.95 * [cos(0.25) -sin(0.25); sin(0.25) cos(0.25)] # Rotation with contraction
b = zeros(latent_dim) # bias
Q = Matrix(0.1 * I(latent_dim)) # Process noise covariance
x0 = zeros(latent_dim) # Initial state mean
P0 = Matrix(0.1 * I(latent_dim)); # Initial state covariancePoisson observation model parameters: For Poisson observations, the rate parameter $\lambda_i$ is modeled as: $\log(\lambda_i) = \mathbf{C}_i^T \mathbf{x}_t + d_i$ where $\mathbf{C}$ maps latent states to log-rates and $d_i$ provides baseline log-rates
log_d = log.(fill(0.1, obs_dim)); # Log baseline rates (small positive rates)Observation matrix $\mathbf{C}$: maps 2D latent states to log-rates for each observed dimension Use positive values so latent activity increases firing rates
C = permutedims([abs.(randn(rng, obs_dim))'; abs.(randn(rng, obs_dim))']);Understanding Poisson LDS Parameters
Latent dynamics parameters (same as Gaussian LDS):
- $A$: How latent states evolve (rotation + contraction creates stable oscillation)
- $Q$: Process noise (uncertainty in latent evolution)
- $x_0$, $P_0$: Initial state distribution
Observation parameters (unique to Poisson case):
- $C[i,:]$: How latent dimensions affect log-rate of observation $i$
- Positive $C[i,j]$: latent dimension $j$ increases firing rate of unit $i$
- Negative $C[i,j]$: latent dimension $j$ decreases firing rate of unit $i$
- $d[i]$: Baseline log-rate for observation $i$ when latent state = 0
- $\exp(d[i])$ gives the baseline firing rate
The Exponential Link Function
The key innovation is how we connect continuous latent states to discrete counts. Instead of linear observations $y = \mathbf{C} \mathbf{x} + \text{noise}$, we use:
\[\lambda_i(t) = \exp(\mathbf{C}_i^T \mathbf{x}_t + d_i)\]
\[y_i(t) \sim \text{Poisson}(\lambda_i(t))\]
The exponential ensures rates are always positive (required for Poisson), and the log-linear relationship means latent states multiplicatively affect firing rates.
The baseline parameter $d_i$ sets the minimum firing rate when latent states are zero.
Construct the model components
state_model = GaussianStateModel(; A, Q, b, x0, P0) # Gaussian latent dynamics
obs_model = PoissonObservationModel(; C, log_d); # Poisson observationsCreate the complete Poisson Linear Dynamical System
true_plds = LinearDynamicalSystem(;
state_model=state_model,
obs_model=obs_model,
latent_dim=latent_dim,
obs_dim=obs_dim,
fit_bool=fill(true, 6) # Learn all parameters: A, Q, C, log_d, x0, P0
);Simulate Latent States and Count Observations
Generate synthetic data from our Poisson LDS. Latent states evolve according to linear dynamics, while observations are drawn from Poisson distributions whose rates depend exponentially on the current latent state.
Generate both latent trajectories and count observations
tSteps = 500
latents, observations = rand(rng, true_plds; tsteps=tSteps, ntrials=1);Visualize Latent Dynamics
Show the underlying continuous dynamics that drive discrete observations. This vector field illustrates how latent states evolve deterministically (ignoring noise).
Create grid for vector field
x = y = -3:0.5:3
X = repeat(x', length(y), 1)
Y = repeat(y, 1, length(x))
U = zeros(size(X)) # Flow in x-direction
V = zeros(size(Y)); # Flow in y-direction
for i in 1:size(X, 1), j in 1:size(X, 2)
v = A * [X[i,j], Y[i,j]]
U[i,j] = v[1] - X[i,j]
V[i,j] = v[2] - Y[i,j]
end
magnitude = @. sqrt(U^2 + V^2) # Normalize arrow lengths for cleaner visualization
U_norm = U ./ magnitude
V_norm = V ./ magnitude;Plot vector field with simulated trajectory
p1 = quiver(X, Y, quiver=(U_norm, V_norm), color=:blue, alpha=0.3,
linewidth=1, arrow=arrow(:closed, :head, 0.1, 0.1))
plot!(latents[1, :, 1], latents[2, :, 1], xlabel=L"x_1", ylabel=L"x_2",
color=:black, linewidth=1.5, title="Latent Dynamics", legend=false)
p1Visualize Latent States and Spike Observations
Create visualizations highlighting the contrast between continuous latent dynamics and discrete count observations (spike trains).
states = latents[:, :, 1]
emissions = observations[:, :, 1]10×500 Matrix{Float64}:
0.0 1.0 0.0 11.0 7.0 2.0 4.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 0.0 2.0 1.0 3.0 6.0 1.0 1.0 1.0 0.0 0.0 1.0 0.0 0.0
0.0 1.0 1.0 3.0 2.0 3.0 1.0 2.0 0.0 1.0 0.0 1.0 0.0 0.0
2.0 2.0 2.0 2.0 3.0 0.0 2.0 1.0 1.0 0.0 0.0 0.0 1.0 1.0
0.0 0.0 2.0 2.0 3.0 5.0 7.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 2.0 1.0 2.0 1.0 3.0 2.0 … 1.0 0.0 0.0 0.0 0.0 1.0 1.0
2.0 2.0 2.0 4.0 1.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 1.0 0.0
1.0 1.0 0.0 1.0 1.0 1.0 6.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
0.0 2.0 0.0 3.0 5.0 1.0 4.0 0.0 0.0 0.0 0.0 0.0 2.0 0.0
0.0 2.0 1.0 1.0 5.0 3.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0Two-panel layout: continuous latent states above, discrete spike rasters below
lim_states = maximum(abs.(states))
p2 = plot(size=(800, 600), layout=@layout[a{0.3h}; b])
for d in 1:latent_dim
plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black,
linewidth=2, label="", subplot=1) # Plot smooth latent state trajectories
end
plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
xticks=[], xlims=(0, tSteps), title="Simulated Latent States",
yformatter=y->"", tickfontsize=12)
colors = palette(:default, obs_dim)
for f in 1:obs_dim
spike_times = findall(x -> x > 0, emissions[f, :])
for t in spike_times
plot!([t, t], [f-0.4, f+0.4], color=colors[f], linewidth=1, label="", subplot=2)
end
end
plot!(subplot=2, yticks=(1:obs_dim, [L"y_{%$d}" for d in 1:obs_dim]),
xlims=(0, tSteps), ylims=(0.5, obs_dim + 0.5), title="Spike Raster Plot",
xlabel="Time", tickfontsize=12, grid=false)
p2The Inference Challenge
Unlike Gaussian LDS where exact inference is possible via Kalman filtering, Poisson observations break the conjugate Gaussian structure. The posterior $p(x_t | y_{1:T})$ is no longer Gaussian, requiring approximations.
Laplace-EM Algorithm:
- E-step: Use Laplace approximation to make posterior "locally Gaussian"
- M-step: Update parameters using expected sufficient statistics
- Iteration: Repeat until ELBO converges
The Laplace approximation finds the mode of the posterior and approximates it with a Gaussian, enabling tractable inference at the cost of some accuracy.
Initialize and Fit Poisson LDS
In practice, we only observe spike counts, not latent states. Our goal is to infer both latent dynamics and the mapping from latent states to firing rates. Start with randomly initialized model.
Random initialization (simulating lack of prior knowledge)
A_init = random_rotation_matrix(latent_dim, rng) # Random rotation matrix
Q_init = Matrix(0.1 * I(latent_dim)) # Process noise guess
C_init = randn(rng, obs_dim, latent_dim) # Random observation mapping
log_d_init = log.(fill(0.1, obs_dim)) # Baseline log-rate guess
x0_init = zeros(latent_dim) # Start from origin
P0_init = Matrix(0.1 * I(latent_dim)); # Initial uncertaintyConstruct naive model
sm_init = GaussianStateModel(; A=A_init, Q=Q_init, b=b, x0=x0_init, P0=P0_init)
om_init = PoissonObservationModel(; C=C_init, log_d=log_d_init)
naive_plds = LinearDynamicalSystem(;
state_model=sm_init,
obs_model=om_init,
latent_dim=latent_dim,
obs_dim=obs_dim,
fit_bool=fill(true, 6)
);For Poisson observations, this requires Laplace approximations since the posterior is no longer Gaussian (unlike linear-Gaussian case)
smoothed_x_pre, smoothed_p_pre = smooth(naive_plds, observations);Compare true vs. initial estimated latent states
p3 = plot()
for d in 1:latent_dim
plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black,
linewidth=2, label=(d==1 ? "True" : ""), alpha=0.8)
plot!(1:tSteps, smoothed_x_pre[d, :, 1] .+ lim_states * (d-1), color=:red,
linewidth=2, label=(d==1 ? "Initial Est." : ""), alpha=0.8)
end
plot!(yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
xlabel="Time", xlims=(0, tSteps), title="Pre-EM: True vs. Initial Estimates",
yformatter=y->"", tickfontsize=12, legend=:topright)
p3Fit Using Laplace-EM Algorithm
Fit the model - using fewer iterations due to computational cost
elbo, _ = fit!(naive_plds, observations; max_iter=25, tol=1e-6);
print("Laplace-EM completed in $(length(elbo)) iterations\n")
Fitting Poisson LDS via LaPlaceEM... 8%|████ | ETA: 0:00:24 ( 1.03 s/it)
Fitting Poisson LDS via LaPlaceEM... 12%|██████ | ETA: 0:00:16 ( 0.74 s/it)
Fitting Poisson LDS via LaPlaceEM... 16%|████████ | ETA: 0:00:13 ( 0.60 s/it)
Fitting Poisson LDS via LaPlaceEM... 20%|██████████ | ETA: 0:00:10 ( 0.51 s/it)
Fitting Poisson LDS via LaPlaceEM... 24%|████████████ | ETA: 0:00:09 ( 0.46 s/it)
Fitting Poisson LDS via LaPlaceEM... 28%|██████████████ | ETA: 0:00:08 ( 0.44 s/it)
Fitting Poisson LDS via LaPlaceEM... 32%|████████████████ | ETA: 0:00:09 ( 0.54 s/it)
Fitting Poisson LDS via LaPlaceEM... 36%|██████████████████ | ETA: 0:00:09 ( 0.56 s/it)
Fitting Poisson LDS via LaPlaceEM... 40%|████████████████████ | ETA: 0:00:08 ( 0.53 s/it)
Fitting Poisson LDS via LaPlaceEM... 44%|██████████████████████ | ETA: 0:00:07 ( 0.53 s/it)
Fitting Poisson LDS via LaPlaceEM... 48%|████████████████████████ | ETA: 0:00:07 ( 0.52 s/it)
Fitting Poisson LDS via LaPlaceEM... 52%|██████████████████████████ | ETA: 0:00:06 ( 0.51 s/it)
Fitting Poisson LDS via LaPlaceEM... 56%|████████████████████████████ | ETA: 0:00:06 ( 0.51 s/it)
Fitting Poisson LDS via LaPlaceEM... 60%|██████████████████████████████ | ETA: 0:00:05 ( 0.49 s/it)
Fitting Poisson LDS via LaPlaceEM... 64%|████████████████████████████████ | ETA: 0:00:04 ( 0.48 s/it)
Fitting Poisson LDS via LaPlaceEM... 72%|████████████████████████████████████ | ETA: 0:00:03 ( 0.43 s/it)
Fitting Poisson LDS via LaPlaceEM... 80%|████████████████████████████████████████ | ETA: 0:00:02 ( 0.40 s/it)
Fitting Poisson LDS via LaPlaceEM... 84%|██████████████████████████████████████████ | ETA: 0:00:02 ( 0.38 s/it)
Fitting Poisson LDS via LaPlaceEM... 92%|██████████████████████████████████████████████ | ETA: 0:00:01 ( 0.35 s/it)
Fitting Poisson LDS via LaPlaceEM... 100%|██████████████████████████████████████████████████| Time: 0:00:08 ( 0.33 s/it)
Laplace-EM completed in 25 iterationsParameter identifiability:
- Scale ambiguity: ($\mathbf{C}$, $d$) and ($\alpha \mathbf{C}$, $d + \log(\alpha)$) give same likelihood
- Can be resolved by constraining norm of $\mathbf{C}$ or fixing one element
- Rotation ambiguity in latent space (same as Gaussian LDS)
Perform smoothing with learned parameters
smoothed_x_post, smoothed_p_post = smooth(naive_plds, observations);Compare true vs. learned latent state estimates
p4 = plot()
for d in 1:latent_dim
plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black,
linewidth=2, label=(d==1 ? "True" : ""), alpha=0.8)
plot!(1:tSteps, smoothed_x_post[d, :, 1] .+ lim_states * (d-1), color=:red,
linewidth=2, label=(d==1 ? "Post-EM Est." : ""), alpha=0.8)
end
plot!(yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
xlabel="Time", xlims=(0, tSteps), title="Post-EM: True vs. Learned Estimates",
yformatter=y->"", tickfontsize=12, legend=:topright)
p4Monitor ELBO Convergence
The Evidence Lower Bound (ELBO) tracks algorithm progress. For Poisson LDS, ELBO includes both data likelihood and Laplace approximation terms. Convergence may be less smooth than Gaussian case due to approximations.
p5 = plot(elbo, xlabel="Iteration", ylabel="ELBO",
title="Laplace-EM Convergence", legend=false,
linewidth=2, marker=:circle, markersize=3, color=:darkgreen)
if length(elbo) > 1
improvement = elbo[end] - elbo[1]
annotate!(p5, length(elbo)*0.7, elbo[end]*0.95,
text("Improvement: $(round(improvement, digits=1))", 10)) # Add convergence annotation
end
p5Summary
This tutorial demonstrated fitting a Poisson Linear Dynamical System:
Key Concepts:
- Hybrid model: Continuous Gaussian latent dynamics generate discrete Poisson observations
- Exponential link: $\log(\lambda_i) = \mathbf{C}_i^T \mathbf{x}_t + d_i$ connects latent states to count rates
- Laplace-EM: Handles non-conjugate Poisson-Gaussian combination through approximations
- Count data modeling: Extends LDS framework to spike trains and event sequences
Technical Insights:
- More computationally intensive than Gaussian LDS due to required approximations
- Convergence can be slower and less smooth than conjugate models
- Parameter recovery quality depends on observation density and latent state separation
- Laplace approximations become more accurate with higher count rates
Advantages:
- Principled probabilistic framework for count data
- Maintains interpretable continuous latent dynamics
- Enables simultaneous state estimation and parameter learning
- Provides uncertainty quantification for both states and parameters
The Poisson LDS successfully bridges continuous dynamical systems and discrete observation models, enabling principled analysis of count data with underlying temporal structure.
This page was generated using Literate.jl.