Simulating and Fitting a Poisson Linear Dynamical System

This tutorial demonstrates how to use StateSpaceDynamics.jl to simulate and fit a Linear Dynamical System (LDS) with Poisson observations using the Laplace-EM algorithm.

Load Packages

using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using StableRNGs
rng = StableRNG(123);

Create a Poisson Linear Dynamical System

obs_dim = 10
latent_dim = 2

A = 0.95 * [cos(0.25) -sin(0.25); sin(0.25) cos(0.25)]
Q = Matrix(0.1 * I(latent_dim))
x0 = zeros(latent_dim)
P0 = Matrix(0.1 * I(latent_dim))

log_d = log.(fill(0.1, obs_dim))
C = permutedims([abs.(randn(rng, obs_dim))'; abs.(randn(rng, obs_dim))'])

state_model = GaussianStateModel(; A, Q, x0, P0)
obs_model = PoissonObservationModel(; C, log_d)

true_plds = LinearDynamicalSystem(;
    state_model=state_model,
    obs_model=obs_model,
    latent_dim=latent_dim,
    obs_dim=obs_dim,
    fit_bool=fill(true, 6)
)
LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, PoissonObservationModel{Float64, Matrix{Float64}, Vector{Float64}}}(GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}([0.9204668006251124 -0.2350337612917968; 0.2350337612917968 0.9204668006251124], [0.1 0.0; 0.0 0.1], [0.0, 0.0], [0.1 0.0; 0.0 0.1]), PoissonObservationModel{Float64, Matrix{Float64}, Vector{Float64}}([0.12683768965424458 0.9995722599695167; 0.6668851724871252 1.4919831226368483; … ; 0.9671975288083468 0.3274601670258862; 1.3641880343579902 0.5067518363436612], [-2.3025850929940455, -2.3025850929940455, -2.3025850929940455, -2.3025850929940455, -2.3025850929940455, -2.3025850929940455, -2.3025850929940455, -2.3025850929940455, -2.3025850929940455, -2.3025850929940455]), 2, 10, Bool[1, 1, 1, 1, 1, 1])

Simulate Latent States and Observations

tSteps = 500
latents, observations = rand(rng, true_plds; tsteps=tSteps, ntrials=1)
([0.3302411483795398 -0.033641736953193446 … 1.112518117190165 0.7407506813783344; -0.31796663176380235 -0.44250663498959253 … -0.9105379304111921 -1.1451694994517603;;;], [1.0 0.0 … 0.0 1.0; 1.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 1.0 0.0; 1.0 1.0 … 2.0 3.0;;;])

Plot Vector Field of Latent Dynamics

x = y = -3:0.5:3
X = repeat(x', length(y), 1)
Y = repeat(y, 1, length(x))
U = zeros(size(X))
V = zeros(size(Y))

for i in 1:size(X, 1), j in 1:size(X, 2)
    v = A * [X[i,j], Y[i,j]]
    U[i,j] = v[1] - X[i,j]
    V[i,j] = v[2] - Y[i,j]
end

magnitude = @. sqrt(U^2 + V^2)
U_norm = U ./ magnitude
V_norm = V ./ magnitude

p = quiver(X, Y, quiver=(U_norm, V_norm), color=:blue, alpha=0.3,
           linewidth=1, arrow=arrow(:closed, :head, 0.1, 0.1))
plot!(latents[1, :, 1], latents[2, :, 1], xlabel="x₁", ylabel="x₂",
      color=:black, linewidth=1.5, title="Latent Dynamics", legend=false)
Example block output

Plot Latent States and Observations

states = latents[:, :, 1]
emissions = observations[:, :, 1]
time_bins = size(states, 2)

plot(size=(800, 600), layout=@layout[a{0.3h}; b])

lim_states = maximum(abs.(states))
for d in 1:latent_dim
    plot!(1:time_bins, states[d, :] .+ lim_states * (d-1), color=:black,
          linewidth=2, label="", subplot=1)
end

plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
      xticks=[], xlims=(0, time_bins), title="Simulated Latent States",
      yformatter=y->"", tickfontsize=12)

colors = palette(:default, obs_dim)
for f in 1:obs_dim
    spike_times = findall(x -> x > 0, emissions[f, :])
    for t in spike_times
        plot!([t, t], [f-0.4, f+0.4], color=colors[f], linewidth=1, label="", subplot=2)
    end
end

plot!(subplot=2, yticks=(1:obs_dim, [L"y_{%$d}" for d in 1:obs_dim]),
      xlims=(0, time_bins), ylims=(0.5, obs_dim + 0.5), title="Simulated Emissions",
      xlabel="Time", tickfontsize=12, grid=false)
Example block output

Initialize Model and Smooth

Initialize with random parameters

A_init = random_rotation_matrix(latent_dim, rng)
Q_init = Matrix(0.1 * I(latent_dim))
C_init = randn(rng, obs_dim, latent_dim)
log_d_init = log.(fill(0.1, obs_dim))
x0_init = zeros(latent_dim)
P0_init = Matrix(0.1 * I(latent_dim))

sm_init = GaussianStateModel(; A=A_init, Q=Q_init, x0=x0_init, P0=P0_init)
om_init = PoissonObservationModel(; C=C_init, log_d=log_d_init)

naive_plds = LinearDynamicalSystem(;
    state_model=sm_init,
    obs_model=om_init,
    latent_dim=latent_dim,
    obs_dim=obs_dim,
    fit_bool=fill(true, 6)
)

smoothed_x, smoothed_p, _ = smooth(naive_plds, observations)

plot()
for d in 1:latent_dim
    plot!(1:time_bins, states[d, :] .+ lim_states * (d-1), color=:black, linewidth=2, label="", subplot=1)
    plot!(1:time_bins, smoothed_x[d, :, 1] .+ lim_states * (d-1), color=:red, linewidth=2, label="", subplot=1)
end

plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
      xticks=[], xlims=(0, time_bins), title="True vs. Predicted Latent States (Pre-EM)",
      yformatter=y->"", tickfontsize=12)
Example block output

Fit the Poisson LDS Using Laplace EM

elbo, _ = fit!(naive_plds, observations; max_iter=25, tol=1e-6)

smoothed_x, smoothed_p, _ = smooth(naive_plds, observations)

plot()
for d in 1:latent_dim
    plot!(1:time_bins, states[d, :] .+ lim_states * (d-1), color=:black, linewidth=2, label="", subplot=1)
    plot!(1:time_bins, smoothed_x[d, :, 1] .+ lim_states * (d-1), color=:red, linewidth=2, label="", subplot=1)
end

plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
      xticks=[], xlims=(0, time_bins), title="True vs. Predicted Latent States (Post-EM)",
      yformatter=y->"", tickfontsize=12)
Example block output

ELBO Convergence

plot(elbo, xlabel="iteration", ylabel="ELBO", title="ELBO over Iterations", legend=false)
Example block output

This page was generated using Literate.jl.