Kalman Filter
Inference algorithms for continuous-discrete linear Gaussian SSMs.
inference
¶
KFHyperParams
¶
Bases: NamedTuple
Lightweight container for filtering hyperparameters.
cdlgssm_emissions(params: ParamsCDLGSSM, t_states: Float[Array, 'num_timesteps 1'], state_means: Float[Array, 'num_timesteps state_dim'], state_covs: Optional[Float[Array, 'num_timesteps state_dim state_dim']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, key: PRNGKey = jr.PRNGKey(0), warn: bool = True) -> Tuple[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_timesteps emission_dim emission_dim']]
¶
Compute the emissions corresponding to - a continuous-discrete linear model, as specified by params
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParamsCDLGSSM
|
model parameters. |
required |
t_states
|
Float[Array, 'num_timesteps 1']
|
continuous-time specific time instants of states |
required |
state_means
|
Float[Array, 'num_timesteps state_dim']
|
state means at time instants t_states, always required |
required |
state_covs
|
Optional[Float[Array, 'num_timesteps state_dim state_dim']]
|
state covariances at time instants t_states, optional - if None, then we assume that the states are point estimates, and simply push through emission function |
None
|
inputs
|
Optional[Float[Array, 'num_timesteps input_dim']]
|
optional array of inputs, of shape (1 + num_timesteps) \times input_dim - The extra input is needed for the initial emission, i.e., it should be at time t_init |
None
|
key
|
PRNGKey
|
random key for sampling |
PRNGKey(0)
|
warn
|
bool
|
whether to issue warnings during filtering (e.g., PSD issues). |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
emissions_mean |
Float[Array, 'num_timesteps emission_dim']
|
mean of emissions |
emissions_covariance |
Float[Array, 'num_timesteps emission_dim emission_dim']
|
covariance of emissions |
cdlgssm_filter(params: ParamsCDLGSSM, emissions: Float[Array, 'num_timesteps emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[KFHyperParams] = KFHyperParams(), inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, warn: bool = True) -> PosteriorGSSMFiltered
¶
Run a Continuous Discrete Kalman filter to produce the marginal likelihood and filtered state estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParamsCDLGSSM
|
CD-LGSSM parameters |
required |
emissions
|
Float[Array, 'num_timesteps emission_dim']
|
array of observations. |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants of observations: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[KFHyperParams]
|
hyperparameters for the filter. |
KFHyperParams()
|
inputs
|
Optional[Float[Array, 'num_timesteps input_dim']]
|
optional array of inputs. |
None
|
warn
|
bool
|
whether to issue warnings during filtering (e.g., PSD issues). |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
PosteriorGSSMFiltered |
PosteriorGSSMFiltered
|
filtered posterior object |
cdlgssm_forecast(params: ParamsCDLGSSM, init_forecast: Union[tfd.Distribution, Float[Array, 'state_dim 1']], t_init: Float[Array, '1 1'], t_forecast: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[KFHyperParams] = KFHyperParams(), inputs: Optional[Float[Array, 'ntime input_dim']] = None, output_fields: Optional[List[str]] = ['forecasted_state_means', 'forecasted_state_covariances'], key: PRNGKey = jr.PRNGKey(0), diffeqsolve_settings: dict = {}, warn: bool = True) -> GSSMForecast
¶
Run a Continuous Discrete Kalman filter to produce forecasted state estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParamsCDLGSSM
|
CD-LGSSM parameters. |
required |
init_forecast
|
Union[Distribution, Float[Array, 'state_dim 1']]
|
initial condition to start forecasting with: - if init_forecast is a distribution, then we forecast such distribution based on different filtering methods - if init_forecast is a point estimate of state, then we forecast a forward path starting at that state |
required |
t_init
|
Float[Array, '1 1']
|
time-instant of the initial condition of forecast |
required |
t_forecast
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants of observations: if not None, it is an array |
None
|
filter_hyperparams
|
Optional[KFHyperParams]
|
hyper-parameters of the filter |
KFHyperParams()
|
inputs
|
Optional[Float[Array, 'ntime input_dim']]
|
optional array of inputs, of shape (1 + num_timesteps) \times input_dim - The extra input is needed for the initial emission, i.e., it should be at time t_init |
None
|
output_fields
|
Optional[List[str]]
|
list of fields to return in posterior object. These can take the values If we forecast Gaussian distributions, based on filtering methods "forecasted_state_means", "forecasted_state_covariances", If we forecast paths, based on solving the SDE "forecasted_state_path". |
['forecasted_state_means', 'forecasted_state_covariances']
|
key
|
PRNGKey
|
random key (e.g., for Ensemble Kalman). |
PRNGKey(0)
|
diffeqsolve_settings
|
dict
|
settings for the SDE solver |
{}
|
warn
|
bool
|
whether to issue warnings during filtering (e.g., PSD issues). |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
post |
GSSMForecast
|
forecasted object. |
cdlgssm_joint_sample(params: ParamsCDLGSSM, key: PRNGKey, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, diffeqsolve_settings={}, warn: bool = True) -> Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]
¶
Sample from the joint distribution to produce state and emission trajectories.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParamsCDLGSSM
|
CD-LGSSM parameters |
required |
key
|
PRNGKey
|
random key |
required |
num_timesteps
|
int
|
number of time steps to sample |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants of observations: if not None, it is an array |
None
|
inputs
|
Optional[Float[Array, 'num_timesteps input_dim']]
|
optional array of inputs. |
None
|
diffeqsolve_settings
|
settings for the diff-eq solver |
{}
|
|
warn
|
bool
|
whether to issue warnings during sampling (e.g., PSD issues). |
True
|
Returns:
| Type | Description |
|---|---|
Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]
|
latent states and observed emissions |
cdlgssm_path_sample(params: ParamsCDLGSSM, key: PRNGKey, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, diffeqsolve_settings={}, warn: bool = True) -> Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]
¶
Sample from a forward path of the CD-LGSSM to produce state and emission trajectories.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParamsCDLGSSM
|
CD-LGSSM parameters |
required |
key
|
PRNGKey
|
random number generator key |
required |
num_timesteps
|
int
|
number of time steps to sample |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants of observations: if not None, it is an array |
None
|
inputs
|
Optional[Float[Array, 'num_timesteps input_dim']]
|
optional array of inputs. |
None
|
diffeqsolve_settings
|
settings for the SDE solver |
{}
|
Returns:
| Type | Description |
|---|---|
Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]
|
latent states and observed emissions |
cdlgssm_posterior_sample(key: PRNGKey, params: ParamsCDLGSSM, emissions: Float[Array, 'num_timesteps emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[KFHyperParams] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, jitter: Optional[Scalar] = 0, warn: bool = True) -> Float[Array, 'ntime state_dim']
¶
Run forward-filtering, backward-sampling to draw samples from \(p(z_{1:T} \mid y_{1:T}, u_{1:T})\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKey
|
random number key. |
required |
params
|
ParamsCDLGSSM
|
CD-LGSSM parameters. |
required |
emissions
|
Float[Array, 'num_timesteps emission_dim']
|
sequence of observations. |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants of observations: if not None, it is an array |
None
|
inputs
|
Optional[Float[Array, 'num_timesteps input_dim']]
|
optional sequence of inptus. |
None
|
jitter
|
Optional[Scalar]
|
padding to add to the diagonal of the covariance matrix before sampling. |
0
|
warn
|
bool
|
whether to issue warnings during smoothing (e.g., PSD issues). |
True
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'ntime state_dim']
|
Float[Array, "ntime state_dim"]: one sample of \(z_{1:T}\) from the posterior distribution on latent states. |
cdlgssm_smoother(params: ParamsCDLGSSM, emissions: Float[Array, 'num_timesteps emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[KFHyperParams] = KFHyperParams(), inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, smoother_type: Optional[str] = 'cd_smoother_1', warn: bool = True) -> PosteriorGSSMSmoothed
¶
Run forward-filtering, backward-smoother to compute expectations under the posterior distribution on latent states. Technically, this smoother implements two different types of smoothers
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParamsCDLGSSM
|
CD-LGSSM parameters |
required |
emissions
|
Float[Array, 'num_timesteps emission_dim']
|
array of observations. |
required |
t_emissions
|
Optional[Float[Array, 'num_timesteps 1']]
|
continuous-time specific time instants of observations: if not None, it is an array |
None
|
inputs
|
Optional[Float[Array, 'num_timesteps input_dim']]
|
array of inputs. |
None
|
smoother_type
|
Optional[str]
|
cd_smoother_1: Sarkka's Algorithm 3.17 cd_smoother_2: Sarkka's Algorithm 3.18 |
'cd_smoother_1'
|
warn
|
bool
|
whether to issue warnings during smoothing (e.g., PSD issues). |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
PosteriorGSSMSmoothed |
PosteriorGSSMSmoothed
|
smoothed posterior object. |
compute_pushforward(params: ParamsCDLGSSM, t0: Float, t1: Float, diffeqsolve_settings: dict = {}, warn: bool = True) -> Tuple[Float[Array, 'state_dim state_dim'], Float[Array, 'state_dim state_dim']]
¶
Compute the push-forward of the continuous-time linear Gaussian dynamics based on the SDE definition \(dz_t = F_t z_t dt + L_t dB_t\) from time t0 to t1, returning the dynamics matrix A and covariance Q, such that \(z_{t1} = A z_{t0} + \epsilon, \epsilon \sim \mathcal{N}(0, Q)\)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
ParamsCDLGSSM
|
CD-LGSSM model parameters |
required |
t0
|
Float
|
initial time |
required |
t1
|
Float
|
final time |
required |
diffeqsolve_settings
|
dict
|
settings for the diff-eq solver |
{}
|
warn
|
bool
|
whether to issue warnings (e.g., about PSD issues). |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
A |
Float[Array, 'state_dim state_dim']
|
dynamics matrix from t0 to t1 |
Q |
Float[Array, 'state_dim state_dim']
|
dynamics covariance from t0 to t1 |
preprocess_args(f)
¶
Preprocess the parameter and input arguments in case some are set to None.
preprocess_params_and_inputs(params, num_timesteps, inputs)
¶
Preprocess parameters in case some are set to None.