Simulating and Fitting a Linear Dynamical System

This tutorial demonstrates how to use StateSpaceDynamics.jl to simulate a latent linear dynamical system and fit it using the EM algorithm.

Load Packages

using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using StableRNGs
rng = StableRNG(123);

Create a State-Space Model

obs_dim = 10
latent_dim = 2

A = 0.95 * [cos(0.25) -sin(0.25); sin(0.25) cos(0.25)]
Q = Matrix(0.1 * I(2))

x0 = [0.0; 0.0]
P0 = Matrix(0.1 * I(2))

C = randn(rng, obs_dim, latent_dim)
R = Matrix(0.5 * I(obs_dim))

true_gaussian_sm = GaussianStateModel(;A=A, Q=Q, x0=x0, P0=P0)
true_gaussian_om = GaussianObservationModel(;C=C, R=R)
true_lds = LinearDynamicalSystem(;
    state_model=true_gaussian_sm,
    obs_model=true_gaussian_om,
    latent_dim=latent_dim,
    obs_dim=obs_dim,
    fit_bool=fill(true, 6)
)
LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, GaussianObservationModel{Float64, Matrix{Float64}}}(GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}([0.9204668006251124 -0.2350337612917968; 0.2350337612917968 0.9204668006251124], [0.1 0.0; 0.0 0.1], [0.0, 0.0], [0.1 0.0; 0.0 0.1]), GaussianObservationModel{Float64, Matrix{Float64}}([-0.12683768965424458 -0.9995722599695167; 0.6668851724871252 -1.4919831226368483; … ; -0.9671975288083468 -0.3274601670258862; -1.3641880343579902 -0.5067518363436612], [0.5 0.0 … 0.0 0.0; 0.0 0.5 … 0.0 0.0; … ; 0.0 0.0 … 0.5 0.0; 0.0 0.0 … 0.0 0.5]), 2, 10, Bool[1, 1, 1, 1, 1, 1])

Simulate Latent and Observed Data

tSteps = 500
latents, observations = rand(rng, true_lds; tsteps=tSteps, ntrials=1)
([0.3302411483795398 0.30287095528698343 … -0.06670541876666242 -0.17921452288240727; -0.31796663176380235 -0.3171044904594342 … 1.4125315970448002 1.2103765402396673;;;], [-0.10983625829671911 1.6264648985386785 … -2.386760414932026 -1.1755792249265034; 1.1615730946048897 1.4717874431185893 … -2.974657918966271 -1.9995874486572618; … ; -0.5933847030613173 -1.1111410198929672 … -0.7690775611482195 0.23287537282242812; -1.2316865274907531 -0.7610659908580879 … 0.3534417293032205 -2.098209673926211;;;])

Plot Vector Field of Latent Dynamics

x = y = -3:0.5:3
X = repeat(x', length(y), 1)
Y = repeat(y, 1, length(x))

U = zeros(size(X))
V = zeros(size(Y))

for i in 1:size(X, 1)
    for j in 1:size(X, 2)
        v = A * [X[i,j], Y[i,j]]
        U[i,j] = v[1] - X[i,j]
        V[i,j] = v[2] - Y[i,j]
    end
end

magnitude = @. sqrt(U^2 + V^2)
U_norm = U ./ magnitude
V_norm = V ./ magnitude

p = quiver(X, Y, quiver=(U_norm, V_norm), color=:blue, alpha=0.3,
           linewidth=1, arrow=arrow(:closed, :head, 0.1, 0.1))
plot!(latents[1, :, 1], latents[2, :, 1], xlabel="x₁", ylabel="x₂",
      color=:black, linewidth=1.5, title="Latent Dynamics", legend=false)
Example block output

Plot Latent States and Observations

states = latents[:, :, 1]
emissions = observations[:, :, 1]

plot(size=(800, 600), layout=@layout[a{0.3h}; b])

lim_states = maximum(abs.(states))
for d in 1:latent_dim
    plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black,
          linewidth=2, label="", subplot=1)
end

plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
      xticks=[], xlims=(0, tSteps), title="Simulated Latent States",
      yformatter=y->"", tickfontsize=12)

lim_emissions = maximum(abs.(emissions))
for n in 1:obs_dim
    plot!(1:tSteps, emissions[n, :] .- lim_emissions * (n-1), color=:black,
          linewidth=2, label="", subplot=2)
end

plot!(subplot=2, yticks=(-lim_emissions .* (obs_dim-1:-1:0), [L"y_{%$n}" for n in 1:obs_dim]),
      xlabel="time", xlims=(0, tSteps), title="Simulated Emissions",
      yformatter=y->"", tickfontsize=12)

plot!(link=:x, size=(800, 600), left_margin=10Plots.mm)
Example block output

Initialize a Model and Perform Smoothing

A_init = random_rotation_matrix(2, rng)
Q_init = Matrix(0.1 * I(2))
C_init = randn(rng, obs_dim, latent_dim)
R_init = Matrix(0.5 * I(obs_dim))
x0_init = zeros(latent_dim)
P0_init = Matrix(0.1 * I(latent_dim))

gaussian_sm_init = GaussianStateModel(;A=A_init, Q=Q_init, x0=x0_init, P0=P0_init)
gaussian_om_init = GaussianObservationModel(;C=C_init, R=R_init)

naive_ssm = LinearDynamicalSystem(;
    state_model=gaussian_sm_init,
    obs_model=gaussian_om_init,
    latent_dim=latent_dim,
    obs_dim=obs_dim,
    fit_bool=fill(true, 6)
)

x_smooth, _, _ = StateSpaceDynamics.smooth(naive_ssm, observations)

plot()
for d in 1:latent_dim
    plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black, linewidth=2, label="", subplot=1)
    plot!(1:tSteps, x_smooth[d, :, 1] .+ lim_states * (d-1), color=:firebrick, linewidth=2, label="", subplot=1)
end
plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
      xticks=[], xlims=(0, tSteps), yformatter=y->"", tickfontsize=12,
      title="True vs. Predicted Latent States (Pre-EM)")
Example block output

Fit Model Using EM Algorithm

elbo, _ = fit!(naive_ssm, observations; max_iter=100, tol=1e-6)

x_smooth, _, _ = StateSpaceDynamics.smooth(naive_ssm, observations)

plot()
for d in 1:latent_dim
    plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black, linewidth=2, label="", subplot=1)
    plot!(1:tSteps, x_smooth[d, :, 1] .+ lim_states * (d-1), color=:firebrick, linewidth=2, label="", subplot=1)
end
plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
      xticks=[], xlims=(0, tSteps), yformatter=y->"", tickfontsize=12,
      title="True vs. Predicted Latent States (Post-EM)")
Example block output

Confirm the model converges

plot(elbo, xlabel="iteration", ylabel="ELBO", title="ELBO (Marginal Loglikelihood)", legend=false)
Example block output

This page was generated using Literate.jl.