Simulating and Fitting a Switching Linear Dynamical System
Load Packages
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using Statistics
using StableRNGs
rng = StableRNG(123);
Simulate Data from an SLDS model
state_dim = 2
obs_dim = 10
K = 2 ## two states
# Create the HMM parameters
A_hmm = [0.92 0.08; 0.06 0.94]
π₀ = [1.0, 0.0]
# Create the state models
A₁ = 0.95 * [cos(0.05) -sin(0.05); sin(0.05) cos(0.05)] ## slower oscillator
A₂ = 0.95 * [cos(0.55) -sin(0.55); sin(0.55) cos(0.55)] ## faster oscillator
Q₁ = [0.001 0.0; 0.0 0.001]
Q₂ = [0.1 0.0; 0.0 0.1]
# Assume same initial distribution for ease
x0 = [0.0, 0.0]
P0 = [0.1 0.0; 0.0 0.1]
# create the observation models
C₁ = randn(rng, obs_dim, state_dim)
C₂ = randn(rng, obs_dim, state_dim)
R = Matrix(0.1 * I(obs_dim)) ## Assume same noise covariance for both states
# Put it all together for an SLDS model
model = SwitchingLinearDynamicalSystem(
A_hmm,
[LinearDynamicalSystem(GaussianStateModel(A₁, Q₁, x0, P0), GaussianObservationModel(C₁, R), state_dim, obs_dim, fill(true, 6)),
LinearDynamicalSystem(GaussianStateModel(A₂, Q₂, x0, P0), GaussianObservationModel(C₂, R), state_dim, obs_dim, fill(true, 6))],
π₀,
K)
# Simulate data
T = 1000
x, y, z = rand(rng, model, 1000)
([-0.0076254535387565315 0.0010301616170909963 … 0.10819269449912033 0.07067327030863799; -0.04123506436656016 -0.03295684397718122 … -0.13492749101543738 -0.1695991047478436], [-0.18526210459309506 -0.35462255865221975 … 0.4766254627635034 0.5128027106162572; 0.3217524900300324 -0.07272113836410676 … -0.08291298797591545 -0.3269680913990665; … ; -0.07844180835082344 -0.3034603533211502 … 0.006499480950309909 0.6288491495372703; 0.49270132361714775 -0.3907296529253596 … -0.16056295474176502 0.3443776345048929], [1, 1, 1, 2, 2, 2, 2, 1, 1, 1 … 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
Plot the true dynamics
p1 = plot(1:T, x[1, :], label="x₁", linewidth=1.5)
plot!(1:T, x[2, :], label="x₂", linewidth=1.5)
# Create a background shading based on the state (z)
# Find state transition points
transition_points = [1; findall(diff(z) .!= 0) .+ 1; T + 1]
for i in 1:(length(transition_points) - 1)
start_idx = transition_points[i]
end_idx = transition_points[i + 1] - 1
state_value = z[start_idx]
# Choose color based on state value
bg_color = state_value == 1 ? :lightblue : :lightyellow
# Add a background shading for this state region
vspan!([start_idx, end_idx], fillalpha=0.5, color=bg_color,
label=(i == 1 ? "State $state_value" : ""))
end
# Adjust the plot appearance
title!("Latent Dynamics with State")
xlabel!("Time")
ylabel!("State Value")
ylims!(-3, 3)
# Adjust the plot appearance
title!("Latent Dynamics with State")
xlabel!("Time")
ylabel!("State Value")
ylims!(-3, 3)
p1
Create a new SLDS model with different parameters and fit to the data
# Create a model to start with for EM, using reasonable guesses
A = [0.9 0.1; 0.1 0.9]
A ./= sum(A, dims=2) ## Normalize rows to sum to 1
πₖ = rand(K)
πₖ ./= sum(πₖ) ## Normalize to sum to 1
Q = Matrix(0.001 * I(state_dim))
x0 = [0.0; 0.0]
P0 = Matrix(0.001 * I(state_dim))
# set up the observation parameters
C = randn(obs_dim, state_dim)
R = Matrix(0.1 * I(obs_dim))
B = [StateSpaceDynamics.LinearDynamicalSystem(
StateSpaceDynamics.GaussianStateModel(0.95 * [cos(f) -sin(f); sin(f) cos(f)], Q, x0, P0),
StateSpaceDynamics.GaussianObservationModel(C, R),
state_dim, obs_dim, fill(true, 6)) for (i,f) in zip(1:K, [0.7, 0.1])]
learned_model = SwitchingLinearDynamicalSystem(A, B, πₖ, model.K)
SwitchingLinearDynamicalSystem{Float64, Matrix{Float64}, Vector{Float64}, Vector{LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, GaussianObservationModel{Float64, Matrix{Float64}}}}}([0.9 0.1; 0.1 0.9], LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, GaussianObservationModel{Float64, Matrix{Float64}}}[LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, GaussianObservationModel{Float64, Matrix{Float64}}}(GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}([0.726600077920264 -0.6120068028758064; 0.6120068028758064 0.726600077920264], [0.001 0.0; 0.0 0.001], [0.0, 0.0], [0.001 0.0; 0.0 0.001]), GaussianObservationModel{Float64, Matrix{Float64}}([-0.1472353199006129 -1.2895524120041055; -1.617921615366992 -0.8985674611863382; … ; 0.5171917119667017 1.2902391318721487; 0.5040654748790333 -0.8166354582745979], [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), 2, 10, Bool[1, 1, 1, 1, 1, 1]), LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, GaussianObservationModel{Float64, Matrix{Float64}}}(GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}([0.9452539570141245 -0.09484174581448675; 0.09484174581448675 0.9452539570141245], [0.001 0.0; 0.0 0.001], [0.0, 0.0], [0.001 0.0; 0.0 0.001]), GaussianObservationModel{Float64, Matrix{Float64}}([-0.1472353199006129 -1.2895524120041055; -1.617921615366992 -0.8985674611863382; … ; 0.5171917119667017 1.2902391318721487; 0.5040654748790333 -0.8166354582745979], [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), 2, 10, Bool[1, 1, 1, 1, 1, 1])], [0.5371353992007545, 0.4628646007992456], 2)
Fit the model to the data
mls, param_diff, FB, FS = fit!(learned_model, y; max_iter=25) # use 25 iterations of EM
([-108520.53377120892, -21432.718521314888, -18941.159992874993, -18080.67064584985, -17673.080310470614, -17389.98729484739, -17337.685201216478, -17317.11416134528, -17305.94073924976, -17299.525960088402 … -17289.471381679385, -17289.124253178124, -17288.874889808976, -17288.693241921923, -17288.559176320206, -17288.4589627697, -17288.383093351145, -17288.324901094107, -17288.279662110937, -17288.24399510062], [15.41406489508258, 1.2455455401448887, 0.766441319387983, 0.5215340937129568, 0.3584842712569053, 0.31554434590528085, 0.05285337246329432, 0.030804684418305514, 0.02177184535554386, 0.015535849028198961 … 0.0024822093611916157, 0.0019192424842317595, 0.0015109756433789618, 0.001212422921748818, 0.0009920652146195637, 0.0008277543692955701, 0.000703879511404568, 0.0006093974023595819, 0.0005364599265866845, 0.00047945773285424764], StateSpaceDynamics.ForwardBackward{Float64, Vector{Float64}, Matrix{Float64}, Array{Float64, 3}}([-14.694810681925606 -12.176441796056054 … -78.88360771217958 -78.27505211005185; -6.223170076583583 -3.5205740403480466 … -4.8676744134035275 -5.236145097666078], [-227.2813237521779 -20.918735284779686 … -6145.11461584854 -6149.457692703976; -6.223170076583583 -9.82770216109264 … -6068.663517181784 -6073.983620323611], [-6070.2574284570155 -6066.653071365399 … -7.8172550999681105 0.0; -6067.76045024708 -6064.155919417809 … -5.320103141827089 0.0], [-223.5551318855296 -13.588186326514005 … -78.94825062489599 -75.47407238036521; 0.0 -1.2552372936625034e-6 … 0.0 0.0], [-232.20592429230692 -223.55530688897215; -13.588186326514005 -1.2552372936625034e-6;;; -31.947992125233213 -13.588186337141451; -23.29737471134831 -1.2553136912174523e-6;;; -40.65738717833119 -23.297199736785842; -22.29775510977288 -2.837623469531536e-10;;; … ;;; -156.20556924795437 -80.00887682518169; -81.13426009071645 0.0;;; -155.14494304766959 -81.13426009071736; -78.94825062489599 0.0;;; -149.48475533731835 -78.9482506248969; -75.47407238036521 0.0]), StateSpaceDynamics.FilterSmooth{Float64}[StateSpaceDynamics.FilterSmooth{Float64}([0.07469660946354777 0.044640338658525486 … 0.013566239498118722 0.0023104747176090107; -0.02533398729927876 0.012839852370994623 … 0.01914703708548579 0.023424939522191473], [0.0005964170678630932 -1.5904160488392509e-6; -1.5904160488392509e-6 0.0006202585002008555;;; 0.008806381618659674 -0.0003347828289762794; -0.0003347828289762794 0.00889515303893694;;; 0.0117468287254638 -0.0002484057666476841; -0.0002484057666476841 0.014322102747581935;;; … ;;; 0.06135244931529858 -0.0022147608022419084; -0.0022147608022419084 0.08209453651876804;;; 0.061418075169005305 -0.002230295810062663; -0.002230295810062663 0.08218693286454543;;; 0.06148187136202957 -0.0022387232932911325; -0.0022387232932911325 0.08225544207019979], [0.07469660946354777 0.044640338658525486 … 0.013566239498118722 0.0023104747176090107; -0.02533398729927876 0.012839852370994623 … 0.01914703708548579 0.023424939522191473;;;], [0.006176000533212866 -0.0018939533714975442; -0.0018939533714975442 0.0012620694126808732;;; 0.010799141454207519 0.00023839252919039195; 0.00023839252919039195 0.009060014847845877;;; 0.011985879754268831 1.667823073475775e-5; 1.667823073475775e-5 0.014616054737747039;;; … ;;; 0.061842973503694616 -0.0020421302486981975; -0.0020421302486981975 0.08215529052177502;;; 0.061602118023125624 -0.0019705425192816017; -0.0019705425192816017 0.08255354189369839;;; 0.061487209655450285 -0.002184600562763589; -0.002184600562763589 0.08280416986181811;;;;], [0.0 0.0; 0.0 0.0;;; 0.0037075765539731818 -0.0013453702819465993; 0.0012676203943428202 0.00012502529823163954;;; 0.00577457073591737 -0.0029401988921241144; 0.004812344226248055 0.006315227375413061;;; … ;;; 0.048011265171042275 -0.036068583307385035; 0.033828537663663776 0.06537523900755088;;; 0.047817942277521425 -0.03584328946452893; 0.03408977496059626 0.06567902733342645;;; 0.04760514747729177 -0.03595515750643187; 0.03400879482835221 0.06604451732421852;;;;]), StateSpaceDynamics.FilterSmooth{Float64}([0.07649798325185714 0.006920001624199251 … 0.019016451611786514 0.04476843234669529; -0.02551117737561867 -0.023057591834502157 … -0.06000470017471798 -0.02012865430106065], [5.494278749554643e-5 -2.5637211632171733e-5; -2.5637211632171733e-5 5.5263655940742426e-5;;; 0.0011161547674384262 -0.000551742390531601; -0.000551742390531601 0.0011992495222875239;;; 0.0011757779647676898 -0.0005859012950416559; -0.0005859012950416559 0.0012409506196976346;;; … ;;; 0.0011792795354740645 -0.0005880495632020828; -0.0005880495632020828 0.001242830812933902;;; 0.0011932420128077186 -0.0005962729306702498; -0.0005962729306702498 0.0012510121296017947;;; 0.0014427965986964874 -0.0007339297011768738; -0.0007339297011768738 0.0014618109830884223], [0.07649798325185714 0.006920001624199251 … 0.019016451611786514 0.04476843234669529; -0.02551117737561867 -0.023057591834502157 … -0.06000470017471798 -0.02012865430106065;;;], [0.005906884229096962 -0.0019771908312474057; -0.0019771908312474057 0.0007060838270310204;;; 0.0011640411899173466 -0.0007113009634764794; -0.0007113009634764794 0.0017309020634940246;;; 0.0013970485851560138 0.0007310345533937276; 0.0007310345533937276 0.009078954715544332;;; … ;;; 0.0014298948156483 -0.00214854679316776; -0.00214854679316776 0.010959523278922249;;; 0.0015548674447111365 -0.001737349408022532; -0.001737349408022532 0.004851576172659595;;; 0.0034470091334771204 -0.0016350579994839245; -0.0016350579994839245 0.0018669737070600297;;;;], [0.0 0.0; 0.0 0.0;;; 0.00054228436709415 -0.00018295227157454278; -0.0017707287851226972 0.0005984911471506806;;; 0.0001599332862343951 0.0002047450318821412; -0.0007573380876248817 0.002263943708218194;;; … ;;; 0.0011064085766312928 -0.0015183536256469087; -0.005314301569001719 0.008770911404797872;;; 0.0005819543316200788 -0.0020233752396667964; -0.0011055213495137902 0.006147967495435067;;; 0.0011912007962570088 -0.0028684673721751826; -0.0005730477165535 0.0014817444426570656;;;;])])
Plot the ELBO over iterations
plot(mls, label="ELBO", linewidth=1.5)
xlabel!("Iteration")
ylabel!("ELBO")
Compare the true and learned model
# Plot the latent states as a weighted function of the responsibilities for each state
latents = zeros(state_dim, T) # Initialize with state dimension, not K
resp = exp.(FB.γ) # Responsibilities (probabilities) for each state at each time
# For each time point, compute the weighted average of the smoothed states
for t in 1:T
for k in 1:K
latents[:, t] += FS[k].x_smooth[:, t] .* resp[k, t]
end
end
# Plot the learned latent states on top of the original with improved styling
plt = plot(size=(800, 500), background_color=:white, margin=5Plots.mm)
# Plot true values
plot!(x[1, :] .+ 2, label="x₁ (True)", linewidth=2, color=:black, alpha=0.8)
plot!(x[2, :] .- 2, label="x₂ (True)", linewidth=2, color=:black, alpha=0.8)
# Plot learned values
plot!(latents[1, :] .+ 2, label="x₁ (Learned)", linewidth=1.5, color=:firebrick)
plot!(latents[2, :] .- 2, label="x₂ (Learned)", linewidth=1.5, color=:royalblue)
# Improve styling
title!("SLDS: True vs Learned Latent States")
xlabel!("Time")
ylabel!("") # Remove the default y label
# Custom y-ticks with state labels at the correct positions
yticks!([-2, 2], ["x₂", "x₁"]) # Set custom tick positions and labels
# Add horizontal lines to emphasize the state positions
hline!([2], color=:gray, alpha=0.3, linestyle=:dash, label="")
hline!([-2], color=:gray, alpha=0.3, linestyle=:dash, label="")
xlims!(0, T)
This page was generated using Literate.jl.