Simulating and Fitting a Switching Linear Dynamical System

Load Packages

using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using Statistics
using StableRNGs

rng = StableRNG(123);

Simulate Data from an SLDS model

state_dim = 2
obs_dim = 10
K = 2 ## two states

# Create the HMM parameters
A_hmm = [0.92 0.08; 0.06 0.94]
π₀ = [1.0, 0.0]

# Create the state models
A₁ = 0.95 * [cos(0.05) -sin(0.05); sin(0.05) cos(0.05)] ## slower oscillator
A₂ = 0.95 * [cos(0.55) -sin(0.55); sin(0.55) cos(0.55)] ## faster oscillator

Q₁ = [0.001 0.0; 0.0 0.001]
Q₂ = [0.1 0.0; 0.0 0.1]

# Assume same initial distribution for ease
x0 = [0.0, 0.0]
P0 = [0.1 0.0; 0.0 0.1]

# create the observation models
C₁ = randn(rng, obs_dim, state_dim)
C₂ = randn(rng, obs_dim, state_dim)

R = Matrix(0.1 * I(obs_dim)) ## Assume same noise covariance for both states

# Put it all together for an SLDS model
model = SwitchingLinearDynamicalSystem(
    A_hmm,
    [LinearDynamicalSystem(GaussianStateModel(A₁, Q₁, x0, P0), GaussianObservationModel(C₁, R), state_dim, obs_dim, fill(true, 6)),
     LinearDynamicalSystem(GaussianStateModel(A₂, Q₂, x0, P0), GaussianObservationModel(C₂, R), state_dim, obs_dim, fill(true, 6))],
     π₀,
    K)

# Simulate data
T = 1000
x, y, z = rand(rng, model, 1000)
([-0.0076254535387565315 0.0005344715407875466 … -0.05961542435845691 -0.05411358762031381; -0.04123506436656016 -0.023410425346735073 … -0.1589869811371745 -0.13703553398805518], [-0.18526210459309506 -0.15796177119519147 … 0.5642837623533308 0.5449692850501915; 0.3217524900300324 0.14027914911950912 … 0.9285210380889222 0.44291807297934227; … ; -0.07844180835082344 0.4656190481762747 … -0.5557825462611323 -0.21417601279726053; 0.49270132361714775 0.32012586659081677 … 0.15877195176510678 1.1594527178911123], [1, 1, 2, 2, 2, 2, 2, 2, 2, 2  …  1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

Plot the true dynamics

p1 = plot(1:T, x[1, :], label="x₁",  linewidth=1.5)
plot!(1:T, x[2, :], label="x₂", linewidth=1.5)

# Create a background shading based on the state (z)
# Find state transition points
transition_points = [1; findall(diff(z) .!= 0) .+ 1; T + 1]

for i in 1:(length(transition_points) - 1)
    start_idx = transition_points[i]
    end_idx = transition_points[i + 1] - 1
    state_value = z[start_idx]

    # Choose color based on state value
    bg_color = state_value == 1 ? :lightblue : :lightyellow

    # Add a background shading for this state region
    vspan!([start_idx, end_idx], fillalpha=0.5, color=bg_color,
           label=(i == 1 ? "State $state_value" : ""))
end

# Adjust the plot appearance
title!("Latent Dynamics with State")
xlabel!("Time")
ylabel!("State Value")
ylims!(-3, 3)

# Adjust the plot appearance
title!("Latent Dynamics with State")
xlabel!("Time")
ylabel!("State Value")
ylims!(-3, 3)

p1
Example block output

Create a new SLDS model with different parameters and fit to the data

# Create a model to start with for EM, using reasonable guesses
A = [0.9 0.1; 0.1 0.9]
A ./= sum(A, dims=2) ## Normalize rows to sum to 1

πₖ = rand(K)
πₖ ./= sum(πₖ) ## Normalize to sum to 1

Q = Matrix(0.001 * I(state_dim))

x0 = [0.0; 0.0]
P0 = Matrix(0.001 * I(state_dim))

# set up the observation parameters
C = randn(obs_dim, state_dim)
R = Matrix(0.1 * I(obs_dim))

B = [StateSpaceDynamics.LinearDynamicalSystem(
    StateSpaceDynamics.GaussianStateModel(0.95 * [cos(f) -sin(f); sin(f) cos(f)], Q, x0, P0),
    StateSpaceDynamics.GaussianObservationModel(C, R),
    state_dim, obs_dim, fill(true, 6)) for (i,f) in zip(1:K, [0.7, 0.1])]

learned_model = SwitchingLinearDynamicalSystem(A, B, πₖ, model.K)
SwitchingLinearDynamicalSystem{Float64, Matrix{Float64}, Vector{Float64}, Vector{LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, GaussianObservationModel{Float64, Matrix{Float64}}}}}([0.9 0.1; 0.1 0.9], LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, GaussianObservationModel{Float64, Matrix{Float64}}}[LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, GaussianObservationModel{Float64, Matrix{Float64}}}(GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}([0.726600077920264 -0.6120068028758064; 0.6120068028758064 0.726600077920264], [0.001 0.0; 0.0 0.001], [0.0, 0.0], [0.001 0.0; 0.0 0.001]), GaussianObservationModel{Float64, Matrix{Float64}}([-0.2768436540249182 1.0553781976751377; 0.11361442651749082 0.5795593434957488; … ; -1.045229397049543 1.2020964359511086; 1.8697489616795715 -1.2841356956379089], [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), 2, 10, Bool[1, 1, 1, 1, 1, 1]), LinearDynamicalSystem{Float64, GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}, GaussianObservationModel{Float64, Matrix{Float64}}}(GaussianStateModel{Float64, Matrix{Float64}, Vector{Float64}}([0.9452539570141245 -0.09484174581448675; 0.09484174581448675 0.9452539570141245], [0.001 0.0; 0.0 0.001], [0.0, 0.0], [0.001 0.0; 0.0 0.001]), GaussianObservationModel{Float64, Matrix{Float64}}([-0.2768436540249182 1.0553781976751377; 0.11361442651749082 0.5795593434957488; … ; -1.045229397049543 1.2020964359511086; 1.8697489616795715 -1.2841356956379089], [0.1 0.0 … 0.0 0.0; 0.0 0.1 … 0.0 0.0; … ; 0.0 0.0 … 0.1 0.0; 0.0 0.0 … 0.0 0.1]), 2, 10, Bool[1, 1, 1, 1, 1, 1])], [0.7013565205730752, 0.2986434794269248], 2)

Fit the model to the data

mls, param_diff, FB, FS = fit!(learned_model, y; max_iter=25) # use 25 iterations of EM
([-97079.74434109525, -21014.483449788248, -19050.095010428522, -18099.885171867263, -17862.10645282446, -17804.837966156534, -17778.33438585185, -17764.42561018367, -17757.182804149605, -17753.67889163311  …  -17758.935030679346, -17761.34336637926, -17763.918297549728, -17766.625282492074, -17769.439740045545, -17772.34371472153, -17775.32373870886, -17778.369431069936, -17781.472562409093, -17784.626420203967], [20.95199561628859, 2.291118905878763, 0.7595611159425671, 0.7620388465969193, 0.22787989259721536, 0.08903188564207053, 0.06133305501722204, 0.04350791501634204, 0.030934914442839766, 0.022153157061444193  …  0.011306253526673426, 0.012140932105170347, 0.013029014058098464, 0.01392865379392921, 0.01482540573160377, 0.015716150027296136, 0.016601948042540646, 0.017485007725938105, 0.018367479157868405, 0.019251046256445938], StateSpaceDynamics.ForwardBackward{Float64, Vector{Float64}, Matrix{Float64}, Array{Float64, 3}}([-40.759960748284044 -4.858865906228484 … -7.6973643460655845 -14.610176710565748; -34.0287530223749 -20.195654077623363 … -35.429952078674916 -38.39888891803196], [-40.759960748284044 -45.71694317790801 … -6412.582846614814 -6427.291139848775; -155.53345177093516 -63.325871493328876 … -6442.58757449144 -6453.351992200266], [-6386.531179100466 -6381.574196672582 … -14.70829323395642 0.0; -6389.091621493188 -6384.134639287638 … -17.26873584839724 0.0], [0.0 -1.7398633644916117e-9 … 0.0 -4.547473508864641e-12; -117.33393341537248 -20.169370932216225 … -32.565170491066965 -26.06085235149567], [-1.7398633644916117e-9 -20.169370932216225; -117.3339336394456 -132.64522875596504;;; -1.7398633644916117e-9 -147.4851559905028; -20.169370932215315 -162.7964511070213;;; 0.0 -76.692932386758; -147.48515576643058 -219.32001233923165;;; … ;;; 0.0 -30.267541518268445; -29.99596176409341 -55.40542746840492;;; 0.0 -32.56517049107606; -30.26754151825844 -57.97463619537757;;; -4.547473508864641e-12 -26.060852351496578; -32.56517049168633 -53.76794702922143]), StateSpaceDynamics.FilterSmooth{Float64}[StateSpaceDynamics.FilterSmooth{Float64}([-0.14139697569927479 -0.0036112724252746775 … -0.019510455421217397 0.017604987630413755; 0.018139547631266312 -0.03541029787861343 … 0.0334863257402473 -0.0072094878696341794], [1.416971797449966e-5 -2.312119877690916e-6; -2.312119877690916e-6 3.968182012055234e-5;;; 0.0003321699736069951 -5.8139893050158e-5; -5.8139893050158e-5 0.0008592615457725448;;; 0.0003346209696419567 -6.1033440720274474e-5; -6.1033440720274474e-5 0.0008719867248851624;;; … ;;; 0.0003346634167673146 -6.109023591324026e-5; -6.109023591324026e-5 0.000872156458284284;;; 0.0003347610501802562 -6.100839911538423e-5; -6.100839911538423e-5 0.0008741622858586135;;; 0.00035730658364090253 -8.261239765217035e-5; -8.261239765217035e-5 0.0009895443268411381], [-0.14139697569927479 -0.0036112724252746775 … -0.019510455421217397 0.017604987630413755; 0.018139547631266312 -0.03541029787861343 … 0.0334863257402473 -0.0072094878696341794;;;], [0.020007274454875806 -0.0025671892954916915; -0.0025671892954916915 0.0003687250083875316;;; 0.00034521126213654435 6.97363392496411e-5; 6.97363392496411e-5 0.0021131507416246795;;; 0.04077872465721982 0.008660562221707557; 0.008660562221707557 0.0027527610272773105;;; … ;;; 0.0003363298285935287 -8.16997489620169e-5; -8.16997489620169e-5 0.0011270466489478213;;; 0.0007154189209235675 -0.0007143418646908438; -0.0007143418646908438 0.0019954962974405624;;; 0.0006672421731079238 -0.0002095353424186981; -0.0002095353424186981 0.0010415210421835404;;;;], [0.0 0.0; 0.0 0.0;;; 0.0005114594041911828 -6.723268997750052e-5; 0.005007716310411231 -0.0006377247004059926;;; -0.0007064863943814662 -0.007159059970076956; -0.00013816315160646912 -0.0014366361800047926;;; … ;;; 3.0277831847890295e-5 -4.092124723488655e-5; -0.00010861730046547372 0.0001306833019926615;;; 4.521122743233139e-5 -0.0003499597053815958; -2.4884986152525367e-5 0.0006351879244006658;;; -0.00032255320679870246 0.0005465691045504713; 0.00016055022804973773 -0.00012628606064155225;;;;]), StateSpaceDynamics.FilterSmooth{Float64}([-0.15143493312950765 -0.15704742993563456 … 0.022540476335865298 0.02020709944963191; 0.016418919468403206 0.011969725842904538 … -0.030694403416864347 -0.030515192843616437], [0.000796468574335045 5.1031281027911025e-5; 5.1031281027911025e-5 0.0007985150852388643;;; 0.0034661822425669517 0.0013028941491425936; 0.0013028941491425936 0.003445815817824533;;; 0.0058175500352808 0.0023961163782228994; 0.0023961163782228994 0.005765510998408426;;; … ;;; 0.030166130039629594 0.012174515296104058; 0.012174515296104058 0.028850924741484565;;; 0.030247117234151508 0.012200776587404108; 0.012200776587404108 0.028944536775129707;;; 0.030320375955944265 0.012224770504078008; 0.012224770504078008 0.02903042320857634], [-0.15143493312950765 -0.15704742993563456 … 0.022540476335865298 0.02020709944963191; 0.016418919468403206 0.011969725842904538 … -0.030694403416864347 -0.030515192843616437;;;], [0.0237290075462735 -0.0024353666907284996; -0.0024353666907284996 0.001068096001748774;;; 0.028130077491954997 -0.0005769205315197111; -0.0005769205315197111 0.00358909015457883;;; 0.032478598841020154 0.0011240447418013023; 0.0011240447418013023 0.005826205019111767;;; … ;;; 0.03079241010054962 0.011404068612907086; 0.011404068612907086 0.029798724501830142;;; 0.03075519030759921 0.011508910113542775; 0.011508910113542775 0.02988668317624692;;; 0.030728702824111578 0.011608146967562355; 0.011608146967562355 0.02996160020285944;;;;], [0.0 0.0; 0.0 0.0;;; 0.0245214288009227 -0.002505368628730213; -0.001792672424406613 0.0009575335965717435;;; 0.028888728188651475 -0.0006383318663632134; -0.00010575479858167367 0.003333128762762009;;; … ;;; 0.029357582603473587 0.011569681947697395; 0.009806863575519836 0.028320961644323177;;; 0.029314353233531338 0.011676423357533522; 0.00991514365005562 0.028415939885032708;;; 0.029282722317263823 0.011777807159930447; 0.010017928764050398 0.028497232044841846;;;;])])

Plot the ELBO over iterations

plot(mls, label="ELBO", linewidth=1.5)
xlabel!("Iteration")
ylabel!("ELBO")
Example block output

Compare the true and learned model

# Plot the latent states as a weighted function of the responsibilities for each state

latents = zeros(state_dim, T)  # Initialize with state dimension, not K
resp = exp.(FB.γ)  # Responsibilities (probabilities) for each state at each time

# For each time point, compute the weighted average of the smoothed states
for t in 1:T
    for k in 1:K
        latents[:, t] += FS[k].x_smooth[:, t] .* resp[k, t]
    end
end


# Plot the learned latent states on top of the original with improved styling
plt = plot(size=(800, 500), background_color=:white, margin=5Plots.mm)

# Plot true values
plot!(x[1, :] .+ 2, label="x₁ (True)", linewidth=2, color=:black, alpha=0.8)
plot!(x[2, :] .- 2, label="x₂ (True)", linewidth=2, color=:black, alpha=0.8)

# Plot learned values
plot!(latents[1, :] .+ 2, label="x₁ (Learned)", linewidth=1.5, color=:firebrick)
plot!(latents[2, :] .- 2, label="x₂ (Learned)", linewidth=1.5, color=:royalblue)

# Improve styling
title!("SLDS: True vs Learned Latent States")
xlabel!("Time")
ylabel!("")  # Remove the default y label

# Custom y-ticks with state labels at the correct positions
yticks!([-2, 2], ["x₂", "x₁"])  # Set custom tick positions and labels

# Add horizontal lines to emphasize the state positions
hline!([2], color=:gray, alpha=0.3, linestyle=:dash, label="")
hline!([-2], color=:gray, alpha=0.3, linestyle=:dash, label="")

xlims!(0, T)
Example block output

This page was generated using Literate.jl.