Simulating and Fitting a Hidden Markov Model

This tutorial demonstrates how to use StateSpaceDynamics.jl to create, sample from, and fit Hidden Markov Models (HMMs).

Load Packages

using LinearAlgebra
using Plots
using Random
using StateSpaceDynamics
using StableRNGs
rng = StableRNG(1234);

Create an HMM

output_dim = 2

A = [0.99 0.01; 0.05 0.95];
πₖ = [0.5; 0.5]

μ_1 = [-1.0, -1.0]
Σ_1 = 0.1 * Matrix{Float64}(I, output_dim, output_dim)
emission_1 = GaussianEmission(output_dim=output_dim, μ=μ_1, Σ=Σ_1)

μ_2 = [1.0, 1.0]
Σ_2 = 0.2 * Matrix{Float64}(I, output_dim, output_dim)
emission_2 = GaussianEmission(output_dim=output_dim, μ=μ_2, Σ=Σ_2)

model = HiddenMarkovModel(K=2, B=[emission_1, emission_2], A=A, πₖ=πₖ)
HiddenMarkovModel{Float64, Vector{Float64}, Matrix{Float64}, Vector{GaussianEmission{Float64, Vector{Float64}, Matrix{Float64}}}}([0.99 0.01; 0.05 0.95], GaussianEmission{Float64, Vector{Float64}, Matrix{Float64}}[GaussianEmission{Float64, Vector{Float64}, Matrix{Float64}}(2, [-1.0, -1.0], [0.1 0.0; 0.0 0.1]), GaussianEmission{Float64, Vector{Float64}, Matrix{Float64}}(2, [1.0, 1.0], [0.2 0.0; 0.0 0.2])], [0.5, 0.5], 2)

Sample from the HMM

num_samples = 10000
true_labels, data = rand(rng, model, n=num_samples)
([2, 2, 2, 2, 2, 2, 2, 1, 1, 1  …  2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [1.4047907916794145 0.6939424571412892 … 0.6216506081766058 0.5659043566187294; 0.25062154873164444 0.6931441433341328 … 1.6380448790988147 1.0830411162585103])

Visualize the sampled dataset

x_vals = data[1, 1:num_samples]
y_vals = data[2, 1:num_samples]
labels_slice = true_labels[1:num_samples]

state_colors = [:dodgerblue, :crimson]

plt = plot()
for state in 1:2
    idx = findall(labels_slice .== state)
    scatter!(x_vals[idx], y_vals[idx];
        color=state_colors[state],
        label="State $state",
        markersize=6)
end

plot!(x_vals, y_vals;
    color=:gray,
    lw=1.5,
    linealpha=0.4,
    label="")

scatter!([x_vals[1]], [y_vals[1]];
    color=:green,
    markershape=:star5,
    markersize=10,
    label="Start")

scatter!([x_vals[end]], [y_vals[end]];
    color=:black,
    markershape=:diamond,
    markersize=8,
    label="End")

xlabel!("Output dim 1")
ylabel!("Output dim 2")
title!("Emissions from HMM (First 100 Points)")
Example block output

Initialize and fit a new HMM to the sampled data

μ_1 = [-0.25, -0.25]
Σ_1 = 0.3 * Matrix{Float64}(I, output_dim, output_dim)
emission_1 = GaussianEmission(output_dim=output_dim, μ=μ_1, Σ=Σ_1)

μ_2 = [0.25, 0.25]
Σ_2 = 0.5 * Matrix{Float64}(I, output_dim, output_dim)
emission_2 = GaussianEmission(output_dim=output_dim, μ=μ_1, Σ=Σ_1)

A = [0.8 0.2; 0.05 0.95]
πₖ = [0.6,0.4]
test_model = HiddenMarkovModel(K=2, B=[emission_1, emission_2], A=A, πₖ=πₖ)

lls = fit!(test_model, data)

plot(lls)
title!("Log-likelihood over EM Iterations")
xlabel!("EM Iteration")
ylabel!("Log-Likelihood")
Example block output

Visualize the latent state predictions using Viterbi

pred_labels= viterbi(test_model, data);

true_mat = reshape(true_labels[1:1000], 1, :)
pred_mat = reshape(pred_labels[1:1000], 1, :)

p1 = heatmap(true_mat;
    colormap = :roma50,
    title = "True State Labels",
    xlabel = "",
    ylabel = "",
    xticks = false,
    yticks = false,
    colorbar = false,
    framestyle = :box)

p2 = heatmap(pred_mat;
    colormap = :roma50,
    title = "Predicted State Labels",
    xlabel = "Timepoints",
    ylabel = "",
    xticks = 0:200:1000,
    yticks = false,
    colorbar = false,
    framestyle = :box)

plot(p1, p2;
    layout = (2, 1),
    size = (700, 500),
    margin = 5Plots.mm)
Example block output

Sampling multiple, independent trials of data from an HMM

n_trials = 100
n_samples = 1000

all_true_labels = Vector{Vector{Int}}(undef, n_trials)
all_data = Vector{Matrix{Float64}}(undef, n_trials)

for i in 1:n_trials
    true_labels, data = rand(rng, model, n=n_samples)
    all_true_labels[i] = true_labels
    all_data[i] = data
end
┌ Warning: Assignment to `true_labels` in soft scope is ambiguous because a global variable by the same name exists: `true_labels` will be treated as a new local. Disambiguate by using `local true_labels` to suppress this warning or `global true_labels` to assign to the existing global variable.
└ @ hidden_markov_model_example.md:159
┌ Warning: Assignment to `data` in soft scope is ambiguous because a global variable by the same name exists: `data` will be treated as a new local. Disambiguate by using `local data` to suppress this warning or `global data` to assign to the existing global variable.
└ @ hidden_markov_model_example.md:159

Fitting an HMM to multiple, independent trials of data

μ_1 = [-0.25, -0.25]
Σ_1 = 0.3 * Matrix{Float64}(I, output_dim, output_dim)
emission_1 = GaussianEmission(output_dim=output_dim, μ=μ_1, Σ=Σ_1)

μ_2 = [0.25, 0.25]
Σ_2 = 0.5 * Matrix{Float64}(I, output_dim, output_dim)
emission_2 = GaussianEmission(output_dim=output_dim, μ=μ_1, Σ=Σ_1)

A = [0.8 0.2; 0.05 0.95]
πₖ = [0.6,0.4]
test_model = HiddenMarkovModel(K=2, B=[emission_1, emission_2], A=A, πₖ=πₖ)

lls = fit!(test_model, all_data)

plot(lls)
title!("Log-likelihood over EM Iterations")
xlabel!("EM Iteration")
ylabel!("Log-Likelihood")
Example block output

Visualize latent state predictions for multiple trials of data using Viterbi

all_pred_labels_vec = viterbi(test_model, all_data)
all_pred_labels = hcat(all_pred_labels_vec...)'
all_true_labels_matrix = hcat(all_true_labels...)'

state_colors = [:dodgerblue, :crimson]
true_subset = all_true_labels_matrix[1:10, 1:500]
pred_subset = all_pred_labels[1:10, 1:500]

p1 = heatmap(
    true_subset,
    colormap = :roma50,
    colorbar = false,
    title = "True State Labels",
    xlabel = "",
    ylabel = "Trials",
    xticks = false,
    yticks = true,
    margin = 5Plots.mm,
    legend = false
)

p2 = heatmap(
    pred_subset,
    colormap = :roma50,
    colorbar = false,
    title = "Predicted State Labels",
    xlabel = "Timepoints",
    ylabel = "Trials",
    xticks = true,
    yticks = true,
    margin = 5Plots.mm,
    legend = false
)

final_plot = plot(
    p1, p2,
    layout = (2, 1),
    size = (850, 550),
    margin = 5Plots.mm,
    legend = false,
)

display(final_plot)

This page was generated using Literate.jl.