Diffrax Utilities
Diffrax-related helper functions.
diffrax_utils
¶
adjust_rhs(x, rhs, lower_bound=-100, upper_bound=100, lower_bound_derivative=-1000, upper_bound_derivative=1000)
¶
Adjust the right-hand side of an ODE/SDE to ensure that the state remains within the bounds (default [-100, 100]) and the derivative remains within the bounds (default [-1000, 1000]). If any of the state variables violate the constraints, set rhs = -x for all variables.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Current state (jnp.ndarray) |
required | |
rhs
|
Right-hand side of the ODE/SDE (jnp.ndarray) |
required | |
lower_bound
|
Lower bound for the state (float) |
-100
|
|
upper_bound
|
Upper bound for the state (float) |
100
|
|
lower_bound_derivative
|
Lower bound for the derivative (float) |
-1000
|
|
upper_bound_derivative
|
Upper bound for the derivative (float) |
1000
|
diffeqsolve(drift, t0: float, t1: float, y0: jnp.ndarray, reverse: bool = False, args=None, solver: dfx.AbstractSolver = None, stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(), adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(), dt0: float = 0.01, tol_vbt: float | None = None, max_steps: int = 100000.0, diffusion=None, key=None, debug=DEBUG, **kwargs) -> jnp.ndarray
¶
Choosing solvers and adjoints based on diffrax website's recommendation for training neural ODEs. See: https://docs.kidger.site/diffrax/usage/how-to-choose-a-solver/ See: https://docs.kidger.site/diffrax/api/adjoints/
Note that choosing RecursionCheckpointAdjoint requires usage of reverse-mode auto-differentiation.
Can use DirectAdjoint for flexible forward-mode + reverse-mode auto-differentiation.
Defaults are chosen to be decent low-cost options for forward solves and backpropagated gradients.
If you want high-fidelity solutions (and their gradients), it is recommended
- for ODEs: choose a higher-order solver (Tsit5) and an adaptive stepsize controller (PIDController).
- for SDEs: follow diffrax website advice (definitely can choose dt0 very small with constant stepsize controller).
Things to pay attention to (that we have incomplete understanding of):
- checkpoints in RecursiveCheckpointAdjoint: this is used to save memory during backpropagation.
- max_steps: reducing this can speed things up. But it is also used to set default number of checkpoints.
... unclear the optimal way to set these parameters.