Skip to content

Models

Continuous-discrete nonlinear Gaussian SSM model definitions.

models

ContDiscreteNonlinearGaussianSSM

Bases: SSM

Definition of a Continuous-Discrete Nonlinear Gaussian State Space Model.

We assume a model of the form

\[ dz=f(z,u_t,t)dt + L(t) Q^{1/2}_c dW_t\]

We assume an initial state distribution of the form

\[p(z_0) = N(z_0 | m, S)\]

and observations are obtained by

\[p(y_{t_k} | z_{t_k}) = N(y_{t_k} | h(z_{t_k}, u_{t_k}), R_{t_k})\]

where the model parameters are

  • \(z_{t_k}\) is a latent state of size state_dim,
  • \(m\) = mean of initial state
  • \(S\) = covariance matrix of initial state

  • \(f\) = dynamics deterministic function (RHS), used to compute transition function

  • \(F\) = the Jacobian of the dynamics function evaluated at the mean
  • \(L\) = dynamics coefficient multiplying brownian motion
  • \(Q\) = dynamics brownian motion's covariance (system) noise
  • \(u_t\) = input covariates of size input_dim (defaults to 0).

  • \(y_{t_k}\) is an observed variable (emissions) of size emission_dim

  • \(h\) = emission (observation) function
  • \(R\) = covariance matrix for emission (observation) noise

These parameters of the model are stored in a separate object of type ParamsCDNLGSSM.

emission_distribution(params: ParamsCDNLGSSM, state: Float[Array, ' state_dim'], inputs: Optional[Float[Array, ' input_dim']] = None) -> tfd.Distribution

CD-NLGSSM Emission distribution. Args: params: CD-NLGSSM model parameters. state: current state. inputs: optional inputs.

Returns:

Type Description
Distribution

Gaussian, state-conditional emission distribution

filter(params, emissions, t_emissions=None, inputs=None, filter_type: str = 'EnKF', filter_state_order: str = 'first', filter_emission_order: str = 'first', filter_num_iter: int = 1, filter_state_cov_rescaling: float = 1.0, filter_dt_average: float = 0.1, enkf_N_particles: int = 25, enkf_inflation_delta: float = 0.0, diffeqsolve_max_steps: int = 100, diffeqsolve_dt0: float = 0.01, output_fields=None, key=jr.PRNGKey(0), diffeqsolve_kwargs: Optional[dict] = {}, extra_filter_kwargs: Optional[dict] = {}, warn: bool = True)

A high-level filtering interface, to compute the filtered posterior distribution over states.

Parameters:

Name Type Description Default
params

CD-NLGSSM model parameters.

required
emissions

sequence of observations.

required
t_emissions

continuous-time specific time instants of observations: if not None, it is an array

None
inputs

Optional input sequence.

None
filter_type str

Which filter to run ("EKF", "EnKF", "UKF").

'EnKF'
filter_state_order str

Order of Taylor expansion for dynamics used in the filter.

'first'
filter_emission_order str

Order of Taylor expansion for emissions used in the EKF filter only.

'first'
filter_num_iter int

Number of iterations for iterated filters (EKF only).

1
filter_state_cov_rescaling float

Rescale state covariance by this factor after each update (inflation delta is better for accurate likelihoods)

1.0
filter_dt_average float

[Only for state_order="Discrete"] Average step size to determine constant state noise cov in filter.

0.1
enkf_N_particles int

Number of particles (for EnKF only).

25
enkf_inflation_delta float

EnKF covariance inflation (ignored by EKF/UKF).

0.0
diffeqsolve_max_steps int

Max steps for ODE solver between observations.

100
diffeqsolve_dt0 float

Initial step size for ODE/SDE solver (default is fixed step size).

0.01
output_fields

Which fields to return from the filter.

None
key

Random number generator key.

PRNGKey(0)
diffeqsolve_kwargs Optional[dict]

Extra kwargs for the ODE solver (e.g., {"solver": diffrax.Heun(), "dt0": 1e-2}).

{}
filter_kwargs

Extra kwargs specific to the chosen filter (e.g., {"emission_order": "zeroth"} for EKF).

required
warn bool

whether to issue warnings (e.g., about PSD issues)

True

filter_and_forecast(params, emissions_filter, t_emissions_filter, t_emissions_forecast, inputs_filter=None, inputs_forecast=None, filter_type: str = 'EnKF', filter_state_order: str = 'first', filter_emission_order: str = 'first', filter_num_iter: int = 1, filter_state_cov_rescaling: float = 1.0, filter_dt_average: float = 0.1, enkf_N_particles: int = 25, enkf_inflation_delta: float = 0.0, diffeqsolve_max_steps: int = 100, diffeqsolve_dt0: float = 0.01, key=jr.PRNGKey(0), diffeqsolve_kwargs: Optional[dict] = {}, extra_filter_kwargs: Optional[dict] = {}, warn: bool = True)

A high-level interface combining filtering and forecasting steps. It computes the filtered posterior distributions over states, and then runs forecasting from the last filtered state.

Parameters:

Name Type Description Default
params

CD-NLGSSM model parameters.

required
emissions_filter

sequence of observations for filtering.

required
t_emissions_filter

continuous-time specific time instants of observations for filtering: if not None, it is an array

required
t_emissions_forecast

continuous-time specific time instants of observations for forecasting: if not None, it is an array

required
inputs_filter

Optional input sequence for filtering.

None
inputs_forecast

Optional input sequence for forecasting.

None
filter_type str

Which filter to run ("EKF", "EnKF", "UKF").

'EnKF'
filter_state_order str

Order of Taylor expansion for dynamics used in the filter.

'first'
filter_emission_order str

Order of Taylor expansion for emissions used in the EKF filter only.

'first'
filter_num_iter int

Number of iterations for iterated filters (EKF only).

1
filter_state_cov_rescaling float

Rescale state covariance by this factor after each update (inflation delta is better for accurate likelihoods)

1.0
filter_dt_average float

[Only for state_order="Discrete"] Average step size to determine constant state noise cov in filter.

0.1
enkf_N_particles int

Number of particles (for EnKF only).

25
enkf_inflation_delta float

EnKF covariance inflation (ignored by EKF/UKF).

0.0
diffeqsolve_max_steps int

Max steps for ODE solver between observations.

100
diffeqsolve_dt0 float

Initial step size for ODE/SDE solver (default is fixed step size).

0.01
key

Random number generator key.

PRNGKey(0)
diffeqsolve_kwargs Optional[dict]

Extra kwargs for the ODE solver (e.g., {"solver": diffrax.Heun(), "dt0": 1e-2}).

{}
filter_kwargs

Extra kwargs specific to the chosen filter (e.g., {"emission_order": "zeroth"} for EKF).

required
warn bool

whether to issue warnings (e.g., about PSD issues)

True

Returns:

Name Type Description
filtered

filtering posterior over states

forecasted

forecasting distribution over future states

initial_distribution(params: ParamsCDNLGSSM, inputs: Optional[Float[Array, ' input_dim']] = None) -> tfd.Distribution

Initial distribution.

Parameters:

Name Type Description Default
params ParamsCDNLGSSM

CD-NLGSSM model parameters.

required
inputs Optional[Float[Array, ' input_dim']]

optional initial inputs (not used in state initialization).

None

Returns:

Type Description
Distribution

initial Gaussian state distribution.

initialize(key: Optional[Float[Array, ' key']] = jr.PRNGKey(0), init_prior: Prior = None, initial_mean: dict = None, initial_cov: dict = None, dynamics_drift: dict = None, dynamics_diffusion_coefficient: dict = None, dynamics_diffusion_cov: dict = None, dynamics_approx_order: Optional[float] = 2.0, emission_function: dict = None, emission_cov: dict = None) -> Tuple[ParamsCDNLGSSM, ParamsCDNLGSSM]

Initialize CD-NLGSSM parameters, and their corresponding properties.

Parameters:

Name Type Description Default
key Optional[Float[Array, ' key']]

Random number key. Defaults to jr.PRNGKey(0).

PRNGKey(0)
init_prior Prior

prior distribution for the initialization. Defaults to None.

None
initial_mean dict

parameter \(m\). Defaults to None.

None
initial_cov dict

parameter \(S\). Defaults to None.

None
dynamics_drift dict

The drift function of the latent dynamics. Defaults to None.

None
dynamics_diffusion_coefficient dict

parameter \(L\). Defaults to None.

None
dynamics_diffusion_cov dict

parameter \(Q\). Defaults to None.

None
dynamics_approx_order Optional[float]

order of the approximation to the dynamics. Defaults to 2.

2.0
emission_function dict

The emission function. Defaults to None.

None
emission_cov dict

parameter \(R\). Defaults to None.

None

Returns:

Type Description
Tuple[ParamsCDNLGSSM, ParamsCDNLGSSM]

Tuple[ParamsCDNLGSSM, ParamsCDNLGSSM]: parameters and their properties.

marginal_log_prob(params: ParamsCDNLGSSM, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'ntime 1']] = None, filter_hyperparams: Optional[Union[EKFHyperParams, EnKFHyperParams, UKFHyperParams]] = EKFHyperParams(), inputs: Optional[Float[Array, 'ntime input_dim']] = None, key: PRNGKey = jr.PRNGKey(0), warn: bool = True) -> Scalar

Compute the marginal log-likelihood of a sequence of emissions under the CD-NLGSSM model, Args: params: CD-NLGSSM model parameters. emissions: sequence of observations. t_emissions: continuous-time specific time instants of observations: if not None, it is an array filter_hyperparams: Hyperparameters for the filtering algorithm. inputs: Optional input sequence. key: Random number generator key. Returns: Marginal log-likelihood of the emissions, \(\log p(y_{1:T})\).

sample_dist(params: ParamsCDNLGSSM, key: PRNGKey, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None) -> Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

Sample from the joint distribution of the CD-NLGSSM to produce state and emission trajectories.

Parameters:

Name Type Description Default
params ParamsCDNLGSSM

CD-NLGSSM model parameters

required
key PRNGKey

random number generator key

required
num_timesteps int

number of time steps to sample

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional array of inputs.

None

Returns:

Type Description
Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

latent states and observed emissions

sample_path(params: ParamsCDNLGSSM, key: PRNGKey, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None) -> Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

Sample from a forward path of the CD-NLGSSM to produce state and emission trajectories.

Parameters:

Name Type Description Default
params ParamsCDNLGSSM

model parameters

required
key PRNGKey

random number generator key

required
num_timesteps int

number of time steps to sample

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional array of inputs.

None

Returns:

Type Description
Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

latent states and observed emissions

sample_prior(prior: Prior, M: int, init_params: Optional[ParamsCDNLGSSM] = None, key: Optional[PRNGKey] = jr.PRNGKey(0)) -> Tuple[ParamsCDNLGSSM, ParamsCDNLGSSM]

Sample from the prior distribution over CD-NLGSSM model parameters.

Parameters:

Name Type Description Default
prior Prior

Prior distribution.

required
M int

Number of samples to draw.

required
init_params Optional[ParamsCDNLGSSM]

Dictionary of parameters to use as initialization; if not provided, default parameters are used.

None
key Optional[PRNGKey]

Random number generator key.

PRNGKey(0)

Returns:

Type Description
Tuple[ParamsCDNLGSSM, ParamsCDNLGSSM]

Tuple of sampled CD-NLGSSM parameters and properties objects.

transition_distribution(params: ParamsCDNLGSSM, state: Float[Array, ' state_dim'], t0: Optional[Float] = None, t1: Optional[Float] = None, inputs: Optional[Float[Array, ' input_dim']] = None) -> tfd.Distribution

CD-NLGSSM Transition distribution, defined as the pushforward of the continuous-time, nonlinear dynamics and discrete inputs (controls).

Parameters:

Name Type Description Default
params ParamsCDNLGSSM

CD-NLGSSM model parameters.

required
state Float[Array, ' state_dim']

current state.

required
t0 Optional[Float]

initial time.

None
t1 Optional[Float]

final time.

None
inputs Optional[Float[Array, ' input_dim']]

optional inputs.

None

Returns:

Type Description
Distribution

Gaussian, state transition distribution.

Distribution
NOTE: for general CD-NLSSMs, we can not return a specific distribution,
Distribution
unless we solve the Fokker-Planck equation for the model SDE
Distribution
Here, we are dealing with CD-NLGSSM: we return a Gaussian distribution.

cdnlgssm_emissions(params: ParamsCDNLGSSM, t_states: Float[Array, 'num_timesteps 1'], state_means: Float[Array, 'num_timesteps state_dim'], state_covs: Optional[Float[Array, 'num_timesteps state_dim state_dim']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, filter_hyperparams: Optional[Union[EKFHyperParams, EnKFHyperParams, UKFHyperParams]] = None, key: PRNGKey = jr.PRNGKey(0), warn: bool = True) -> Tuple[Float[Array, 'num_timesteps emission_dim'], Float[Array, 'num_timesteps emission_dim emission_dim']]

Compute the emissions corresponding to - a continuous-discrete nonlinear model, as specified by params - a filter method for a continuous-discrete nonlinear model Depending on the hyperparameter class provided, it can execute EKF, UKF or EnKF

Parameters:

Name Type Description Default
params ParamsCDNLGSSM

model parameters.

required
t_states Float[Array, 'num_timesteps 1']

continuous-time specific time instants of states

required
state_means Float[Array, 'num_timesteps state_dim']

state means at time instants t_states, always required

required
state_covs Optional[Float[Array, 'num_timesteps state_dim state_dim']]

state covariances at time instants t_states, optional - if None, then we assume that the states are point estimates, and simply push through emission function

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional array of inputs, of shape (1 + num_timesteps) \times input_dim - The extra input is needed for the initial emission, i.e., it should be at time t_init

None
filter_hyperparams Optional[Union[EKFHyperParams, EnKFHyperParams, UKFHyperParams]]

hyper-parameters of the filter, optional

None
key PRNGKey

random key for sampling

PRNGKey(0)

Returns:

Name Type Description
emissions_mean Float[Array, 'num_timesteps emission_dim']

mean of emissions

emissions_covariance Float[Array, 'num_timesteps emission_dim emission_dim']

covariance of emissions

cdnlgssm_filter(params: ParamsCDNLGSSM, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Union[EKFHyperParams, EnKFHyperParams, UKFHyperParams]] = EKFHyperParams(), inputs: Optional[Float[Array, 'ntime input_dim']] = None, num_iter: Optional[int] = 1, output_fields: Optional[List[str]] = None, key: PRNGKey = jr.PRNGKey(0), warn: bool = True) -> PosteriorGSSMFiltered

Run an continuous-discrete nonlinear filter to produce the marginal likelihood and filtered state estimates.

Depending on the hyperparameter class provided, it can execute EKF, UKF or EnKF

Parameters:

Name Type Description Default
params ParamsCDNLGSSM

CD-NLGSSM parameters.

required
emissions Float[Array, 'ntime emission_dim']

observation sequence.

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
filter_hyperparams Optional[Union[EKFHyperParams, EnKFHyperParams, UKFHyperParams]]

hyper-parameters of the filter

EKFHyperParams()
inputs Optional[Float[Array, 'ntime input_dim']]

optional array of inputs.

None
num_iter Optional[int]

number of linearizations around posterior for update step (default 1).

1
output_fields Optional[List[str]]

list of fields to return in posterior object. These can take the values "filtered_means", "filtered_covariances", "predicted_means", "predicted_covariances", and "marginal_loglik".

None
key PRNGKey

random key (e.g., for EnKF).

PRNGKey(0)
warn bool

whether to issue warnings during filtering.

True

Returns:

Name Type Description
post PosteriorGSSMFiltered

filtered posterior object.

cdnlgssm_forecast(params: ParamsCDNLGSSM, init_forecast: Union[tfd.Distribution, Float[Array, 'state_dim 1']], t_init: Float[Array, '1 1'], t_forecast: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Union[EKFHyperParams, EnKFHyperParams, UKFHyperParams]] = EKFHyperParams(), inputs: Optional[Float[Array, 'ntime input_dim']] = None, output_fields: Optional[List[str]] = ['forecasted_state_means', 'forecasted_state_covariances'], key: PRNGKey = jr.PRNGKey(0), diffeqsolve_settings: dict = {}, warn: bool = True) -> GSSMForecast

Run an continuous-discrete nonlinear model to produce the forecasted state estimates.

Depending on the hyperparameter class provided,
    it can execute EKF, UKF or EnKF

Parameters:

Name Type Description Default
params ParamsCDNLGSSM

CD-NLGSSM parameters.

required
init_forecast Union[Distribution, Float[Array, 'state_dim 1']]

initial condition to start forecasting with: - if init_forecast is a distribution, then we forecast such distribution based on different filtering methods - if init_forecast is a point estimate of state, then we forecast a forward path starting at that state

required
t_init Float[Array, '1 1']

time-instant of the initial condition of forecast

required
t_forecast Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants for forecasting: if not None, it is an array

None
filter_hyperparams Optional[Union[EKFHyperParams, EnKFHyperParams, UKFHyperParams]]

hyper-parameters of the filter

EKFHyperParams()
inputs Optional[Float[Array, 'ntime input_dim']]

optional array of inputs, of shape (1 + num_timesteps) \times input_dim - The extra input is needed for the initial emission, i.e., it should be at time t_init

None
output_fields Optional[List[str]]

list of fields to return in posterior object. These can take the values If we forecast Gaussian distributions, based on filtering methods "forecasted_state_means", "forecasted_state_covariances", If we forecast paths, based on solving the SDE "forecasted_state_path".

['forecasted_state_means', 'forecasted_state_covariances']
key PRNGKey

random key (e.g., for Ensemble Kalman).

PRNGKey(0)
diffeqsolve_settings dict

settings for the SDE solver

{}
warn bool

whether to issue warnings during forecasting (e.g., PSD issues).

True

Returns:

Name Type Description
post GSSMForecast

forecasted object.

cdnlgssm_joint_sample(params: ParamsCDNLGSSM, key: PRNGKey, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, diffeqsolve_settings={}) -> Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

Sample from the joint distribution of a CD-NLGSSM to produce state and emission trajectories.

Parameters:

Name Type Description Default
params ParamsCDNLGSSM

CD-NLGSSM model parameters

required
key PRNGKey

random number generator key

required
num_timesteps int

number of time steps to sample

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional array of inputs.

None
diffeqsolve_settings

settings for the SDE solver.

{}

Returns:

Type Description
Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

latent states and observed emissions

cdnlgssm_path_sample(params: ParamsCDNLGSSM, key: PRNGKey, num_timesteps: int, t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, inputs: Optional[Float[Array, 'num_timesteps input_dim']] = None, diffeqsolve_settings={}) -> Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

Sample from a forward path of the CD-NLGSSM to produce state and emission trajectories.

Parameters:

Name Type Description Default
params ParamsCDNLGSSM

CD-NLGSSM parameters

required
key PRNGKey

random number generator key

required
num_timesteps int

number of time steps to sample

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
inputs Optional[Float[Array, 'num_timesteps input_dim']]

optional array of inputs.

None
diffeqsolve_settings

settings for the SDE solver.

{}

Returns:

Type Description
Tuple[Float[Array, 'num_timesteps state_dim'], Float[Array, 'num_timesteps emission_dim']]

latent states and observed emissions

cdnlgssm_smoother(params: ParamsCDNLGSSM, emissions: Float[Array, 'ntime emission_dim'], t_emissions: Optional[Float[Array, 'num_timesteps 1']] = None, filter_hyperparams: Optional[Union[EKFHyperParams, EnKFHyperParams, UKFHyperParams]] = EKFHyperParams(), inputs: Optional[Float[Array, 'ntime input_dim']] = None, num_iter: Optional[int] = 1, output_fields: Optional[List[str]] = ['filtered_means', 'filtered_covariances', 'smoothed_means', 'smoothed_covariances', 'marginal_loglik'], key: PRNGKey = jr.PRNGKey(0), warn: bool = True) -> PosteriorGSSMFiltered

Run an continuous-discrete nonlinear smoother to produce the marginal likelihood and smoothed state estimates.

Depending on the hyperparameter class provided, it can execute different smoothers

Parameters:

Name Type Description Default
params ParamsCDNLGSSM

CD-NLGSSM parameters.

required
emissions Float[Array, 'ntime emission_dim']

observation sequence.

required
t_emissions Optional[Float[Array, 'num_timesteps 1']]

continuous-time specific time instants of observations: if not None, it is an array

None
filter_hyperparams Optional[Union[EKFHyperParams, EnKFHyperParams, UKFHyperParams]]

hyper-parameters of the smoother to use

EKFHyperParams()
inputs Optional[Float[Array, 'ntime input_dim']]

optional array of inputs.

None
num_iter Optional[int]

optinal, number of linearizations around posterior for update step (default 1).

1
output_fields Optional[List[str]]

list of fields to return in posterior object. These can take the values "filtered_means", "filtered_covariances", "smoothed_means", "smoothed_covariances", and "marginal_loglik".

['filtered_means', 'filtered_covariances', 'smoothed_means', 'smoothed_covariances', 'marginal_loglik']
key PRNGKey

random key (e.g., for EnKS).

PRNGKey(0)
warn bool

whether to issue warnings during smoothing (e.g., PSD issues).

True

Returns:

Name Type Description
post PosteriorGSSMFiltered

posterior object.

compute_pushforward(x0: Float[Array, ' state_dim'], P0: Float[Array, 'state_dim state_dim'], params: ParamsCDNLGSSM, t0: Float, t1: Float, inputs: Optional[Float[Array, ' input_dim']] = None, diffeqsolve_settings: dict = {}) -> Tuple[Float[Array, 'state_dim state_dim'], Float[Array, 'state_dim state_dim']]

Compute the push-forward of the sufficient statistics of a Gaussian distribution through the CDNLGSSM dynamics.

Parameters:

Name Type Description Default
x0 Float[Array, ' state_dim']

initial mean

required
P0 Float[Array, 'state_dim state_dim']

initial covariance

required
params ParamsCDNLGSSM

model parameters

required
t0 Float

initial time

required
t1 Float

final time

required
inputs Optional[Float[Array, ' input_dim']]

optional inputs

None
diffeqsolve_settings dict

settings for the diffrax diffeqsolve function

{}

Returns:

Type Description
Tuple[Float[Array, 'state_dim state_dim'], Float[Array, 'state_dim state_dim']]

mean and covariance of the push-forward distribution