Optimization Utilities
Optimization helper functions.
optimize_utils
¶
run_sgd(loss_fn, params, dataset, optimizer=optax.adam(0.001), batch_size=1, num_epochs=50, shuffle=False, return_param_history=False, return_grad_history=False, key=jr.PRNGKey(0))
¶
Note that batch_emissions is initially of shape (N,T) where N is the number of independent sequences and T is the length of a sequence. Then, a random susbet with shape (B, T) of entire sequence, not time steps, is sampled at each step where B is batch size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_fn
|
Objective function. |
required | |
params
|
initial value of parameters to be estimated. |
required | |
dataset
|
PyTree of data arrays with leading batch dimension |
required | |
optmizer
|
Optimizer. |
required | |
batch_size
|
Number of sequences used at each update step. |
1
|
|
num_iters
|
Iterations made on only one mini-batch. |
required | |
shuffle
|
Indicates whether to shuffle emissions. |
False
|
|
return_param_history
|
Indicates whether to return history of parameters. |
False
|
|
return_grad_history
|
Indicates whether to return history of gradients. |
False
|
|
key
|
RNG key. |
PRNGKey(0)
|
Returns:
| Name | Type | Description |
|---|---|---|
hmm |
HMM with optimized parameters. |
|
losses |
Output of loss_fn stored at each step. |