Simulating and Fitting a Hidden Markov Model with Gaussian Emissions

This tutorial demonstrates how to use StateSpaceDynamics.jl to create, sample from, and fit Hidden Markov Models (HMMs) with Gaussian emission distributions. This is the classical HMM formulation where each hidden state generates observations from a different multivariate Gaussian distribution.

Unlike GLM-HMMs, this model doesn't use input features - each state simply emits observations from its own characteristic Gaussian distribution. This makes it ideal for clustering time series data, identifying behavioral regimes, or modeling switching dynamics where each state has a distinct statistical signature.

Load Required Packages

using LinearAlgebra
using Plots
using Random
using StateSpaceDynamics
using StableRNGs
using Statistics: mean, std
using LaTeXStrings

Set up reproducible random number generation

rng = StableRNG(1234);

Create a Gaussian Emission HMM

We'll create an HMM with two hidden states, each emitting 2D Gaussian observations. This creates a simple but illustrative model where hidden states correspond to different regions in the observation space.

output_dim = 2;  # Each observation is a 2D vector

Define state transition dynamics: $A_{ij} = P(\text{state}_t = j \mid \text{state}_{t-1} = i)$ \ High diagonal values mean states are "sticky" (tend to persist)

A = [0.99 0.01;    # From state 1: 99% stay, 1% switch to state 2
     0.05 0.95];   # From state 2: 5% switch to state 1, 95% stay

Initial state probabilities: $\pi_k = P(\text{state}_1 = k)$

πₖ = [0.5; 0.5];

Define emission distributions for each hidden state State 1: Centered at (-1, -1) with small variance (tight cluster)

μ_1 = [-1.0, -1.0]
Σ_1 = 0.1 * Matrix{Float64}(I, output_dim, output_dim)
emission_1 = GaussianEmission(output_dim=output_dim, μ=μ_1, Σ=Σ_1);

State 2: Centered at (1, 1) with larger variance (more spread out)

μ_2 = [1.0, 1.0]
Σ_2 = 0.2 * Matrix{Float64}(I, output_dim, output_dim)
emission_2 = GaussianEmission(output_dim=output_dim, μ=μ_2, Σ=Σ_2);

Construct the complete HMM

model = HiddenMarkovModel(
    K=2,                        # Number of hidden states
    B=[emission_1, emission_2], # Emission distributions
    A=A,                        # State transition matrix
    πₖ=πₖ                      # Initial state distribution
);

print("Created Gaussian HMM with 2 states:\n")
print("State 1: μ = $μ_1, σ² = $(Σ_1[1,1]) (tight cluster)\n")
print("State 2: μ = $μ_2, σ² = $(Σ_2[1,1]) (looser cluster)\n");
Created Gaussian HMM with 2 states:
State 1: μ = [-1.0, -1.0], σ² = 0.1 (tight cluster)
State 2: μ = [1.0, 1.0], σ² = 0.2 (looser cluster)

Sample from the HMM

Generate synthetic data from our true model. Each state generates observations from its own Gaussian distribution without requiring input features. The rand function samples both the hidden state sequence and the corresponding observations.

num_samples = 10000;
true_labels, data = rand(rng, model, n=num_samples);

Visualize the Sampled Dataset

Create a 2D scatter plot showing observations colored by their true hidden state. This illustrates how each state generates observations from a distinct region of space. We will also plot a trajectory line to show the temporal evolution for the first 1000 timepoints.

x_vals = data[1, 1:num_samples]
y_vals = data[2, 1:num_samples]
labels_slice = true_labels[1:num_samples]

state_colors = [:dodgerblue, :crimson]

p1 = plot()

for state in 1:2
    idx = findall(labels_slice .== state)
    scatter!(x_vals[idx], y_vals[idx];
        color=state_colors[state],
        label="State $state",
        markersize=3,
        alpha=0.6)
end

plot!(x_vals[1:1000], y_vals[1:1000];
    color=:gray, lw=1, alpha=0.3, label="Trajectory")

scatter!([x_vals[1]], [y_vals[1]]; marker=:star5, markersize=8,
         color=:green, label="Start")
scatter!([x_vals[end]], [y_vals[end]]; marker=:diamond, markersize=6,
         color=:black, label="End")

plot!(xlabel=L"x_1", ylabel=L"x_2",
      title="HMM Emissions by True Hidden State",
      legend=:topleft)
Example block output

Initialize and Fit HMM with EM

In reality, we only observe the data, not the hidden states. The goal of fitting is to learn the latent state sequence and the model parameters that best explain the data. We will initialize a new HMM with incorrect parameters and use the Expectation-Maximization (EM) algorithm to iteratively refine the parameters and infer the hidden states.

μ_1_init = [-0.25, -0.25]  # Closer to center than true
Σ_1_init = 0.3 * Matrix{Float64}(I, output_dim, output_dim)  # Larger variance
emission_1_init = GaussianEmission(output_dim=output_dim, μ=μ_1_init, Σ=Σ_1_init);

μ_2_init = [0.25, 0.25]    # Closer to center than true
Σ_2_init = 0.5 * Matrix{Float64}(I, output_dim, output_dim)  # Much larger variance
emission_2_init = GaussianEmission(output_dim=output_dim, μ=μ_2_init, Σ=Σ_2_init);

A_init = [0.8 0.2; 0.05 0.95]  # Less persistent than true model
πₖ_init = [0.6, 0.4];           # Biased toward state 1

test_model = HiddenMarkovModel(K=2, B=[emission_1_init, emission_2_init],
                              A=A_init, πₖ=πₖ_init);

Fit using Expectation-Maximization

lls = fit!(test_model, data);

print("EM converged in $(length(lls)) iterations\n")
print("Log-likelihood improved by $(round(lls[end] - lls[1], digits=1))\n");
EM converged in 5 iterations
Log-likelihood improved by 23097.5

Plot EM convergence

p2 = plot(lls, xlabel="EM Iteration", ylabel="Log-Likelihood",
          title="EM Algorithm Convergence", legend=false,
          marker=:circle, markersize=3, lw=2, color=:darkblue)

p2
Example block output

Hidden State Decoding with Viterbi

Now that we have learned the model parameters from the observed data, we can decode the most likely sequence of hidden states using the Viterbi algorithm. Then, in this toy example where we know the true latent state path, we can assess the accuracy of our state predictions.

pred_labels = viterbi(test_model, data);

accuracy = mean(true_labels .== pred_labels)
1.0

Calling a specific set of parameters "state 1" and "state 2" is arbitrary and does not affect the correctness of the model. The EM algorithm can converge with the states swapped from our original convention. We check for this and correct it if necessary.

swapped_pred = 3 .- pred_labels  # Convert 1→2, 2→1
swapped_accuracy = mean(true_labels .== swapped_pred)

if swapped_accuracy > accuracy
    pred_labels = swapped_pred
    accuracy = swapped_accuracy
    print("Detected and corrected label switching\n")
end

print("State prediction accuracy: $(round(accuracy*100, digits=1))%\n");
State prediction accuracy: 100.0%

Our model looks like it is doing pretty well! Let's visualize the predicted and true state sequences as heatmaps (first 1000 timepoints)

n_display = 1000
true_seq = reshape(true_labels[1:n_display], 1, :)
pred_seq = reshape(pred_labels[1:n_display], 1, :)

p3 = plot(
    heatmap(true_seq, colormap=:roma, title="True State Sequence",
           xticks=false, yticks=false, colorbar=false),
    heatmap(pred_seq, colormap=:roma, title="Predicted State Sequence (Viterbi)",
           xlabel="Time Steps (1-$n_display)", xticks=0:200:n_display,
           yticks=false, colorbar=false),
    layout=(2, 1), size=(800, 300)
)
Example block output

Multiple Independent Trials

Many real applications involve multiple independent sequences (e.g., multiple subjects, sessions, or trials). In StateSpaceDynamics.jl, it is easy to incorporate data from multiple trials in parameters learning. Once again, we will generate a synthetic dataset from our ground truth model to illustrate this process.

n_trials = 100    # Number of independent sequences
n_samples = 1000  # Length of each sequence

all_true_labels = Vector{Vector{Int}}(undef, n_trials);
all_data = Vector{Matrix{Float64}}(undef, n_trials);

for i in 1:n_trials  # Sample each trial independently
    labels_trial, data_trial = rand(rng, model, n=n_samples)
    all_true_labels[i] = labels_trial
    all_data[i] = data_trial
end

total_state1_prop = mean([mean(labels .== 1) for labels in all_true_labels])
print("Average State 1 proportion: $(round(total_state1_prop, digits=3))\n");
Average State 1 proportion: 0.84

Multi-Trial HMM Fitting

When fitting to multiple independent sequences, EM accounts for each sequence starting independently from the initial state distribution. Here, we initialize a new model and fit it to all trials simultaneously.

test_model_multi = HiddenMarkovModel(
    K=2,
    B=[deepcopy(emission_1_init), deepcopy(emission_2_init)],
    A=A_init, πₖ=πₖ_init
)

lls_multi = fit!(test_model_multi, all_data);

Running EM algorithm...   2%|█                                                 |  ETA: 0:00:17 ( 0.17  s/it)
Running EM algorithm...   4%|██                                                |  ETA: 0:00:12 ( 0.12  s/it)
Running EM algorithm...   6%|███                                               |  ETA: 0:00:10 ( 0.11  s/it)
Running EM algorithm...   8%|████                                              |  ETA: 0:00:09 (97.29 ms/it)
Running EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.53 ms/it)

Let's check on how our training went and what parameters we learned.

print("Multi-trial EM converged in $(length(lls_multi)) iterations\n")
print("Log-likelihood improved by $(round(lls_multi[end] - lls_multi[1], digits=1))\n");
print("Multi-trial learned parameters:\n")
print("State 1: μ = $(round.(test_model_multi.B[1].μ, digits=3)), σ² = $(round(test_model_multi.B[1].Σ[1,1], digits=3))\n")
print("State 2: μ = $(round.(test_model_multi.B[2].μ, digits=3)), σ² = $(round(test_model_multi.B[2].Σ[1,1], digits=3))\n");
Multi-trial EM converged in 9 iterations
Log-likelihood improved by 1371.2
Multi-trial learned parameters:
State 1: μ = [-0.997, -1.0], σ² = 0.1
State 2: μ = [1.006, 1.004], σ² = 0.2

Visualize multi-trial EM convergence

p4 = plot(lls_multi, xlabel="EM Iteration", ylabel="Log-Likelihood",
          title="Multi-Trial EM Convergence", legend=false,
          marker=:circle, markersize=3, lw=2, color=:darkgreen)
Example block output

Multi-Trial State Decoding

Now that we have done parameter learning, we can use Viterbi to find the most likely hidden state sequence for each trial with a single function call.

all_pred_labels_vec = viterbi(test_model_multi, all_data);

all_pred_labels = hcat(all_pred_labels_vec...)';      # trials × time
all_true_labels_matrix = hcat(all_true_labels...)';   # trials × time
100×1000 adjoint(::Matrix{Int64}) with eltype Int64:
 2  2  2  2  1  1  1  1  1  1  1  1  1  …  2  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1
 2  2  2  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     2  2  2  2  2  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1  …  2  2  2  2  1  1  1  1  1  1  1  1
 2  2  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1
 2  2  2  2  1  1  1  1  1  1  1  1  1     1  1  1  1  2  2  2  2  2  2  2  2
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1
 2  2  2  2  2  2  2  2  2  2  2  2  1     1  1  1  1  1  1  1  1  1  1  1  1
 ⋮              ⋮              ⋮        ⋱        ⋮              ⋮           
 2  2  2  2  2  2  2  2  2  2  2  2  2     1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  2  2  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  2  2  2  2  2  2
 1  2  2  2  2  2  2  2  2  2  2  2  2     1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1  …  1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1
 2  2  2  1  1  1  1  1  1  2  2  2  2     1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1     1  1  1  1  1  1  1  1  1  1  1  1

Calculate overall accuracy across all trials accounting for label switching

overall_accuracy = mean(all_true_labels_matrix .== all_pred_labels);

swapped_pred_all = 3 .- all_pred_labels;
swapped_accuracy_all = mean(all_true_labels_matrix .== swapped_pred_all);

if swapped_accuracy_all > overall_accuracy
    all_pred_labels = swapped_pred_all
    overall_accuracy = swapped_accuracy_all
    print("Corrected label switching in multi-trial analysis\n")
end

print("Overall state prediction accuracy: $(round(overall_accuracy*100, digits=1))%\n");
Overall state prediction accuracy: 100.0%

We can also look at per-trial accuracies to see how consistent the model is across trials.

trial_accuracies = [mean(all_true_labels_matrix[i, :] .== all_pred_labels[i, :]) for i in 1:n_trials]
print("Per-trial accuracy: $(round(mean(trial_accuracies)*100, digits=1))% ± $(round(std(trial_accuracies)*100, digits=1))%\n");
Per-trial accuracy: 100.0% ± 0.0%

Visualize subset of trials (first 10 trials, first 500 timepoints)

n_trials_display = 10
n_time_display = 500

true_subset = all_true_labels_matrix[1:n_trials_display, 1:n_time_display]
pred_subset = all_pred_labels[1:n_trials_display, 1:n_time_display]

p5 = plot(
    heatmap(true_subset, colormap=:roma, title="True States ($n_trials_display trials)",
           xticks=false, ylabel="Trial", colorbar=false),
    heatmap(pred_subset, colormap=:roma, title="Predicted States (Viterbi)",
           xlabel="Time Steps", ylabel="Trial", colorbar=false),
    layout=(2, 1), size=(900, 400)
)
Example block output

Parameter Recovery Assessment

Since we have access to the true model parameters, we can quantitatively assess how well the multi-trial fitting procedure recovered them.

true_μ1_orig, true_μ2_orig = [-1.0, -1.0], [1.0, 1.0]
learned_μ1 = test_model_multi.B[1].μ
learned_μ2 = test_model_multi.B[2].μ
2-element Vector{Float64}:
 1.0064258376923567
 1.004188904637427

Compare emission model mean vectors

μ1_error = norm(true_μ1_orig - learned_μ1) / norm(true_μ1_orig)
μ2_error = norm(true_μ2_orig - learned_μ2) / norm(true_μ2_orig)

print("Mean vector recovery errors:\n")
print("State 1: $(round(μ1_error*100, digits=1))%, State 2: $(round(μ2_error*100, digits=1))%\n")
Mean vector recovery errors:
State 1: 0.2%, State 2: 0.5%

Compare covariance matrices

true_Σ1_orig, true_Σ2_orig = 0.1, 0.2
learned_Σ1 = test_model_multi.B[1].Σ[1,1]
learned_Σ2 = test_model_multi.B[2].Σ[1,1]

Σ1_error = abs(true_Σ1_orig - learned_Σ1) / true_Σ1_orig
Σ2_error = abs(true_Σ2_orig - learned_Σ2) / true_Σ2_orig

print("Variance recovery errors:\n")
print("State 1: $(round(Σ1_error*100, digits=1))%, State 2: $(round(Σ2_error*100, digits=1))%\n");
Variance recovery errors:
State 1: 0.4%, State 2: 0.2%

Compare transition matrices

true_A_orig = [0.99 0.01; 0.05 0.95]
A_error = norm(true_A_orig - test_model_multi.A) / norm(true_A_orig)
print("Transition matrix error: $(round(A_error*100, digits=1))%\n")
Transition matrix error: 0.3%

Summary

This tutorial demonstrated the complete workflow for Gaussian emission Hidden Markov Models. We covered how to create, sample from, fit, and perform state inference with HMMs using StateSpaceDynamics.jl.

Key Concepts:

  • Discrete hidden states with Gaussian emission distributions
  • Temporal dependencies through Markovian state transitions
  • EM algorithm for joint parameter learning and state inference
  • Viterbi decoding for finding most likely state sequences

Technical Insights:

  • Label switching is a common identifiability issue requiring detection and correction
  • Multi-trial analysis provides more robust parameter estimates than single sequences
  • Parameter recovery quality depends on state separation and sequence length
  • Convergence monitoring through log-likelihood plots ensures proper algorithm behavior

Applications:

  • Time series clustering and regime detection
  • Behavioral state analysis in sequential data
  • Exploratory analysis of temporal datasets with latent structure
  • Foundation for more complex state-space models

Gaussian HMMs provide a fundamental framework for modeling sequential data with discrete latent structure, serving as both standalone models and building blocks for more sophisticated probabilistic time series methods.


This page was generated using Literate.jl.