Choosing Latent Dimensionality for Linear Dynamical Systems (LDS)
One of the most critical decisions when fitting an LDS is selecting the latent dimensionality K. Cross-validation is the universal approach that works for ANY state-space model - Gaussian LDS, Poisson LDS, nonlinear SSMs, etc. This tutorial demonstrates robust CV-based model selection.
# Load Required Packages
using StateSpaceDynamics
using LinearAlgebra
using Random
using Plots
using Statistics
using StableRNGs
using PrintfFix RNG for reproducible results
rng = StableRNG(1234);
# Create a True Gaussian LDS SystemStableRNGs.LehmerRNG(state=0x000000000000000000000000000009a5)For demonstration, we'll create a ground truth LDS with K=4 latent dimensions. This system will exhibit interesting dynamics like oscillations and decay patterns.
K_true = 4 # True latent dimensionality
D = 10 # Observation dimensionality
T = 300; # Number of time stepsCreate interesting dynamics: oscillating + decaying modes
θ = π/12 # Oscillation frequency
λ = 0.92 # Decay rate
true_A = [cos(θ) -sin(θ) 0.0 0.0;
sin(θ) cos(θ) 0.0 0.0;
0.0 0.0 λ 0.0;
0.0 0.0 0.0 0.85*λ];
true_Q = 0.05 * Matrix(I(K_true)); # Process noise covariance
true_b = zeros(K_true)
Random.seed!(rng, 42) # Observation matrix - each latent dimension affects multiple observations
true_C = randn(rng, D, K_true) * 0.6;
true_d = zeros(D)
true_R = 0.1 * Matrix(I(D)); # Observation noise covariance
true_μ0 = zeros(K_true) # Initial state parameters
true_Σ0 = 0.1 * Matrix(I(K_true));
true_lds = LinearDynamicalSystem(
GaussianStateModel(true_A, true_Q, true_b, true_μ0, true_Σ0),
GaussianObservationModel(true_C, true_R, true_d),
K_true,
D,
fill(true, 6)
);
latent_states, observations = rand(rng, true_lds; tsteps=T, ntrials=1); # Generate ground truth dataVisualize the true latent dynamics and observations
p1 = plot(layout=(2,2), size=(1000, 600))
plot!(1:T, latent_states[1, :], label="Latent 1 (cos)",
linewidth=2, subplot=1, title="Oscillating Modes")
plot!(1:T, latent_states[2, :], label="Latent 2 (sin)",
linewidth=2, subplot=1)
plot!(1:T, latent_states[3, :], label="Latent 3 (decay)",
linewidth=2, subplot=2, title="Decaying Modes")
plot!(1:T, latent_states[4, :], label="Latent 4 (decay)",
linewidth=2, subplot=2)
plot!(1:T, observations[1, :], label="Obs 1", alpha=0.7, subplot=3, title="Observations 1-3")
plot!(1:T, observations[2, :], label="Obs 2", alpha=0.7, subplot=3)
plot!(1:T, observations[3, :], label="Obs 3", alpha=0.7, subplot=3)
plot!(1:T, observations[4, :], label="Obs 4", alpha=0.7, subplot=4, title="Observations 4-6")
plot!(1:T, observations[5, :], label="Obs 5", alpha=0.7, subplot=4)
plot!(1:T, observations[6, :], label="Obs 6", alpha=0.7, subplot=4)
# Prepare Data for Cross-Validation
y_data = reshape(observations, D, T, 1) # (obs_dim, tsteps, ntrials)
# Cross-Validation Setup
K_candidates = 1:8 # Test latent dimensions from 1 to 8
n_folds = 5 # Number of CV folds
fold_size = T ÷ n_folds;Storage for CV results
cv_scores = zeros(length(K_candidates), n_folds)
cv_mean = zeros(length(K_candidates))
cv_std = zeros(length(K_candidates));
println("Starting Cross-Validation for Model Selection...")
println("="^60)
# Perform K-Fold Cross-Validation
for (k_idx, K) in enumerate(K_candidates)
println("Testing K = $K...")
fold_scores = zeros(n_folds)
for fold in 1:n_folds
val_start = (fold - 1) * fold_size + 1
val_end = min(fold * fold_size, T)
train_indices = vcat(1:(val_start-1), (val_end+1):T)
val_indices = val_start:val_end
y_train = y_data[:, train_indices, :]
y_val = y_data[:, val_indices, :]
A_init = 0.9 * Matrix(I(K)) + 0.1 * randn(rng, K, K)
Q_init = 0.1 * Matrix(I(K))
b_init = zeros(K)
C_init = randn(rng, D, K) * 0.5
R_init = 0.2 * Matrix(I(D))
d_init = zeros(D)
μ0_init = zeros(K)
Σ0_init = 0.1 * Matrix(I(K))
lds_candidate = LinearDynamicalSystem(
GaussianStateModel(A_init, Q_init, b_init, μ0_init, Σ0_init),
GaussianObservationModel(C_init, R_init, d_init),
K,
D,
fill(true, 6) # Fit all parameters
)
try
lls, _ = fit!(lds_candidate, y_train; max_iter=200, tol=1e-6, progress=false);
x_val, _ = smooth(lds_candidate, y_val[:, :, 1])
val_ll = loglikelihood(x_val, lds_candidate, y_val[:, :, 1])
fold_scores[fold] = sum(val_ll) / length(val_indices) # Normalize by sequence length
catch e
println(" Warning: Fold $fold failed for K=$K: $e")
fold_scores[fold] = -Inf
end
end
cv_scores[k_idx, :] = fold_scores
cv_mean[k_idx] = mean(fold_scores)
cv_std[k_idx] = std(fold_scores)
@printf(" K=%d: CV Score = %.3f ± %.3f\n", K, cv_mean[k_idx], cv_std[k_idx])
end
# Find Optimal K
best_k_idx = argmax(cv_mean)
best_K = K_candidates[best_k_idx]
println("\n" * "="^60)
println("CROSS-VALIDATION RESULTS:")
println("="^60)
@printf("True K: %d\n", K_true)
@printf("Best K: %d (CV Score: %.3f ± %.3f)\n", best_K, cv_mean[best_k_idx], cv_std[best_k_idx])
println()
p2 = plot(K_candidates, cv_mean;
yerror = cv_std,
marker = :circle,
markersize = 6,
linewidth = 2,
xlabel = "Latent Dimensionality (K)",
ylabel = "Cross-Validation Score",
title = "Model Selection via Cross-Validation",
legend = false,
size = (800, 500),
)
vline!([K_true]; linestyle=:dash, color=:green, linewidth=2, label="")
vline!([best_K]; linestyle=:dot, color=:red, linewidth=2, label="")
yr = extrema(cv_mean)
y1 = yr[2] - 0.05*(yr[2] - yr[1])
y2 = yr[2] - 0.15*(yr[2] - yr[1])
annotate!( (K_true, y1, text("True K=$(K_true)", :green, 10)),
(best_K, y2, text("Selected K=$(best_K)", :red, 10)) )Initialize final model
A_final = 0.9 * Matrix(I(best_K)) + 0.1 * randn(rng, best_K, best_K)
Q_final = 0.1 * Matrix(I(best_K))
b_final = zeros(best_K)
C_final = randn(rng, D, best_K) * 0.5
R_final = 0.2 * Matrix(I(D))
μ0_final = zeros(best_K)
Σ0_final = 0.1 * Matrix(I(best_K))
final_lds = LinearDynamicalSystem(
GaussianStateModel(A_final, Q_final, b_final, μ0_final, Σ0_final),
GaussianObservationModel(C_final, R_final, true_d),
best_K,
D,
fill(true, 6)
)Linear Dynamical System:
------------------------
Gaussian State Model:
---------------------
State Parameters:
A = [0.972 -0.0737 0.0176 0.0746; -0.0262 0.855 0.00471 -0.131; 0.00165 0.0674 0.848 0.0787; 0.0802 0.0642 -0.0126 1.03]
Q = [0.1 0.0 0.0 0.0; 0.0 0.1 0.0 0.0; 0.0 0.0 0.1 0.0; 0.0 0.0 0.0 0.1]
Initial State:
b = [0.0, 0.0, 0.0, 0.0]
x0 = [0.0, 0.0, 0.0, 0.0]
P0 = [0.1 0.0 0.0 0.0; 0.0 0.1 0.0 0.0; 0.0 0.0 0.1 0.0; 0.0 0.0 0.0 0.1]
Gaussian Observation Model:
---------------------------
size(C) = (10, 4)
size(R) = (10, 10)
size(d) = (10,)
Parameters to update:
---------------------
x0, P0, A (and b), Q, C, R
Fit on full dataset
final_lls, _ = fit!(final_lds, y_data; max_iter=500, tol=1e-8)([-14478.013912490433, -710.4084406897019, -234.16711203090426, 218.95950298139383, 592.8693744817501, 857.4714712581017, 1022.3172910438657, 1119.197521870131, 1175.304009698883, 1207.4056170958127 … 1270.3705665982884, 1270.3787885226172, 1270.3869935811354, 1270.3951818429896, 1270.4033533766842, 1270.4115082503647, 1270.4196465319892, 1270.4277682888792, 1270.4358735879225, 1270.4439624958436], [6.944726245938449, 1.8905295826863686, 1.0138931251240768, 0.483940789539728, 0.229107977955957, 0.12204925095765637, 0.04483528844926021, 0.023358989759428954, 0.01403542885222746, 0.009103687819864021 … 1.0891180314418483e-6, 1.0846977617827071e-6, 1.08030434739712e-6, 1.075937571900227e-6, 1.071597219667823e-6, 1.0672830784161049e-6, 1.0629949373220932e-6, 1.0587325875811274e-6, 1.0544958234797385e-6, 1.0502844398236567e-6])Compare Learned vs True Dynamics Use the correct input format for smooth function (needs 3D array)
x_learned, P_learned = smooth(final_lds, y_data)
plt1 = plot(
1:length(final_lls), final_lls,
linewidth=2,
xlabel="EM Iteration",
ylabel="Log-Likelihood",
title="Learning Curve (Final Model)"
)
n_plot = min(4, best_K, K_true)
colors = [:blue, :red, :green, :orange]
plt2 = plot(title="True vs Learned Latent Dynamics", xlabel="Time", ylabel="Latent State Value")
for i in 1:n_plot
if i <= size(latent_states, 1)
plot!(plt2, 1:T, latent_states[i, :],
label="True Latent $i", color=colors[i],
linestyle=:solid, linewidth=2)
end
if i <= size(x_learned, 1)
plot!(plt2, 1:T, x_learned[i, :],
label="Learned Latent $i", color=colors[i],
linestyle=:dash, linewidth=2)
end
end
p3 = plot(plt1, plt2, layout = @layout([a; b]), size=(1000,600))Compute reconstruction error x_learned is now (latent_dim, tsteps, 1), so we need to handle the singleton trial dimension
x_learned = x_learned[:, :, 1]
y_pred = final_lds.obs_model.C * x_learned
reconstruction_error = mean((observations - y_pred).^2)
@printf("Reconstruction MSE: %.6f\n", reconstruction_error)Reconstruction MSE: 0.102450This page was generated using Literate.jl.