Simulating and Fitting a Linear Dynamical System
This tutorial demonstrates how to use StateSpaceDynamics.jl
to simulate a latent linear dynamical system and fit it using the EM algorithm.
Load Packages
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using StableRNGs
rng = StableRNG(123);
Create a State-Space Model
obs_dim = 10
latent_dim = 2
A = 0.95 * [cos(0.25) -sin(0.25); sin(0.25) cos(0.25)]
Q = Matrix(0.1 * I(2))
x0 = [0.0; 0.0]
P0 = Matrix(0.1 * I(2))
C = randn(rng, obs_dim, latent_dim)
R = Matrix(0.5 * I(obs_dim))
true_gaussian_sm = GaussianStateModel(;A=A, Q=Q, x0=x0, P0=P0)
true_gaussian_om = GaussianObservationModel(;C=C, R=R)
true_lds = LinearDynamicalSystem(;
state_model=true_gaussian_sm,
obs_model=true_gaussian_om,
latent_dim=latent_dim,
obs_dim=obs_dim,
fit_bool=fill(true, 6)
)
LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, GaussianObservationModel{Float64, Matrix{Float64}}}(GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}([0.9204668006251124 -0.2350337612917968; 0.2350337612917968 0.9204668006251124], [0.1 0.0; 0.0 0.1], [0.0, 0.0], [0.1 0.0; 0.0 0.1]), GaussianObservationModel{Float64, Matrix{Float64}}([-0.12683768965424458 -0.9995722599695167; 0.6668851724871252 -1.4919831226368483; … ; -0.9671975288083468 -0.3274601670258862; -1.3641880343579902 -0.5067518363436612], [0.5 0.0 … 0.0 0.0; 0.0 0.5 … 0.0 0.0; … ; 0.0 0.0 … 0.5 0.0; 0.0 0.0 … 0.0 0.5]), 2, 10, Bool[1, 1, 1, 1, 1, 1])
Simulate Latent and Observed Data
tSteps = 500
latents, observations = rand(rng, true_lds; tsteps=tSteps, ntrials=1)
([0.3302411483795398 0.30287095528698343 … -0.06670541876666242 -0.17921452288240727; -0.31796663176380235 -0.3171044904594342 … 1.4125315970448002 1.2103765402396673;;;], [-0.10983625829671911 1.6264648985386785 … -2.386760414932026 -1.1755792249265034; 1.1615730946048897 1.4717874431185893 … -2.974657918966271 -1.9995874486572618; … ; -0.5933847030613173 -1.1111410198929672 … -0.7690775611482195 0.23287537282242812; -1.2316865274907531 -0.7610659908580879 … 0.3534417293032205 -2.098209673926211;;;])
Plot Vector Field of Latent Dynamics
x = y = -3:0.5:3
X = repeat(x', length(y), 1)
Y = repeat(y, 1, length(x))
U = zeros(size(X))
V = zeros(size(Y))
for i in 1:size(X, 1)
for j in 1:size(X, 2)
v = A * [X[i,j], Y[i,j]]
U[i,j] = v[1] - X[i,j]
V[i,j] = v[2] - Y[i,j]
end
end
magnitude = @. sqrt(U^2 + V^2)
U_norm = U ./ magnitude
V_norm = V ./ magnitude
p = quiver(X, Y, quiver=(U_norm, V_norm), color=:blue, alpha=0.3,
linewidth=1, arrow=arrow(:closed, :head, 0.1, 0.1))
plot!(latents[1, :, 1], latents[2, :, 1], xlabel="x₁", ylabel="x₂",
color=:black, linewidth=1.5, title="Latent Dynamics", legend=false)
Plot Latent States and Observations
states = latents[:, :, 1]
emissions = observations[:, :, 1]
plot(size=(800, 600), layout=@layout[a{0.3h}; b])
lim_states = maximum(abs.(states))
for d in 1:latent_dim
plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black,
linewidth=2, label="", subplot=1)
end
plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
xticks=[], xlims=(0, tSteps), title="Simulated Latent States",
yformatter=y->"", tickfontsize=12)
lim_emissions = maximum(abs.(emissions))
for n in 1:obs_dim
plot!(1:tSteps, emissions[n, :] .- lim_emissions * (n-1), color=:black,
linewidth=2, label="", subplot=2)
end
plot!(subplot=2, yticks=(-lim_emissions .* (obs_dim-1:-1:0), [L"y_{%$n}" for n in 1:obs_dim]),
xlabel="time", xlims=(0, tSteps), title="Simulated Emissions",
yformatter=y->"", tickfontsize=12)
plot!(link=:x, size=(800, 600), left_margin=10Plots.mm)
Initialize a Model and Perform Smoothing
A_init = random_rotation_matrix(2, rng)
Q_init = Matrix(0.1 * I(2))
C_init = randn(rng, obs_dim, latent_dim)
R_init = Matrix(0.5 * I(obs_dim))
x0_init = zeros(latent_dim)
P0_init = Matrix(0.1 * I(latent_dim))
gaussian_sm_init = GaussianStateModel(;A=A_init, Q=Q_init, x0=x0_init, P0=P0_init)
gaussian_om_init = GaussianObservationModel(;C=C_init, R=R_init)
naive_ssm = LinearDynamicalSystem(;
state_model=gaussian_sm_init,
obs_model=gaussian_om_init,
latent_dim=latent_dim,
obs_dim=obs_dim,
fit_bool=fill(true, 6)
)
x_smooth, _, _ = StateSpaceDynamics.smooth(naive_ssm, observations)
plot()
for d in 1:latent_dim
plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black, linewidth=2, label="", subplot=1)
plot!(1:tSteps, x_smooth[d, :, 1] .+ lim_states * (d-1), color=:firebrick, linewidth=2, label="", subplot=1)
end
plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
xticks=[], xlims=(0, tSteps), yformatter=y->"", tickfontsize=12,
title="True vs. Predicted Latent States (Pre-EM)")
Fit Model Using EM Algorithm
elbo, _ = fit!(naive_ssm, observations; max_iter=100, tol=1e-6)
x_smooth, _, _ = StateSpaceDynamics.smooth(naive_ssm, observations)
plot()
for d in 1:latent_dim
plot!(1:tSteps, states[d, :] .+ lim_states * (d-1), color=:black, linewidth=2, label="", subplot=1)
plot!(1:tSteps, x_smooth[d, :, 1] .+ lim_states * (d-1), color=:firebrick, linewidth=2, label="", subplot=1)
end
plot!(subplot=1, yticks=(lim_states .* (0:latent_dim-1), [L"x_%$d" for d in 1:latent_dim]),
xticks=[], xlims=(0, tSteps), yformatter=y->"", tickfontsize=12,
title="True vs. Predicted Latent States (Post-EM)")
Confirm the model converges
plot(elbo, xlabel="iteration", ylabel="ELBO", title="ELBO (Marginal Loglikelihood)", legend=false)
This page was generated using Literate.jl.