Universal Differential Equations¶
A Universal Differential Equation (UDE) combines a mechanistic ODE with a neural network to learn unknown or partially-known dynamics from data. The ODE encodes prior biological knowledge; the neural network (called the NDE component) learns residual dynamics the ODE cannot explain on its own.
This notebook demonstrates UDE training on a Lotka-Volterra predator-prey system with an unknown external forcing term.
import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import pandas as pd
from example_models import get_lotka_volterra
from mxlpy import plot
from mxlpy.jax.models import Node, Ode, Ude
from mxlpy.jax.train import IntegrationSettings, train
lv = Ode.from_mxlpy(get_lotka_volterra())
ts = jnp.linspace(0, 50, 100)
y0 = jnp.array([10.0, 10.0])
ys = lv.integrate(ts, y0)
fig, ax = plt.subplots()
ax.plot(ts, ys)
ax.legend(["prey", "pred"])
plt.show()
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
Baseline: mechanistic ODE¶
The Lotka-Volterra system describes predator-prey dynamics with four parameters (a, b, c, d). This serves as the mechanistic prior — the part of the system we already understand and can encode explicitly.
ys_forced = (ys.T + 0.14 * jnp.sin(ts)).T
fig, ax = plt.subplots()
ax.plot(ts, ys_forced)
ax.legend(["prey", "pred"])
plt.show()
Training the UDE¶
We construct a Ude by combining the mechanistic Ode with a small neural network (Node). The op="+" means the NDE output is added to the ODE right-hand side at each time step.
Training uses train, which freezes the ODE parameters and only optimises the neural network weights. A curriculum schedule gradually increases the proportion of the trajectory used for loss computation, which helps avoid poor local minima early in training.
ude, losses = train(
Ude(
ode=lv,
nn=Node(
n_obs=2,
width=8,
depth=3,
key=jax.random.PRNGKey(42),
out_scale=jnp.array([0.1]),
),
op="+",
),
ts=ts,
ys=ys_forced,
training_steps=[
(500, 0.2),
(500, 0.5),
(10_000, 1.0), # in practice longer, shortened for docs
],
avg_every=50,
integration_settings=IntegrationSettings(
max_steps=8192,
method=diffrax.Tsit5,
),
optim=optax.adabelief(learning_rate=2e-4),
)
ys_pred = ude.integrate(ts, y0)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), layout="constrained")
# ax1
ax1.plot(ts, ys_forced)
plot.reset_prop_cycle(ax=ax1)
ax1.plot(ts, ys_pred, ls="dashed")
ax1.legend(["prey", "pred"])
# ax2
pd.Series(losses, dtype=float)[12:].plot(ax=ax2)
plt.show()
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[4], line 38 34 ax1.plot(ts, ys_pred, ls="dashed") 35 ax1.legend(["prey", "pred"]) 36 37 # ax2 ---> 38 pd.Series(losses, dtype=float)[12:].plot(ax=ax2) 39 plt.show() File ~/work/MxlPy/MxlPy/.venv/lib/python3.13/site-packages/pandas/core/series.py:514, in Series.__init__(self, data, index, dtype, name, copy) 512 data = data.copy(deep=True) 513 else: --> 514 data = sanitize_array(data, index, dtype, copy) 515 data = SingleBlockManager.from_array(data, index, refs=refs) 517 NDFrame.__init__(self, data) File ~/work/MxlPy/MxlPy/.venv/lib/python3.13/site-packages/pandas/core/construction.py:665, in sanitize_array(data, index, dtype, copy, allow_2d) 662 subarr = np.array([], dtype=np.float64) 664 elif dtype is not None: --> 665 subarr = _try_cast(data, dtype, copy) 667 else: 668 subarr = maybe_convert_platform(data) File ~/work/MxlPy/MxlPy/.venv/lib/python3.13/site-packages/pandas/core/construction.py:848, in _try_cast(arr, dtype, copy) 846 subarr = maybe_cast_to_integer_array(arr, dtype) 847 elif not copy: --> 848 subarr = np.asarray(arr, dtype=dtype) 849 else: 850 subarr = np.array(arr, dtype=dtype, copy=copy) ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (3,) + inhomogeneous part.
Separate the contribution of each model part¶
After training we can evaluate the ODE and NDE components independently along the predicted trajectory. jax.vmap maps each component over all time points in one call. This decomposition reveals what the neural network actually learned — ideally it should recover the unknown forcing signal.
rhs_ode = jax.vmap(ude.ode, in_axes=(0, 0, None))(ts, ys_pred, jnp.array([]))
rhs_nde = jax.vmap(ude.nn, in_axes=(0, 0, None))(ts, ys_pred, jnp.array([]))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), layout="constrained")
ax1.set(title="ODE")
ax2.set(title="NDE")
ax1.plot(ts, rhs_ode)
ax2.plot(ts, rhs_nde)
for ax in (ax1, ax2):
ax.legend(["dpreydt", "dpreddt"])
The ODE panel shows the mechanistic dynamics encoded in the prior model. The NDE panel shows the learned correction — the neural network's contribution that accounts for dynamics the ODE could not explain. A well-trained UDE produces an NDE signal that approximates the unknown forcing term.