Skip to content

Kalman Filter

Inference algorithms for continuous-discrete linear Gaussian SSMs.

inference

KFHyperParams

Bases: NamedTuple

Lightweight container for filtering hyperparameters.

cdlgssm_emissions(params: ParamsCDLGSSM, t_states: Float[Array, 'num_timesteps 1'], state_means: Float[Array, 'num_timesteps state_dim'], state_covs: Optional[Float[Array, 'num_timesteps state_dim state_dim']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, key: PRNGKey = jr.PRNGKey(0), warn: bool = True) -> Tuple[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_timesteps emission_dim emission_dim']]

Compute the emissions corresponding to - a continuous-discrete linear model, as specified by params

Parameters:

Name Type Description Default
params ParamsCDLGSSM

model parameters.

required
t_states Float[Array, 'num_timesteps 1']

continuous-time specific time instants of states

required
state_means Float[Array, 'num_timesteps state_dim']

state means at time instants t_states, always required

required
state_covs Optional[Float[Array, 'num_timesteps state_dim state_dim']]

state covariances at time instants t_states, optional - if None, then we assume that the states are point estimates, and simply push through emission function

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional array of inputs, of shape (1 + num_timesteps) \times input_dim - The extra input is needed for the initial emission, i.e., it should be at time t_init

None
key PRNGKey

random key for sampling

PRNGKey(0)
warn bool

whether to issue warnings during filtering (e.g., PSD issues).

True

Returns:

Name Type Description
emissions_mean Float[Array, 'num_timesteps emission_dim']

mean of emissions

emissions_covariance Float[Array, 'num_timesteps emission_dim emission_dim']

covariance of emissions

cdlgssm_filter(params: ParamsCDLGSSM, emissions: Float[Array, 'num_timesteps emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[KFHyperParams] = KFHyperParams(), inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, warn: bool = True) -> PosteriorGSSMFiltered

Run a Continuous Discrete Kalman filter to produce the marginal likelihood and filtered state estimates.

Parameters:

Name Type Description Default
params ParamsCDLGSSM

CD-LGSSM parameters

required
emissions Float[Array, 'num_timesteps emission_dim']

array of observations.

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
filter_hyperparams Optional[KFHyperParams]

hyperparameters for the filter.

KFHyperParams()
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional array of inputs.

None
warn bool

whether to issue warnings during filtering (e.g., PSD issues).

True

Returns:

Name Type Description
PosteriorGSSMFiltered PosteriorGSSMFiltered

filtered posterior object

cdlgssm_forecast(params: ParamsCDLGSSM, init_forecast: Union[tfd.Distribution, Float[Array, 'state_dim 1']], t_init: Float[Array, '1 1'], t_forecast: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[KFHyperParams] = KFHyperParams(), inputs: Optional[Float[Array, 'ntime input_dim']] = None, output_fields: Optional[List[str]] = ['forecasted_state_means', 'forecasted_state_covariances'], key: PRNGKey = jr.PRNGKey(0), diffeqsolve_settings: dict = {}, warn: bool = True) -> GSSMForecast

Run a Continuous Discrete Kalman filter to produce forecasted state estimates.

Parameters:

Name Type Description Default
params ParamsCDLGSSM

CD-LGSSM parameters.

required
init_forecast Union[Distribution, Float[Array, 'state_dim 1']]

initial condition to start forecasting with: - if init_forecast is a distribution, then we forecast such distribution based on different filtering methods - if init_forecast is a point estimate of state, then we forecast a forward path starting at that state

required
t_init Float[Array, '1 1']

time-instant of the initial condition of forecast

required
t_forecast Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
filter_hyperparams Optional[KFHyperParams]

hyper-parameters of the filter

KFHyperParams()
inputs Optional[Float[Array, 'ntime input_dim']]

optional array of inputs, of shape (1 + num_timesteps) \times input_dim - The extra input is needed for the initial emission, i.e., it should be at time t_init

None
output_fields Optional[List[str]]

list of fields to return in posterior object. These can take the values If we forecast Gaussian distributions, based on filtering methods "forecasted_state_means", "forecasted_state_covariances", If we forecast paths, based on solving the SDE "forecasted_state_path".

['forecasted_state_means', 'forecasted_state_covariances']
key PRNGKey

random key (e.g., for Ensemble Kalman).

PRNGKey(0)
diffeqsolve_settings dict

settings for the SDE solver

{}
warn bool

whether to issue warnings during filtering (e.g., PSD issues).

True

Returns:

Name Type Description
post GSSMForecast

forecasted object.

cdlgssm_joint_sample(params: ParamsCDLGSSM, key: PRNGKey, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, diffeqsolve_settings={}, warn: bool = True) -> Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

Sample from the joint distribution to produce state and emission trajectories.

Parameters:

Name Type Description Default
params ParamsCDLGSSM

CD-LGSSM parameters

required
key PRNGKey

random key

required
num_timesteps int

number of time steps to sample

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional array of inputs.

None
diffeqsolve_settings

settings for the diff-eq solver

{}
warn bool

whether to issue warnings during sampling (e.g., PSD issues).

True

Returns:

Type Description
Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

latent states and observed emissions

cdlgssm_path_sample(params: ParamsCDLGSSM, key: PRNGKey, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, diffeqsolve_settings={}, warn: bool = True) -> Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

Sample from a forward path of the CD-LGSSM to produce state and emission trajectories.

Parameters:

Name Type Description Default
params ParamsCDLGSSM

CD-LGSSM parameters

required
key PRNGKey

random number generator key

required
num_timesteps int

number of time steps to sample

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional array of inputs.

None
diffeqsolve_settings

settings for the SDE solver

{}

Returns:

Type Description
Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

latent states and observed emissions

cdlgssm_posterior_sample(key: PRNGKey, params: ParamsCDLGSSM, emissions: Float[Array, 'num_timesteps emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[KFHyperParams] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, jitter: Optional[Scalar] = 0, warn: bool = True) -> Float[Array, 'ntime state_dim']

Run forward-filtering, backward-sampling to draw samples from \(p(z_{1:T} \mid y_{1:T}, u_{1:T})\).

Parameters:

Name Type Description Default
key PRNGKey

random number key.

required
params ParamsCDLGSSM

CD-LGSSM parameters.

required
emissions Float[Array, 'num_timesteps emission_dim']

sequence of observations.

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional sequence of inptus.

None
jitter Optional[Scalar]

padding to add to the diagonal of the covariance matrix before sampling.

0
warn bool

whether to issue warnings during smoothing (e.g., PSD issues).

True

Returns:

Type Description
Float[Array, 'ntime state_dim']

Float[Array, "ntime state_dim"]: one sample of \(z_{1:T}\) from the posterior distribution on latent states.

cdlgssm_smoother(params: ParamsCDLGSSM, emissions: Float[Array, 'num_timesteps emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[KFHyperParams] = KFHyperParams(), inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, smoother_type: Optional[str] = 'cd_smoother_1', warn: bool = True) -> PosteriorGSSMSmoothed

Run forward-filtering, backward-smoother to compute expectations under the posterior distribution on latent states. Technically, this smoother implements two different types of smoothers

Parameters:

Name Type Description Default
params ParamsCDLGSSM

CD-LGSSM parameters

required
emissions Float[Array, 'num_timesteps emission_dim']

array of observations.

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

array of inputs.

None
smoother_type Optional[str]

cd_smoother_1: Sarkka's Algorithm 3.17 cd_smoother_2: Sarkka's Algorithm 3.18

'cd_smoother_1'
warn bool

whether to issue warnings during smoothing (e.g., PSD issues).

True

Returns:

Name Type Description
PosteriorGSSMSmoothed PosteriorGSSMSmoothed

smoothed posterior object.

compute_pushforward(params: ParamsCDLGSSM, t0: Float, t1: Float, diffeqsolve_settings: dict = {}, warn: bool = True) -> Tuple[Float[Array, 'state_dim state_dim'], Float[Array, 'state_dim state_dim']]

Compute the push-forward of the continuous-time linear Gaussian dynamics based on the SDE definition \(dz_t = F_t z_t dt + L_t dB_t\) from time t0 to t1, returning the dynamics matrix A and covariance Q, such that \(z_{t1} = A z_{t0} + \epsilon, \epsilon \sim \mathcal{N}(0, Q)\)

Parameters:

Name Type Description Default
params ParamsCDLGSSM

CD-LGSSM model parameters

required
t0 Float

initial time

required
t1 Float

final time

required
diffeqsolve_settings dict

settings for the diff-eq solver

{}
warn bool

whether to issue warnings (e.g., about PSD issues).

True

Returns:

Name Type Description
A Float[Array, 'state_dim state_dim']

dynamics matrix from t0 to t1

Q Float[Array, 'state_dim state_dim']

dynamics covariance from t0 to t1

preprocess_args(f)

Preprocess the parameter and input arguments in case some are set to None.

preprocess_params_and_inputs(params, num_timesteps, inputs)

Preprocess parameters in case some are set to None.