Fitting
from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import minimize
from example_models import get_linear_chain_2v
from mxlpy import Simulator, fit, make_protocol, plot, unwrap
Fitting¶
Almost every model at some point needs to be fitted to experimental data to be validated.
mxlpy offers highly customisable routines for fitting either time series or steady-states.
For this tutorial we are going to use the fit
module to optimise our parameter values and the plot
module to plot some results.
Let's get started!
Creating synthetic data¶
Normally, you would fit your model to experimental data.
Here, for the sake of simplicity, we will generate some synthetic data.
Checkout the basics tutorial if you need a refresher on building and simulating models.
# As a small trick, let's define a variable for the model function
# That way, we can re-use it all over the file and easily replace
# it with another model
model_fn = get_linear_chain_2v
res = unwrap(
Simulator(model_fn())
.update_parameters({"k1": 1.0, "k2": 2.0, "k3": 1.0})
.simulate_time_course(np.linspace(0, 10, 101))
.get_result()
).get_combined()
fig, ax = plot.lines(res)
ax.set(xlabel="time / a.u.", ylabel="Conc. & Flux / a.u.")
plt.show()
Steady-states¶
For the steady-state fit we need two inputs:
- the steady state data, which we supply as a
pandas.Series
- an initial parameter guess
The fitting routine will compare all data contained in that series to the model output.
Note that the data both contains concentrations and fluxes!
data = res.iloc[-1]
data.head()
x 0.500000 y 1.000045 v1 1.000000 v2 1.000000 v3 1.000045 Name: 10.0, dtype: float64
fit_result = unwrap(
fit.steady_state(
model_fn(),
p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
data=res.iloc[-1],
)
)
fit_result.best_pars
{'k1': np.float64(1.000015202475239), 'k2': np.float64(2.0000309249184327), 'k3': np.float64(0.9999697802169417)}
If only some of the data is required, you can use a subset of it.
The fitting routine will only try to fit concentrations and fluxes contained in that series.
fit_result = unwrap(
fit.steady_state(
model_fn(),
p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
data=data.loc[["x", "y"]],
)
)
fit_result.best_pars
{'k1': np.float64(0.9829433889293213), 'k2': np.float64(1.9658867533095319), 'k3': np.float64(0.9828987681888647)}
Time course¶
For the time course fit we need again need two inputs
- the time course data, which we supply as a
pandas.DataFrame
- an initial parameter guess
The fitting routine will create data at every time points specified in the DataFrame
and compare all of them.
Other than that, the same rules of the steady-state fitting apply.
fit_result = unwrap(
fit.time_course(
model_fn(),
p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
data=res,
)
)
fit_result.best_pars
{'k1': np.float64(0.9999999948101959), 'k2': np.float64(1.9999999615864308), 'k3': np.float64(0.9999999922253802)}
Protcol time courses¶
Normally, you would fit your model to experimental data.
Here, again, for the sake of simplicity, we will generate some synthetic data.
protocol = make_protocol(
[
(1, {"k1": 1.0}),
(1, {"k1": 2.0}),
(1, {"k1": 1.0}),
]
)
res_protocol = unwrap(
Simulator(model_fn())
.update_parameters({"k1": 1.0, "k2": 2.0, "k3": 1.0})
.simulate_protocol(
protocol,
time_points_per_step=10,
)
.get_result()
).get_combined()
fig, ax = plot.lines(res_protocol)
ax.set(xlabel="time / a.u.", ylabel="Conc. & Flux / a.u.")
plt.show()
For the protocol time course fit we need three inputs
- an initial parameter guess
- the time course data, which we supply as a
pandas.DataFrame
- the protocol, which we supply as a
pandas.DataFrame
Note that the parameter given by the protocol cannot be fitted anymore
fit_result = unwrap(
fit.protocol_time_course(
model_fn(),
p0={"k2": 1.87, "k3": 1.093}, # note that k1 is given by the protocol
data=res_protocol,
protocol=protocol,
)
)
fit_result.best_pars
{'k2': np.float64(1.9999999995061162), 'k3': np.float64(1.000000000653317)}
First finish line
With that you now know most of what you will need from a day-to-day basis about fitting in mxlpy.Congratulations!
Advanced topics / customisation¶
All fitting routines internally are build in a way that they will call a tree of functions.
minimize_fn
residual_fn
integrator
loss_fn
You can therefore use dependency injection to overwrite the minimisation function, the loss function, the residual function and the integrator if need be.
from functools import partial
from typing import TYPE_CHECKING, cast
from mxlpy.fit import LossFn
from mxlpy.integrators import Scipy
if TYPE_CHECKING:
import pandas as pd
from mxlpy.fit import ResidualFn
from mxlpy.model import Model
from mxlpy.types import Array, IntegratorType
Custom loss function¶
You can change the loss function that is being passed to the minimsation function using the loss_fn
keyword.
Depending on the use case (time course vs steady state) this function will be passed two pandas DataFrame
s or Series
.
def mean_absolute_error(
x: pd.DataFrame | pd.Series,
y: pd.DataFrame | pd.Series,
) -> float:
"""Mean absolute error between two dataframes."""
return cast(float, np.mean(np.abs(x - y)))
fit_result = fit.time_course(
model_fn(),
p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
data=res,
loss_fn=mean_absolute_error,
)
fit_result.best_pars
{'k1': np.float64(0.9999999602399153), 'k2': np.float64(1.999999919411726), 'k3': np.float64(0.9999999475129249)}
Custom integrator¶
You can change the default integrator to an integrator of your choice by partially application of the class of any of the existing ones.
Here, for example, we choose the Scipy
solver suite and set the default relative and absolute tolerances to 1e-6
respectively.
fit_result = unwrap(
fit.time_course(
model_fn(),
p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
data=res,
integrator=partial(Scipy, rtol=1e-6, atol=1e-6),
)
)
fit_result.best_pars
{'k1': np.float64(0.9999999887485048), 'k2': np.float64(2.0000000418475845), 'k3': np.float64(0.9999999287055419)}
Custom minimisation¶
You can change the default minimize_fn
from L-BFGS-B
to any other function that takes a ResidualFn
and minimizes it.
from mxlpy.fit import Bounds, MinResult
def nelder_mead(
residual_fn: ResidualFn,
p0: dict[str, float],
bounds: Bounds,
) -> MinResult | None:
res = minimize(
residual_fn,
x0=list(p0.values()),
bounds=[bounds.get(name, (1e-6, 1e6)) for name in p0],
method="Nelder-Mead",
)
if res.success:
return MinResult(
parameters=dict(
zip(
p0,
res.x,
strict=True,
)
),
residual=res.fun,
)
return None
fit_result = unwrap(
fit.time_course(
model_fn(),
p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
data=res,
minimize_fn=nelder_mead,
)
)
fit_result.best_pars
{'k1': np.float64(1.0000044128393781), 'k2': np.float64(1.9999586259064912), 'k3': np.float64(1.000005095705572)}
Custom residual function¶
You can change the residual function to include further behaviour.
The barebones implementation is given below
def time_course_residual(
par_values: Array,
# This will be filled out by partial
par_names: list[str],
data: pd.DataFrame,
model: Model,
y0: dict[str, float] | None,
integrator: IntegratorType,
loss_fn: LossFn,
) -> float:
"""Calculate residual error between model time course and experimental data.
Args:
par_values: Parameter values to test
data: Experimental time course data
model: Model instance to simulate
y0: Initial conditions
par_names: Names of parameters being fit
integrator: ODE integrator class to use
loss_fn: Loss function to use for residual calculation
Returns:
float: Root mean square error between model and data
"""
res = (
Simulator(
model.update_parameters(dict(zip(par_names, par_values, strict=True))),
y0=y0,
integrator=integrator,
)
.simulate_time_course(cast(list, data.index))
.get_result()
)
if res is None:
return cast(float, np.inf)
results_ss = res.get_combined()
return loss_fn(
results_ss.loc[:, cast(list, data.columns)],
data,
)
fit_result = unwrap(
fit.time_course(
model_fn(),
p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
data=res,
residual_fn=time_course_residual,
)
)
fit_result.best_pars
{'k1': np.float64(0.9999999948101959), 'k2': np.float64(1.9999999615864308), 'k3': np.float64(0.9999999922253802)}