Skip to content

Optimization Utilities

Optimization helper functions.

optimize_utils

run_sgd(loss_fn, params, dataset, optimizer=optax.adam(0.001), batch_size=1, num_epochs=50, shuffle=False, return_param_history=False, return_grad_history=False, key=jr.PRNGKey(0))

Note that batch_emissions is initially of shape (N,T) where N is the number of independent sequences and T is the length of a sequence. Then, a random susbet with shape (B, T) of entire sequence, not time steps, is sampled at each step where B is batch size.

Parameters:

Name Type Description Default
loss_fn

Objective function.

required
params

initial value of parameters to be estimated.

required
dataset

PyTree of data arrays with leading batch dimension

required
optmizer

Optimizer.

required
batch_size

Number of sequences used at each update step.

1
num_iters

Iterations made on only one mini-batch.

required
shuffle

Indicates whether to shuffle emissions.

False
return_param_history

Indicates whether to return history of parameters.

False
return_grad_history

Indicates whether to return history of gradients.

False
key

RNG key.

PRNGKey(0)

Returns:

Name Type Description
hmm

HMM with optimized parameters.

losses

Output of loss_fn stored at each step.