Skip to content

Base State-Space Model Class

Base State-Space Model Class

Abstract SSM base classes and emissions models used by cd_dynamax.

ssm_temissions

Posterior

Bases: Protocol

A NamedTuple with parameters stored as jax.Array in the leaf nodes.

Prior

Bases: ABC

cd-dynamax priors: these prior should have

  • sample returns a tensor of samples from this prior
  • log_prob returns the log probability of the prior, for a given parameter set

log_prob(x) abstractmethod

A function to compute the log_probability of specific prior classes

Parameters:

Name Type Description Default
x

Pytree with samples to evaluate log_prob at

required

Returns:

Type Description

PyTree with log_prob values

sample(key, M) abstractmethod

A sampling function to be defined by specific prior classes

Parameters:

Name Type Description Default
key

Random number key. Defaults to jr.PRNGKey(0).

required
M

number of samples to draw

required

Returns:

Type Description

PyTree with samples

SSM

Bases: ABC

A base cd-dynamax class for continuous-discrete state space models. Such models consist of parameters, which we may learn, as well as hyperparameters, which specify static properties of the model. This base class allows parameters to be indicated in a standardized way so that they can easily be converted to/from unconstrained form for optimization.

Abstract Methods

Models that inherit from SSM must implement a few key functions and properties:

  • initial_distribution returns the distribution over the initial state given parameters
  • transition_distribution returns the conditional distribution over the next state given the current state and parameters
  • emission_distribution returns the conditional distribution over the emission given the current state and parameters
  • log_prior (optional) returns the log prior probability of the parameters
  • emission_shape returns a tuple specification of the emission shape
  • inputs_shape returns a tuple specification of the input shape, or None if there are no inputs.

The shape properties are required for properly handling batches of data.

Sampling and Computing Log Probabilities

Once these have been implemented, subclasses will inherit the ability to sample from these models and to compute log joint probabilities from the base class functions:

  • sample draws samples of the states and emissions for given parameters
  • log_prob computes the log joint probability of the states and emissions for given parameters

Inference

Many subclasses of SSMs expose basic functions for performing state inference.

  • marginal_log_prob computes the marginal log probability of the emissions, summing over latent states
  • filter computes the filtered posteriors
  • smoother computes the smoothed posteriors

Learning

Likewise, many SSMs will support learning with stochastic gradient descent (SGD) or Markov Chain Monte Carlo (MCMC).

The generic SSM class allows to fit the model with the preferred learning algorithm

For SGD, any subclass that implements marginal_log_prob inherits the base class fitting function

  • fit_sgd run SGD to minimize the negative marginal log probability.

For black-box optimization, any subclass that implements marginal_log_prob inherits the base class fitting function

  • fit_scipy run SciPy-based optimization to minimize the negative marginal log probability.
  • fit_scipy_jaxopt run jaxopt SciPyMinimize optimization to minimize the negative marginal log probability.

For MCMC, any subclass that implements marginal_log_prob inherits the base class fitting function * fit_mcmc run BlackJAX HMC to sample from the posterior over parameters given data.

diffeqsolve_settings: dict property

Return a dictionary of settings for the differential equation solver.

emission_shape: Tuple[int] abstractmethod property

Return a pytree matching the pytree of tuples specifying the shape of a single time step's emissions.

inputs_shape: Optional[Tuple[int]] property

Return a pytree matching the pytree of tuples specifying the shape of a single time step's inputs.

e_step(params: ParameterSet, emissions: Float[Array, 'num_timesteps emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None) -> Tuple[SuffStatsSSM, Scalar]

Perform an E-step to compute expected sufficient statistics under the posterior, \(p(z_{1:T} \mid y_{1:T}, u_{1:T}, \theta)\).

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
emissions Float[Array, 'num_timesteps emission_dim']

emissions \(y_{1:T}\)

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional inputs \(u_{1:T}\)

None

Returns:

Type Description
Tuple[SuffStatsSSM, Scalar]

Expected sufficient statistics under the posterior.

emission_distribution(params: ParameterSet, state: Float[Array, ' state_dim'], inputs: Optional[Float[Array, ' input_dim']] = None) -> tfd.Distribution abstractmethod

Return a distribution over emissions given current state.

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
state Float[Array, ' state_dim']

current latent state \(z_t\)

required
inputs Optional[Float[Array, ' input_dim']]

current inputs \(u_t\)

None

Returns:

Type Description
Distribution

conditional distribution of current emission \(p(y_t \mid z_t, u_t, \theta)\)

filter(params: ParameterSet, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Union[Any]] = None, inputs: Optional[Float[Array, 'ntime input_dim']] = None) -> Posterior

Compute filtering distributions, \(p(z_t \mid y_{1:t}, u_{1:t}, \theta)\) for \(t=1,\ldots,T\).

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
emissions Float[Array, 'ntime emission_dim']

observed data \(y_{1:T}\)

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants: if not None, it is an array

None
filter_hyperparams Optional[Union[Any]]

hyperparameters of the filtering algorithm

None
inputs Optional[Float[Array, 'ntime input_dim']]

current inputs \(u_t\)

None

Returns:

Type Description
Posterior

filtering distributions

fisher_information(params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, warn: bool = True)

Compute the observed Fisher information matrix for the model parameters.

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
props PropertySet

properties specifying which parameters should be learned

required
emissions Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]

one or more sequences of emissions

required
t_emissions Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]

continuous-time specific time instants: if not None, it is an array

None
filter_hyperparams Optional[Any]

if needed, hyperparameters of the filtering algorithm

None
inputs Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]

one or more sequences of corresponding inputs

None
warn bool

whether to print warnings from filters

True

Returns: Fisher information PyTree.

NOTE

The hessian computation requires reverse-mode autodiff.

fit_em(params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, num_iters: int = 50, verbose: bool = True) -> Tuple[ParameterSet, Float[Array, ' num_iters']]

Compute parameter MLE/ MAP estimate using Expectation-Maximization (EM).

EM aims to find parameters that maximize the marginal log probability,

\[\theta^\star = \mathrm{argmax}_\theta \; \log p(y_{1:T}, \theta \mid u_{1:T})\]

It does so by iteratively forming a lower bound (the "E-step") and then maximizing it (the "M-step").

Note: emissions and inputs can either be single sequences or batches of sequences.

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
props PropertySet

properties specifying which parameters should be learned

required
emissions Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]

one or more sequences of emissions

required
t_emissions Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]

continuous-time specific time instants: if not None, it is an array

None
inputs Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]

one or more sequences of corresponding inputs

None
num_iters int

number of iterations of EM to run

50
verbose bool

whether or not to show a progress bar

True

Returns:

Type Description
Tuple[ParameterSet, Float[Array, ' num_iters']]

tuple of new parameters and log likelihoods over the course of EM iterations.

fit_mcmc(initial_params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, mcmc_algorithm={'type': 'nuts', 'n_samples': 100, 'warmup_samples': 10, 'parameters': {}}, verbose=True, warn=True, key: PRNGKey = jr.PRNGKey(0)) -> Tuple[ParameterSet, ParameterSet, Float[Array, ' num_steps'], Float[Array, ' n_mcmc_samples']]

Generate samples from the posterior using Markov Chain Monte Carlo (MCMC).

Parameters:

Name Type Description Default
initial_params ParameterSet

initial parameters \(\theta\)

required
props PropertySet

properties specifying which parameters should be learned

required
emissions Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]

one or more sequences of observed data

required
t_emissions Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]

continuous-time specific time instants: if not None, it is an array

None
filter_hyperparams Optional[Any]

if needed, hyperparameters of the filtering algorithm

None
inputs Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]

one or more sequences of corresponding inputs

None
mcmc_algorithm

dictionary specifying the MCMC algorithm to use and its settings, based on Blackjax. It should contain the following keys: type: type of MCMC algorithm (e.g. "nuts") n_samples: number of samples to draw warmup_samples: number of warmup steps parameters: additional parameters for the MCMC algorithm

{'type': 'nuts', 'n_samples': 100, 'warmup_samples': 10, 'parameters': {}}
verbose

whether or not to show a progress bar

True
warn

whether to print warnings from filters

True
key PRNGKey

a random number generator

PRNGKey(0)

Returns:

Type Description
Tuple[ParameterSet, ParameterSet, Float[Array, ' num_steps'], Float[Array, ' n_mcmc_samples']]

tuple of samples and log probabilities of the samples

fit_scipy(initial_params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, method: str = 'Nelder-Mead', options: dict = {'maxiter': 5000}, return_param_history: bool = False, warn: bool = True) -> Tuple[ParameterSet, Float[Array, ' niter'], Optional[ParameterSet]]

Compute parameter MLE/MAP estimate using SciPy-based chosen method. SciPy-based optimizers aim to find parameters that maximize the marginal log probability,

\[ heta^\star = \mathrm{argmax}_ heta \; \log p(y_{1:T}, heta \mid u_{1:T})\]

by minimizing the negative of that quantity.

Note: emissions and inputs can either be single sequences or batches of sequences.

Parameters:

Name Type Description Default
initial_params ParameterSet

model parameters $ heta$

required
props PropertySet

model properties specifying which parameters should be learned

required
emissions Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]

one or more sequences of observed data

required
t_emissions Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]

continuous-time specific time instants: if not None, it is an array

None
filter_hyperparams Optional[Any]

if needed, hyperparameters of the filtering algorithm

None
inputs Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]

one or more sequences of corresponding inputs

None
method str

optimization method to use, e.g. "Nelder-Mead", "BFGS", etc.

'Nelder-Mead'
options dict

dictionary of options to pass to the SciPy optimizer

{'maxiter': 5000}
return_param_history bool

whether to return the history of parameters

False
warn bool

whether to print warnings from filters

True

Returns: params_fitted: final fitted parameters (PyTree, constrained) losses: array of loss values per iteration params_history (optional): constrained parameter history (if requested)

fit_scipy_jaxopt(initial_params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, method: str = 'nelder-mead', options: dict = {'maxiter': 100}) -> Tuple[ParameterSet, Float[Array, ' niter']]

Compute parameter MLE/ MAP estimate using SciPy-based chosen method from jaxopt.

ScipyMinimize aims to find parameters that maximize the marginal log probability,

\[\theta^\star = \mathrm{argmax}_\theta \; \log p(y_{1:T}, \theta \mid u_{1:T})\]

by minimizing the negative of that quantity.

Note: emissions and inputs can either be single sequences or batches of sequences.

Parameters:

Name Type Description Default
initial_params ParameterSet

model parameters \(\theta\)

required
props PropertySet

properties specifying which parameters should be learned

required
emissions Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]

one or more sequences of emissions

required
t_emissions Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]

continuous-time specific time instants: if not None, it is an array

None
filter_hyperparams Optional[Any]

if needed, hyperparameters of the filtering algorithm

None
inputs Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]

one or more sequences of corresponding inputs

None
method str

optimization method to use, e.g. "Nelder-Mead", "BFGS", etc.

'nelder-mead'
options dict

dictionary of options to pass to the SciPy optimizer

{'maxiter': 100}

Returns:

Type Description
Tuple[ParameterSet, Float[Array, ' niter']]

tuple of new parameters and losses (negative scaled marginal log probs) over the course of Nelder-Mead iterations.

fit_sgd(initial_params: ParameterSet, props: PropertySet, emissions: Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']], t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]] = None, optimizer: optax.GradientTransformation = optax.adam(0.001), batch_size: int = 1, num_epochs: int = 50, shuffle: bool = False, return_param_history: bool = False, return_grad_history: bool = False, warn: bool = True, key: PRNGKey = jr.PRNGKey(0)) -> Tuple[ParameterSet, Float[Array, ' niter']]

Compute parameter MLE/MAP estimate using Stochastic Gradient Descent (SGD).

SGD aims to find parameters that maximize the marginal log probability,

\[\theta^\star = \mathrm{argmax}_\theta \; \log p(y_{1:T}, \theta \mid u_{1:T})\]

by minimizing the negative of that quantity.

Note: emissions and inputs can either be single sequences or batches of sequences.

On each iteration, the algorithm grabs a minibatch of sequences and takes a gradient step.

One pass through the entire set of sequences is called an epoch.

Parameters:

Name Type Description Default
initial_params ParameterSet

model parameters \(\theta\)

required
props PropertySet

model properties specifying which parameters should be learned

required
emissions Union[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_batches num_timesteps emission_dim']]

one or more sequences of observed data

required
t_emissions Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'num_batches num_timesteps 1']]]

continuous-time specific time instants: if not None, it is an array

None
filter_hyperparams Optional[Any]

if needed, hyperparameters of the filtering algorithm

None
inputs Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'num_batches num_timesteps input_dim']]]

one or more sequences of corresponding inputs

None
optimizer GradientTransformation

an optax optimizer for minimization

adam(0.001)
batch_size int

number of sequences per minibatch

1
num_epochs int

number of epochs of SGD to run

50
shuffle bool

whether or not to shuffle the sequences on each

False
return_param_history bool

whether to return the history of parameters

False
return_grad_history bool

whether to return the history of gradients

False
key PRNGKey

a random number generator for selecting minibatches

PRNGKey(0)

Returns:

Type Description
Tuple[ParameterSet, Float[Array, ' niter']]

tuple of new parameters and losses (negative scaled marginal log probs) over the course of SGD iterations. if interested in the history of parameters and gradients, these are returned as well.

initial_distribution(params: ParameterSet, inputs: Optional[Float[Array, ' input_dim']]) -> tfd.Distribution abstractmethod

Return an initial distribution over latent states.

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
inputs Optional[Float[Array, ' input_dim']]

optional inputs \(u_t\)

required

Returns:

Type Description
Distribution

distribution over initial latent state, \(p(z_1 \mid u_t, \theta)\).

initialize(*args, **kwargs) abstractmethod

Initialize the model parameters.

Returns: CD-SSM parameters and their properties.

log_prior(params: ParameterSet) -> Scalar

Return the log prior probability of model parameters.

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required

Returns:

Name Type Description
log_prior Scalar

log prior probability.

log_prob(params: ParameterSet, states: Float[Array, 'num_timesteps state_dim'], emissions: Float[Array, 'num_timesteps emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, key: PRNGKey = jr.PRNGKey(0)) -> Scalar

Compute the log joint probability of the states and observations, \(\log p(y_{1:T}, z_{1:T} \mid \theta)\).

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
states Float[Array, 'num_timesteps state_dim']

latent states \(z_{1:T}\)

required
emissions Float[Array, 'num_timesteps emission_dim']

observed data \(y_{1:T}\)

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

current inputs \(u_t\)

None
key PRNGKey

random number generator key

PRNGKey(0)

Returns:

Type Description
Scalar

log joint probability (scalar) of states and emissions

m_step(params: ParameterSet, props: PropertySet, batch_stats: SuffStatsSSM, m_step_state: Any) -> ParameterSet

Perform an M-step to find parameters that maximize the expected log joint probability.

Specifically, compute

\[\theta^\star = \mathrm{argmax}_\theta \; \mathbb{E}_{p(z_{1:T} \mid y_{1:T}, u_{1:T}, \theta)} \big[\log p(y_{1:T}, z_{1:T}, \theta \mid u_{1:T}) \big]\]

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
props PropertySet

properties specifying which parameters should be learned

required
batch_stats SuffStatsSSM

sufficient statistics from each sequence

required
m_step_state Any

any required state for optimizing the model parameters.

required

Returns:

Type Description
ParameterSet

new parameters

marginal_log_prob(params: ParameterSet, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Float[Array, 'ntime input_dim']] = None, key: PRNGKey = jr.PRNGKey(0)) -> Scalar

Compute log marginal likelihood of observations, \(\log \sum_{z_{1:T}} p(y_{1:T}, z_{1:T} \mid \theta)\).

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
emissions Float[Array, 'ntime emission_dim']

emissions \(y_{1:T}\)

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants: if not None, it is an array

None
filter_hyperparams Optional[Any]

hyperparameters of the filtering algorithm

None
inputs Optional[Float[Array, 'ntime input_dim']]

current inputs \(u_t\)

None
key PRNGKey

random number generator (for use in randomized methods approximating the marginal likelihood)

PRNGKey(0)

Returns:

Type Description
Scalar

marginal log probability

prior_averaged_fisher(prior: Prior, initial_params: ParameterSet, num_timesteps: int, t_emissions: Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'n_samples num_timesteps 1']]] = None, inputs: Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'n_samples num_timesteps input_dim']]] = None, filter_hyperparams: Optional[Union[Any]] = None, transition_type: Optional[str] = 'distribution', n_samples: int = 100, key: PRNGKey = jr.PRNGKey(0))

Compute the expected Fisher information matrix with respect to the prior.

Note: for the Fisher information computation, as we need to compute gradients w.rto the parameters, user must ensure that the sampled parameters are trainable

Parameters:

Name Type Description Default
prior Prior

prior distribution

required
initial_params ParameterSet

initial parameters to use for those not sampled from the prior

required
num_timesteps int

number of timesteps \(T\)

required
t_emissions Optional[Union[Float[Array, 'num_timesteps 1'], Float[Array, 'n_samples num_timesteps 1']]]

continuous-time specific time instants: if not None, it is an array

None
inputs Optional[Union[Float[Array, 'num_timesteps input_dim'], Float[Array, 'n_samples num_timesteps input_dim']]]

inputs \(u_{1:T}\)

None
filter_hyperparams Optional[Union[Any]]

hyperparameters of the filtering algorithm

None
transition_type Optional[str]

type of transition function, either "distribution" (default) or "path" "distribution" samples from the (default Gaussian) transition distribution (default) - This is exact for Linear Gaussian SSMs "path" runs an SDE solver to sample the distribution. This is more "exact" (up to discretization error).

'distribution'
n_samples int

number of samples to draw from the prior

100
key PRNGKey

random number generator

PRNGKey(0)

Returns:

Type Description

Fisher information PyTree averaged over the prior.

sample(params: ParameterSet, key: PRNGKey, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, transition_type: Optional[str] = 'distribution') -> Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

Sample states \(z_{1:T}\) and emissions \(y_{1:T}\) given parameters \(\theta\) and (optionally) inputs \(u_{1:T}\).

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
key PRNGKey

random number generator

required
num_timesteps int

number of timesteps \(T\)

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

inputs \(u_{1:T}\)

None
transition_type Optional[str]

type of transition function, either "distribution" (default) or "path" "distribution" samples from the (default Gaussian) transition distribution (default) - This is exact for Linear Gaussian SSMs "path" runs an SDE solver to sample the distribution. This is more "exact" (up to discretization error). - Note: this is not supported for Linear Gaussian SSMs.

'distribution'

Returns:

Type Description
Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

latent states and emissions

sample_batch(params: ParameterSet, key: PRNGKey, num_sequences: int, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, transition_type: Optional[str] = 'distribution') -> Tuple[Float[Array, 'num_sequences num_timesteps state_dim'], Float[Array, 'num_sequences num_timesteps emission_dim']]

Sample a batch of sequences of states and emissions.

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
key PRNGKey

random number generator

required
num_sequences int

number of sequences to sample

required
num_timesteps int

number of timesteps \(T\)

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

inputs \(u_{1:T}\)

None
transition_type Optional[str]

type of transition function, either "distribution" (default) or "path" "distribution" samples from the (default Gaussian) transition distribution (default) - This is exact for Linear Gaussian SSMs "path" runs an SDE solver to sample the distribution. This is more "exact" (up to discretization error). - Note: this is not supported for Linear Gaussian SSMs.

'distribution'

Returns:

Type Description
Tuple[Float[Array, 'num_sequences num_timesteps state_dim'], Float[Array, 'num_sequences num_timesteps emission_dim']]

latent states and emissions

score(params: ParameterSet, props: PropertySet, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Any] = None, inputs: Optional[Float[Array, 'ntime input_dim']] = None, return_log_prob: bool = False, warn: bool = True, key: PRNGKey = jr.PRNGKey(0)) -> Scalar

Compute the score function, i.e., the gradient of the log marginal likelihood of observations: \(\nabla \log \sum_{z_{1:T}} p(y_{1:T}, z_{1:T} \mid \theta)\).

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
props PropertySet

properties specifying which parameters should be learned

required
emissions Float[Array, 'ntime emission_dim']

observed data \(y_{1:T}\)

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants: if not None, it is an array

None
filter_hyperparams Optional[Any]

hyperparameters of the filtering algorithm

None
inputs Optional[Float[Array, 'ntime input_dim']]

current inputs \(u_t\)

None
return_log_prob bool

whether or not to return the log probability in addition to its gradient

False
warn bool

whether to print warnings from filters

True
key PRNGKey

random number generator (for use in randomized methods approximating the marginal likelihood)

PRNGKey(0)

Returns:

Type Description
Scalar

gradient of marginal log probability

Note: We need to exclude non-trainable parameters from the gradient computation because sometimes objects like integers or booleans are included in the parameter set. These are not differentiable and will cause the gradient to be None even with stop_gradient applied.

smoother(params: ParameterSet, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Union[Any]] = None, inputs: Optional[Float[Array, 'ntime input_dim']] = None) -> Posterior

Compute smoothing distributions, \(p(z_t \mid y_{1:T}, u_{1:T}, \theta)\) for \(t=1,\ldots,T\).

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
emissions Float[Array, 'ntime emission_dim']

observed data \(y_{1:T}\)

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants: if not None, it is an array

None
filter_hyperparams Optional[Union[Any]]

hyperparameters of the filtering algorithm

None
inputs Optional[Float[Array, 'ntime input_dim']]

current inputs \(u_t\)

None

Returns:

Type Description
Posterior

smoothing distributions

transition_distribution(params: ParameterSet, state: Float[Array, ' state_dim'], t0: Optional[Float], t1: Optional[Float], inputs: Optional[Float[Array, ' input_dim']]) -> tfd.Distribution abstractmethod

Return a distribution over next latent state given current state.

Parameters:

Name Type Description Default
params ParameterSet

model parameters \(\theta\)

required
state Float[Array, ' state_dim']

current latent state \(z_t\)

required
inputs Optional[Float[Array, ' input_dim']]

current inputs \(u_t\)

required

Returns:

Type Description
Distribution

conditional distribution of next latent state \(p(z_{t+1} \mid z_t, u_t, \theta)\).

SuffStatsSSM

Bases: Protocol

A NamedTuple with sufficient statistics stored as jax.Array in the leaf nodes.