Lorenz63 neural drift SGD fit
# Imports
import sys
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
# CD-NLGSSM imports
from cd_dynamax import make_key_sequence
from cd_dynamax.src.utils.experiment_utils import *
from cd_dynamax.src.utils.simulation_utils import filter_and_forecast
from cd_dynamax.src.utils.demo_utils import sample_t_emissions
# Python plotting import
sys.path.append("../scripts")
from filter_forecast_plotting_utils import (
data_plotter,
plot_filter_then_forecast_state_results,
)
from parameter_learning_plotting_utils import plot_mll_learning_curve
# Set default config path
config_path = "../configs/"
Define the Lorenz 63 Model¶
We generate data from a Lorenz 63 system, from dynamics with the following stochastic differential equations:
\begin{align*} \frac{d x}{d t} &= a(y-x) + \sigma w_x(t) \\ \frac{d y}{d t} &= x(b-z) - y + \sigma w_y(t) \\ \frac{d z}{d t} &= xy - cz + \sigma w_z(t), \end{align*}
With parameters $a=10, b=28, c=8/3$, the system gives rise to chaotic behavior, and we choose $\sigma=1.0$ for diffusion.
To generate data, we numerically approximate random path solutions to this SDE using Heun's method (i.e. improved Euler), as implemented in Diffrax.
We assume the observation model is \begin{align*} y(t) &= H x(t) + r(t) \\ r(t) &\sim N(0,R), \end{align*} where we choose $R=I$.
Namely, we impose partial observability with H=[1, 0, 0], with noisy observations, sampled at irregular time intervals.
# Default Lorenz 63 parameter definition in config file
# where only the first component (x_1) is observed
default_lorenz63_config_model = "model/l63_x1_filter"
# Specify true Lorenz '63 parameters
from cd_dynamax.src.utils.physics_based_models import LearnableLorenz63_Drift
# Define the True Lorenz 63 Model parameters dictionary
true_model_def = {
"initial_values.dynamics_drift": {
"params": LearnableLorenz63_Drift(sigma=10.0, rho=28.0, beta=8.0 / 3.0),
"props": None, # Let us use the default parameter properties
}
}
# Create and initialize the CD-NLGSSM model
true_model, true_params, true_props = create_cddynamax_model_from_config(
config_path=config_path,
true_model_config_file=default_lorenz63_config_model,
overrides=true_model_def,
)
Simulate data using true CDNLGSSM model¶
keys = make_key_sequence(0)
t_emissions = sample_t_emissions(
start=0.0, stop=50.0, dt=0.01, regular=False, key=next(keys)
)
# Simulate data using true parameters (observed at times t_emissions)
# Note that default SDE diffeqsolve settings solve using dfx.Heun() with dfx.ConstantStepSize() and dt0=0.01 step size.
states, emissions = true_model.sample(
params=true_params,
key=next(keys),
num_timesteps=len(t_emissions),
t_emissions=t_emissions,
transition_type="path",
)
[make_t_emissions] Generated 5000 points from 0.0 to 50.0 with avg step ~0.0100 [make_t_emissions] Smallest time step = 0.000004 Sampling from SDE solver path: this may be an unnecessarily poor approximation if you're simulating from a linear SDE. It is an appropriate choice for non-linear SDEs. CD-NLGSSM Sampling from continuous-discrete non-linear Gaussian SSM path
# Plot the results in a subplot with 3 rows, with observations overlaid in first emission_dims rows
fig, axes = plt.subplots(true_model.state_dim, 1, figsize=(10, 8), sharex=True)
state_labels = ["x", "y", "z"]
for d in range(true_model.state_dim):
### True data
data_plotter(
ax=axes[d],
t_idx=t_emissions,
states=states[:, d],
state_label=state_labels[d],
state_style={"color": "black", "linestyle": "-", "linewidth": 1.5},
emissions=emissions[:, d] if d < emissions.shape[1] else None,
emission_label="True Observation" if d < emissions.shape[1] else None,
emissions_style={
"color": "orange",
"linestyle": "None",
"marker": "x",
"markersize": 3,
"alpha": 0.5,
},
plot_observations=True and d < emissions.shape[1],
)
axes[-1].set_xlabel("Time")
plt.suptitle("Lorenz '63 System States and Observations")
plt.tight_layout()
plt.show()
Fitting a Lorenz 63 model with Neural Network based drift function to observed data¶
Neural Network based function definition¶
Learnable model definition¶
# New Lorenz 63 model definition: default config file
# where only the first component (x_1) is observed
learnable_lorenz63_config_model = "model/l63_x1_filter"
# Define this new Lorenz 63 Learnable Model drift function via a NN-based drift
from cd_dynamax.src.utils.data_driven_models import LearnableNN_TwoLayerGeLU
# Define hidden dimension for the NN
nn_hidden_dim = 128
learnable_nndrift_model_def = {
"initial_values.dynamics_drift": {
"params": LearnableNN_TwoLayerGeLU(
weights1=jnp.sqrt(2.0 / true_model.state_dim)
* jr.normal(next(keys), (nn_hidden_dim, true_model.state_dim)),
bias1=jnp.zeros((nn_hidden_dim,)),
weights2=jnp.sqrt(2.0 / nn_hidden_dim)
* jr.normal(next(keys), (true_model.state_dim, nn_hidden_dim)),
bias2=jnp.zeros((true_model.state_dim,)),
),
# All parameters are learnable
"props": LearnableNN_TwoLayerGeLU(
weights1=ParameterProperties(trainable=True),
bias1=ParameterProperties(trainable=True),
weights2=ParameterProperties(trainable=True),
bias2=ParameterProperties(trainable=True),
),
}
}
# Create and initialize the CD-NLGSSM model
learnable_model, learnable_params, learnable_props = create_cddynamax_model_from_config(
config_path=config_path,
true_model_config_file=learnable_lorenz63_config_model,
overrides=learnable_nndrift_model_def,
)
Fitting: SGD-based optimization of log-likelihood, computed via Ensemble Kalman Filter (EnKF)¶
Ensemble Kalman Filter (EnKF)
# Default EKF settings are defined in config file
default_enkf_config = "filter/enkf_StateFirst"
# Figure-out the filtering/smoothing settings from config
enkf_hyperparams, enkf_info = create_cddynamax_filter_from_config(
config_path=config_path, filter_config_file=default_enkf_config, overrides={}
)
SGD for log-likelihood maximization
# Load SGD fitting configuration
# Fit configuration file
fit_config_file = "fitting/sgd_l63_x1_nndrift_fit_to_data"
# Full path to the fit config file
fit_config_filepath = os.path.join(config_path, fit_config_file)
# Check if the config file exists
if not os.path.exists(fit_config_filepath):
raise FileNotFoundError(f"Configuration file '{fit_config_filepath}' not found.")
# Read the fitting configuration
from configparser import ConfigParser
config = ConfigParser()
config.read(fit_config_filepath)
# For each optimization method specified in the config
optim_configs = config.sections()
# Run SGD for parameter estimation
if "sgd" in optim_configs:
# SGD configuration
sgd_config = config["sgd"]
# SGD MAP estimation via the cd-dynamax learnable_model's fit_sgd method
print("\nStarting SGD fitting...")
fit_results = learnable_model.fit_sgd(
initial_params=learnable_params,
props=learnable_props,
emissions=emissions,
t_emissions=t_emissions,
filter_hyperparams=enkf_hyperparams,
inputs=None,
optimizer=eval(sgd_config.get("optimizer", "optax.adam(1e-3)")),
batch_size=sgd_config.getint("batch_size", 1),
num_epochs=sgd_config.getint("num_epochs", 50),
shuffle=sgd_config.getboolean("shuffle", False),
return_param_history=sgd_config.getboolean("return_param_history", False),
return_grad_history=sgd_config.getboolean("return_grad_history", False),
key=jr.PRNGKey(sgd_config.getint("key", 0)),
)
Starting SGD fitting...
Organize results for easier plotting
# First element of tuple is the fitted parameters
sgd_fitted_params = fit_results[0]
# Note that fit_sgd returns normalized negative log-likelihoods
# So we multiply it by the number of datapoints in emissions
sgd_loss_history = -fit_results[1] * emissions.size
# Process history outputs if requested
if sgd_config.getboolean("return_param_history", False) and sgd_config.getboolean(
"return_grad_history", False
):
sgd_fitted_params_history = fit_results[2]
sgd_fitted_grad_history = fit_results[3]
elif sgd_config.getboolean("return_param_history", False):
sgd_fitted_params_history = fit_results[2]
elif sgd_config.getboolean("return_grad_history", False):
sgd_fitted_grad_history = fit_results[2]
# Plot the marginal log likelihood learning curve
plot_mll_learning_curve(
true_model=true_model,
true_params=true_params,
true_emissions=emissions,
t_emissions=t_emissions,
marginal_lls=sgd_loss_history,
filter_hyperparams=enkf_hyperparams,
plot_save_path=None, # Plot here instead of saving
)
Filter then forecast using true and learned models¶
We now generate a new test trajectory from the true model, then do filtering and forecasting (using both true and learned models) to evaluate short-term forecasting performance.
# Define a set of time indices for filtering and forecasting
test_t_emissions = sample_t_emissions(
start=0.0, stop=100.0, dt=0.01, regular=True, key=next(keys)
)
[make_t_emissions] Generated 10000 points from 0.0 to 100.0 with avg step ~0.0100 [make_t_emissions] Smallest time step = 0.010000
# Now, generate a new testing trajectory from the true model.
test_states, test_emissions = true_model.sample(
params=true_params,
key=next(keys),
num_timesteps=len(test_t_emissions),
t_emissions=test_t_emissions,
transition_type="path",
)
Sampling from SDE solver path: this may be an unnecessarily poor approximation if you're simulating from a linear SDE. It is an appropriate choice for non-linear SDEs. CD-NLGSSM Sampling from continuous-discrete non-linear Gaussian SSM path
# Forecasting and filtering on specific range
T0 = 45.0
T_filter_end = 50.0
T_forecast_end = 52.0
Filtering and Forecasting with the true model¶
Note that the diffusion in the true SDE model (combined with the chaotic sensitivies of the Lorenz drift) makes even high-quality forecasts diverge quickly.
print(
"Running filtering with {filter_name} with True model from T={T0} up to T={T_filter_end} and forecasting up to T={T_forecast_end}.".format(
filter_name=enkf_info["name"]
if enkf_info is not None and "name" in enkf_info
else "the specified filter",
T0=T0,
T_filter_end=T_filter_end,
T_forecast_end=T_forecast_end,
)
)
# Perform filtering and forecasting
(
true_test_enkf_filtered,
true_test_enkf_forecasted,
start_idx_filter,
stop_idx_filter,
start_idx_forecast,
stop_idx_forecast,
) = filter_and_forecast(
model_params=true_params,
filter_hyperparams=enkf_hyperparams,
t_emissions=test_t_emissions,
emissions=test_emissions,
T0=T0,
T_filter_end=T_filter_end,
T_forecast_end=T_forecast_end,
)
Running filtering with EnKF (1st order) with True model from T=45.0 up to T=50.0 and forecasting up to T=52.0.
# Plot both filtering and forecasting results
plot_filter_then_forecast_state_results(
data={
"states": test_states,
"emissions": test_emissions,
"t_emissions": test_t_emissions,
},
results={
"filtered": tree_to_dict(true_test_enkf_filtered),
"forecasted": tree_to_dict(true_test_enkf_forecasted),
"start_idx_filter": start_idx_filter,
"stop_idx_filter": stop_idx_filter,
"start_idx_forecast": start_idx_forecast,
"stop_idx_forecast": stop_idx_forecast,
},
plot_only_filter_forecast_window=True,
results_file=None, # Plot here, not save
filter_info=enkf_info,
plot_uncertainty=True,
plot_observations=True,
plot_mse=False,
)
Filtering and Forecasting with the learned model¶
print(
"Running filtering with {filter_name} with Learned model from T={T0} up to T={T_filter_end} and forecasting up to T={T_forecast_end}.".format(
filter_name=enkf_info["name"]
if enkf_info is not None and "name" in enkf_info
else "the specified filter",
T0=T0,
T_filter_end=T_filter_end,
T_forecast_end=T_forecast_end,
)
)
# Perform filtering and forecasting
(
test_enkf_filtered,
test_enkf_forecasted,
start_idx_filter,
stop_idx_filter,
start_idx_forecast,
stop_idx_forecast,
) = filter_and_forecast(
model_params=sgd_fitted_params,
filter_hyperparams=enkf_hyperparams,
t_emissions=test_t_emissions,
emissions=test_emissions,
T0=T0,
T_filter_end=T_filter_end,
T_forecast_end=T_forecast_end,
)
Running filtering with EnKF (1st order) with Learned model from T=45.0 up to T=50.0 and forecasting up to T=52.0.
# Plot both filtering and forecasting results
plot_filter_then_forecast_state_results(
data={
"states": test_states,
"emissions": test_emissions,
"t_emissions": test_t_emissions,
},
results={
"filtered": tree_to_dict(test_enkf_filtered),
"forecasted": tree_to_dict(test_enkf_forecasted),
"start_idx_filter": start_idx_filter,
"stop_idx_filter": stop_idx_filter,
"start_idx_forecast": start_idx_forecast,
"stop_idx_forecast": stop_idx_forecast,
},
plot_only_filter_forecast_window=True,
results_file=None, # Plot here, not save
filter_info=enkf_info,
plot_uncertainty=True,
plot_observations=True,
plot_mse=False,
)
Lorenz 63's chaotic behavior: differences in true Vs learned model forecasting¶
Next, we will look at the long-term statistical behavior of the learned vs true models by simulating a long trajectory from each and plotting histograms and others!!
# Long time emissions for long simulation: training data is too short to see true attractor
long_t_emissions = jnp.arange(start=0.0, stop=1e4, step=0.01).reshape(-1, 1)
# Long simulation with true parameters for comparison
true_long_states, true_long_emissions = true_model.sample(
params=true_params,
key=next(keys),
num_timesteps=len(long_t_emissions),
t_emissions=long_t_emissions,
transition_type="path",
)
Sampling from SDE solver path: this may be an unnecessarily poor approximation if you're simulating from a linear SDE. It is an appropriate choice for non-linear SDEs. CD-NLGSSM Sampling from continuous-discrete non-linear Gaussian SSM path
# Long simulation with learned parameters
fitted_long_states, fitted_long_emissions = learnable_model.sample(
params=sgd_fitted_params,
key=next(keys),
num_timesteps=len(long_t_emissions),
t_emissions=long_t_emissions,
transition_type="path",
)
Sampling from SDE solver path: this may be an unnecessarily poor approximation if you're simulating from a linear SDE. It is an appropriate choice for non-linear SDEs. CD-NLGSSM Sampling from continuous-discrete non-linear Gaussian SSM path
Chaos dynamics plotting¶
# Let's plot some analyses: Note that we don't expect the learned 2nd and 3rd state variables to match-up at all to the true ones.
from cd_dynamax.src.utils.plotting_chaos_utils import analyze_chaotic_dynamics
analyze_chaotic_dynamics(
true_long_states,
fitted_long_states,
t=long_t_emissions.squeeze(),
drift_true=true_params.dynamics.drift.f,
drift_learn=sgd_fitted_params.dynamics.drift.f,
burnin_frac=0.5,
)
Estimated Largest Lyapunov Exponent: True : 1.069 Learned: 0.793
Estimated Correlation Dimension D2: True : 1.601 Learned : 1.479 Estimated Correlation Dimensions: True D2=1.601, Learned D2=1.479