Base State-Space Model Class
Base State-Space Model Class¶
Abstract SSM base classes and emissions models used by cd_dynamax.
ssm_temissions
¶
Posterior
¶
Bases: Protocol
A NamedTuple with parameters stored as jax.Array in the leaf nodes.
Prior
¶
Bases: ABC
cd-dynamax priors: these prior should have
samplereturns a tensor of samples from this priorlog_probreturns the log probability of the prior, for a given parameter set
log_prob(x)
abstractmethod
¶
A function to compute the log_probability of specific prior classes
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Pytree with samples to evaluate log_prob at |
required |
Returns:
| Type | Description |
|---|---|
|
PyTree with log_prob values |
sample(key, M)
abstractmethod
¶
A sampling function to be defined by specific prior classes
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Random number key. Defaults to jr.PRNGKey(0). |
required | |
M
|
number of samples to draw |
required |
Returns:
| Type | Description |
|---|---|
|
PyTree with samples |
SSM
¶
Bases: ABC
A base cd-dynamax class for continuous-discrete state space models. Such models consist of parameters, which we may learn, as well as hyperparameters, which specify static properties of the model. This base class allows parameters to be indicated in a standardized way so that they can easily be converted to/from unconstrained form for optimization.
Abstract Methods
Models that inherit from SSM must implement a few key functions and properties:
initial_distributionreturns the distribution over the initial state given parameterstransition_distributionreturns the conditional distribution over the next state given the current state and parametersemission_distributionreturns the conditional distribution over the emission given the current state and parameterslog_prior(optional) returns the log prior probability of the parametersemission_shapereturns a tuple specification of the emission shapeinputs_shapereturns a tuple specification of the input shape, orNoneif there are no inputs.
The shape properties are required for properly handling batches of data.
Sampling and Computing Log Probabilities
Once these have been implemented, subclasses will inherit the ability to sample from these models and to compute log joint probabilities from the base class functions:
sampledraws samples of the states and emissions for given parameterslog_probcomputes the log joint probability of the states and emissions for given parameters
Inference
Many subclasses of SSMs expose basic functions for performing state inference.
marginal_log_probcomputes the marginal log probability of the emissions, summing over latent statesfiltercomputes the filtered posteriorssmoothercomputes the smoothed posteriors
Learning
Likewise, many SSMs will support learning with stochastic gradient descent (SGD) or Markov Chain Monte Carlo (MCMC).
The generic SSM class allows to fit the model with the preferred learning algorithm
For SGD, any subclass that implements marginal_log_prob inherits the base class fitting function
fit_sgdrun SGD to minimize the negative marginal log probability.
For black-box optimization, any subclass that implements marginal_log_prob inherits the base class fitting function
fit_scipyrun SciPy-based optimization to minimize the negative marginal log probability.fit_scipy_jaxoptrun jaxopt SciPyMinimize optimization to minimize the negative marginal log probability.
For MCMC, any subclass that implements marginal_log_prob inherits the base class fitting function
* fit_mcmc run BlackJAX HMC to sample from the posterior over parameters given data.
diffeqsolve_settings: dict
property
¶
Return a dictionary of settings for the differential equation solver.
emission_shape: Tuple[int]
abstractmethod
property
¶
Return a pytree matching the pytree of tuples specifying the shape of a single time step's emissions.
inputs_shape: Optional[Tuple[int]]
property
¶
Return a pytree matching the pytree of tuples specifying the shape of a single time step's inputs.
e_step(params: ParameterSet, emissions: Float[Array, 'num_timesteps emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None) -> Tuple[SuffStatsSSM, Scalar]
¶
Perform an E-step to compute expected sufficient statistics under the posterior, \(p(z_{1:T} \mid y_{1:T}, u_{1:T}, \theta)\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
emissions
|
Float[Array, 'num_timesteps emission_dim']
|
emissions \(y_{1:T}\) |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants: if not None, it is an array |
None
|
inputs
|
Optional[Float[Array, 'num_timesteps input_dim']]
|
optional inputs \(u_{1:T}\) |
None
|
Returns:
| Type | Description |
|---|---|
Tuple[SuffStatsSSM, Scalar]
|
Expected sufficient statistics under the posterior. |
emission_distribution(params: ParameterSet, state: Float[Array, ' state_dim'], inputs: Optional[Float[Array, ' input_dim']] = None) -> tfd.Distribution
abstractmethod
¶
Return a distribution over emissions given current state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
state
|
Float[Array, ' state_dim']
|
current latent state \(z_t\) |
required |
inputs
|
Optional[Float[Array, ' input_dim']]
|
current inputs \(u_t\) |
None
|
Returns:
| Type | Description |
|---|---|
Distribution
|
conditional distribution of current emission \(p(y_t \mid z_t, u_t, \theta)\) |
filter(params: ParameterSet, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Union[Any]] = None, inputs: Optional[Float[Array, 'ntime input_dim']] = None) -> Posterior
¶
Compute filtering distributions, \(p(z_t \mid y_{1:t}, u_{1:t}, \theta)\) for \(t=1,\ldots,T\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
emissions
|
Float[Array, 'ntime emission_dim']
|
observed data \(y_{1:T}\) |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[Union[Any]]
|
hyperparameters of the filtering algorithm |
None
|
inputs
|
Optional[Float[Array, 'ntime input_dim']]
|
current inputs \(u_t\) |
None
|
Returns:
| Type | Description |
|---|---|
Posterior
|
filtering distributions |
fisher_information(params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, warn: bool = True)
¶
Compute the observed Fisher information matrix for the model parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
props
|
PropertySet
|
properties specifying which parameters should be learned |
required |
emissions
|
Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]
|
one or more sequences of emissions |
required |
t_emissions
|
Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]
|
continuous-time specific time instants: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[Any]
|
if needed, hyperparameters of the filtering algorithm |
None
|
inputs
|
Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]
|
one or more sequences of corresponding inputs |
None
|
warn
|
bool
|
whether to print warnings from filters |
True
|
Returns: Fisher information PyTree.
NOTE
The hessian computation requires reverse-mode autodiff.
fit_em(params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, num_iters: int = 50, verbose: bool = True) -> Tuple[ParameterSet, Float[Array, ' num_iters']]
¶
Compute parameter MLE/ MAP estimate using Expectation-Maximization (EM).
EM aims to find parameters that maximize the marginal log probability,
It does so by iteratively forming a lower bound (the "E-step") and then maximizing it (the "M-step").
Note: emissions and inputs can either be single sequences or batches of sequences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
props
|
PropertySet
|
properties specifying which parameters should be learned |
required |
emissions
|
Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]
|
one or more sequences of emissions |
required |
t_emissions
|
Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]
|
continuous-time specific time instants: if not None, it is an array |
None
|
inputs
|
Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]
|
one or more sequences of corresponding inputs |
None
|
num_iters
|
int
|
number of iterations of EM to run |
50
|
verbose
|
bool
|
whether or not to show a progress bar |
True
|
Returns:
| Type | Description |
|---|---|
Tuple[ParameterSet, Float[Array, ' num_iters']]
|
tuple of new parameters and log likelihoods over the course of EM iterations. |
fit_mcmc(initial_params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, mcmc_algorithm={'type': 'nuts', 'n_samples': 100, 'warmup_samples': 10, 'parameters': {}}, verbose=True, warn=True, key: PRNGKey = jr.PRNGKey(0)) -> Tuple[ParameterSet, ParameterSet, Float[Array, ' num_steps'], Float[Array, ' n_mcmc_samples']]
¶
Generate samples from the posterior using Markov Chain Monte Carlo (MCMC).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
initial_params
|
ParameterSet
|
initial parameters \(\theta\) |
required |
props
|
PropertySet
|
properties specifying which parameters should be learned |
required |
emissions
|
Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]
|
one or more sequences of observed data |
required |
t_emissions
|
Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]
|
continuous-time specific time instants: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[Any]
|
if needed, hyperparameters of the filtering algorithm |
None
|
inputs
|
Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]
|
one or more sequences of corresponding inputs |
None
|
mcmc_algorithm
|
dictionary specifying the MCMC algorithm to use and its settings, based on Blackjax. It should contain the following keys: type: type of MCMC algorithm (e.g. "nuts") n_samples: number of samples to draw warmup_samples: number of warmup steps parameters: additional parameters for the MCMC algorithm |
{'type': 'nuts', 'n_samples': 100, 'warmup_samples': 10, 'parameters': {}}
|
|
verbose
|
whether or not to show a progress bar |
True
|
|
warn
|
whether to print warnings from filters |
True
|
|
key
|
PRNGKey
|
a random number generator |
PRNGKey(0)
|
Returns:
| Type | Description |
|---|---|
Tuple[ParameterSet, ParameterSet, Float[Array, ' num_steps'], Float[Array, ' n_mcmc_samples']]
|
tuple of samples and log probabilities of the samples |
fit_scipy(initial_params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, method: str = 'Nelder-Mead', options: dict = {'maxiter': 5000}, return_param_history: bool = False, warn: bool = True) -> Tuple[ParameterSet, Float[Array, ' niter'], Optional[ParameterSet]]
¶
Compute parameter MLE/MAP estimate using SciPy-based chosen method. SciPy-based optimizers aim to find parameters that maximize the marginal log probability,
by minimizing the negative of that quantity.
Note: emissions and inputs can either be single sequences or batches of sequences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
initial_params
|
ParameterSet
|
model parameters $ heta$ |
required |
props
|
PropertySet
|
model properties specifying which parameters should be learned |
required |
emissions
|
Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]
|
one or more sequences of observed data |
required |
t_emissions
|
Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]
|
continuous-time specific time instants: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[Any]
|
if needed, hyperparameters of the filtering algorithm |
None
|
inputs
|
Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]
|
one or more sequences of corresponding inputs |
None
|
method
|
str
|
optimization method to use, e.g. "Nelder-Mead", "BFGS", etc. |
'Nelder-Mead'
|
options
|
dict
|
dictionary of options to pass to the SciPy optimizer |
{'maxiter': 5000}
|
return_param_history
|
bool
|
whether to return the history of parameters |
False
|
warn
|
bool
|
whether to print warnings from filters |
True
|
Returns: params_fitted: final fitted parameters (PyTree, constrained) losses: array of loss values per iteration params_history (optional): constrained parameter history (if requested)
fit_scipy_jaxopt(initial_params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, method: str = 'nelder-mead', options: dict = {'maxiter': 100}) -> Tuple[ParameterSet, Float[Array, ' niter']]
¶
Compute parameter MLE/ MAP estimate using SciPy-based chosen method from jaxopt.
ScipyMinimize aims to find parameters that maximize the marginal log probability,
by minimizing the negative of that quantity.
Note: emissions and inputs can either be single sequences or batches of sequences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
initial_params
|
ParameterSet
|
model parameters \(\theta\) |
required |
props
|
PropertySet
|
properties specifying which parameters should be learned |
required |
emissions
|
Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]
|
one or more sequences of emissions |
required |
t_emissions
|
Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]
|
continuous-time specific time instants: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[Any]
|
if needed, hyperparameters of the filtering algorithm |
None
|
inputs
|
Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]
|
one or more sequences of corresponding inputs |
None
|
method
|
str
|
optimization method to use, e.g. "Nelder-Mead", "BFGS", etc. |
'nelder-mead'
|
options
|
dict
|
dictionary of options to pass to the SciPy optimizer |
{'maxiter': 100}
|
Returns:
| Type | Description |
|---|---|
Tuple[ParameterSet, Float[Array, ' niter']]
|
tuple of new parameters and losses (negative scaled marginal log probs) over the course of Nelder-Mead iterations. |
fit_sgd(initial_params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, optimizer: optax.GradientTransformation = optax.adam(0.001), batch_size: int = 1, num_epochs: int = 50, shuffle: bool = False, return_param_history: bool = False, return_grad_history: bool = False, warn: bool = True, key: PRNGKey = jr.PRNGKey(0)) -> Tuple[ParameterSet, Float[Array, ' niter']]
¶
Compute parameter MLE/MAP estimate using Stochastic Gradient Descent (SGD).
SGD aims to find parameters that maximize the marginal log probability,
by minimizing the negative of that quantity.
Note: emissions and inputs can either be single sequences or batches of sequences.
On each iteration, the algorithm grabs a minibatch of sequences and takes a gradient step.
One pass through the entire set of sequences is called an epoch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
initial_params
|
ParameterSet
|
model parameters \(\theta\) |
required |
props
|
PropertySet
|
model properties specifying which parameters should be learned |
required |
emissions
|
Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]
|
one or more sequences of observed data |
required |
t_emissions
|
Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]
|
continuous-time specific time instants: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[Any]
|
if needed, hyperparameters of the filtering algorithm |
None
|
inputs
|
Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]
|
one or more sequences of corresponding inputs |
None
|
optimizer
|
GradientTransformation
|
an |
adam(0.001)
|
batch_size
|
int
|
number of sequences per minibatch |
1
|
num_epochs
|
int
|
number of epochs of SGD to run |
50
|
shuffle
|
bool
|
whether or not to shuffle the sequences on each |
False
|
return_param_history
|
bool
|
whether to return the history of parameters |
False
|
return_grad_history
|
bool
|
whether to return the history of gradients |
False
|
key
|
PRNGKey
|
a random number generator for selecting minibatches |
PRNGKey(0)
|
Returns:
| Type | Description |
|---|---|
Tuple[ParameterSet, Float[Array, ' niter']]
|
tuple of new parameters and losses (negative scaled marginal log probs) over the course of SGD iterations. if interested in the history of parameters and gradients, these are returned as well. |
initial_distribution(params: ParameterSet, inputs: Optional[Float[Array, ' input_dim']]) -> tfd.Distribution
abstractmethod
¶
Return an initial distribution over latent states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
inputs
|
Optional[Float[Array, ' input_dim']]
|
optional inputs \(u_t\) |
required |
Returns:
| Type | Description |
|---|---|
Distribution
|
distribution over initial latent state, \(p(z_1 \mid u_t, \theta)\). |
initialize(*args, **kwargs)
abstractmethod
¶
Initialize the model parameters.
Returns: CD-SSM parameters and their properties.
log_prior(params: ParameterSet) -> Scalar
¶
Return the log prior probability of model parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
Returns:
| Name | Type | Description |
|---|---|---|
log_prior |
Scalar
|
log prior probability. |
log_prob(params: ParameterSet, states: Float[Array, 'num_timesteps state_dim'], emissions: Float[Array, 'num_timesteps emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, key: PRNGKey = jr.PRNGKey(0)) -> Scalar
¶
Compute the log joint probability of the states and observations, \(\log p(y_{1:T}, z_{1:T} \mid \theta)\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
states
|
Float[Array, 'num_timesteps state_dim']
|
latent states \(z_{1:T}\) |
required |
emissions
|
Float[Array, 'num_timesteps emission_dim']
|
observed data \(y_{1:T}\) |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants: if not None, it is an array |
None
|
inputs
|
Optional[Float[Array, 'num_timesteps input_dim']]
|
current inputs \(u_t\) |
None
|
key
|
PRNGKey
|
random number generator key |
PRNGKey(0)
|
Returns:
| Type | Description |
|---|---|
Scalar
|
log joint probability (scalar) of states and emissions |
m_step(params: ParameterSet, props: PropertySet, batch_stats: SuffStatsSSM, m_step_state: Any) -> ParameterSet
¶
Perform an M-step to find parameters that maximize the expected log joint probability.
Specifically, compute
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
props
|
PropertySet
|
properties specifying which parameters should be learned |
required |
batch_stats
|
SuffStatsSSM
|
sufficient statistics from each sequence |
required |
m_step_state
|
Any
|
any required state for optimizing the model parameters. |
required |
Returns:
| Type | Description |
|---|---|
ParameterSet
|
new parameters |
marginal_log_prob(params: ParameterSet, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Float[Array, 'ntime input_dim']] = None, key: PRNGKey = jr.PRNGKey(0)) -> Scalar
¶
Compute log marginal likelihood of observations, \(\log \sum_{z_{1:T}} p(y_{1:T}, z_{1:T} \mid \theta)\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
emissions
|
Float[Array, 'ntime emission_dim']
|
emissions \(y_{1:T}\) |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[Any]
|
hyperparameters of the filtering algorithm |
None
|
inputs
|
Optional[Float[Array, 'ntime input_dim']]
|
current inputs \(u_t\) |
None
|
key
|
PRNGKey
|
random number generator (for use in randomized methods approximating the marginal likelihood) |
PRNGKey(0)
|
Returns:
| Type | Description |
|---|---|
Scalar
|
marginal log probability |
prior_averaged_fisher(prior: Prior, initial_params: ParameterSet, num_timesteps: int, t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'n_samples num_timesteps 1']]] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'n_samples num_timesteps input_dim']]] = None, filter_hyperparams: Optional[Union[Any]] = None, transition_type: Optional[str] = 'distribution', n_samples: int = 100, key: PRNGKey = jr.PRNGKey(0))
¶
Compute the expected Fisher information matrix with respect to the prior.
Note: for the Fisher information computation, as we need to compute gradients w.rto the parameters, user must ensure that the sampled parameters are trainable
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prior
|
Prior
|
prior distribution |
required |
initial_params
|
ParameterSet
|
initial parameters to use for those not sampled from the prior |
required |
num_timesteps
|
int
|
number of timesteps \(T\) |
required |
t_emissions
|
Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'n_samples num_timesteps 1']]]
|
continuous-time specific time instants: if not None, it is an array |
None
|
inputs
|
Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'n_samples num_timesteps input_dim']]]
|
inputs \(u_{1:T}\) |
None
|
filter_hyperparams
|
Optional[Union[Any]]
|
hyperparameters of the filtering algorithm |
None
|
transition_type
|
Optional[str]
|
type of transition function, either "distribution" (default) or "path" "distribution" samples from the (default Gaussian) transition distribution (default) - This is exact for Linear Gaussian SSMs "path" runs an SDE solver to sample the distribution. This is more "exact" (up to discretization error). |
'distribution'
|
n_samples
|
int
|
number of samples to draw from the prior |
100
|
key
|
PRNGKey
|
random number generator |
PRNGKey(0)
|
Returns:
| Type | Description |
|---|---|
|
Fisher information PyTree averaged over the prior. |
sample(params: ParameterSet, key: PRNGKey, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, transition_type: Optional[str] = 'distribution') -> Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]
¶
Sample states \(z_{1:T}\) and emissions \(y_{1:T}\) given parameters \(\theta\) and (optionally) inputs \(u_{1:T}\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
key
|
PRNGKey
|
random number generator |
required |
num_timesteps
|
int
|
number of timesteps \(T\) |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants: if not None, it is an array |
None
|
inputs
|
Optional[Float[Array, 'num_timesteps input_dim']]
|
inputs \(u_{1:T}\) |
None
|
transition_type
|
Optional[str]
|
type of transition function, either "distribution" (default) or "path" "distribution" samples from the (default Gaussian) transition distribution (default) - This is exact for Linear Gaussian SSMs "path" runs an SDE solver to sample the distribution. This is more "exact" (up to discretization error). - Note: this is not supported for Linear Gaussian SSMs. |
'distribution'
|
Returns:
| Type | Description |
|---|---|
Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]
|
latent states and emissions |
sample_batch(params: ParameterSet, key: PRNGKey, num_sequences: int, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, transition_type: Optional[str] = 'distribution') -> Tuple[Float[Array, 'num_sequences num_timesteps state_dim'], Float[Array, 'num_sequences num_timesteps emission_dim']]
¶
Sample a batch of sequences of states and emissions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
key
|
PRNGKey
|
random number generator |
required |
num_sequences
|
int
|
number of sequences to sample |
required |
num_timesteps
|
int
|
number of timesteps \(T\) |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants: if not None, it is an array |
None
|
inputs
|
Optional[Float[Array, 'num_timesteps input_dim']]
|
inputs \(u_{1:T}\) |
None
|
transition_type
|
Optional[str]
|
type of transition function, either "distribution" (default) or "path" "distribution" samples from the (default Gaussian) transition distribution (default) - This is exact for Linear Gaussian SSMs "path" runs an SDE solver to sample the distribution. This is more "exact" (up to discretization error). - Note: this is not supported for Linear Gaussian SSMs. |
'distribution'
|
Returns:
| Type | Description |
|---|---|
Tuple[Float[Array, 'num_sequences num_timesteps state_dim'], Float[Array, 'num_sequences num_timesteps emission_dim']]
|
latent states and emissions |
score(params: ParameterSet, props: PropertySet, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Float[Array, 'ntime input_dim']] = None, return_log_prob: bool = False, warn: bool = True, key: PRNGKey = jr.PRNGKey(0)) -> Scalar
¶
Compute the score function, i.e., the gradient of the log marginal likelihood of observations: \(\nabla \log \sum_{z_{1:T}} p(y_{1:T}, z_{1:T} \mid \theta)\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
props
|
PropertySet
|
properties specifying which parameters should be learned |
required |
emissions
|
Float[Array, 'ntime emission_dim']
|
observed data \(y_{1:T}\) |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[Any]
|
hyperparameters of the filtering algorithm |
None
|
inputs
|
Optional[Float[Array, 'ntime input_dim']]
|
current inputs \(u_t\) |
None
|
return_log_prob
|
bool
|
whether or not to return the log probability in addition to its gradient |
False
|
warn
|
bool
|
whether to print warnings from filters |
True
|
key
|
PRNGKey
|
random number generator (for use in randomized methods approximating the marginal likelihood) |
PRNGKey(0)
|
Returns:
| Type | Description |
|---|---|
Scalar
|
gradient of marginal log probability |
Note: We need to exclude non-trainable parameters from the gradient computation because sometimes objects like integers or booleans are included in the parameter set. These are not differentiable and will cause the gradient to be None even with stop_gradient applied.
smoother(params: ParameterSet, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Union[Any]] = None, inputs: Optional[Float[Array, 'ntime input_dim']] = None) -> Posterior
¶
Compute smoothing distributions, \(p(z_t \mid y_{1:T}, u_{1:T}, \theta)\) for \(t=1,\ldots,T\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
emissions
|
Float[Array, 'ntime emission_dim']
|
observed data \(y_{1:T}\) |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[Union[Any]]
|
hyperparameters of the filtering algorithm |
None
|
inputs
|
Optional[Float[Array, 'ntime input_dim']]
|
current inputs \(u_t\) |
None
|
Returns:
| Type | Description |
|---|---|
Posterior
|
smoothing distributions |
transition_distribution(params: ParameterSet, state: Float[Array, ' state_dim'], t0: Optional[Float], t1: Optional[Float], inputs: Optional[Float[Array, ' input_dim']]) -> tfd.Distribution
abstractmethod
¶
Return a distribution over next latent state given current state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParameterSet
|
model parameters \(\theta\) |
required |
state
|
Float[Array, ' state_dim']
|
current latent state \(z_t\) |
required |
inputs
|
Optional[Float[Array, ' input_dim']]
|
current inputs \(u_t\) |
required |
Returns:
| Type | Description |
|---|---|
Distribution
|
conditional distribution of next latent state \(p(z_{t+1} \mid z_t, u_t, \theta)\). |
SuffStatsSSM
¶
Bases: Protocol
A NamedTuple with sufficient statistics stored as jax.Array in the leaf nodes.