import matplotlib.pyplot as plt
from mxlpy import Model, Simulator, fns, plot, unwrap
SIR and SIRD models¶
In epidemiology, compartmental models are often applied to model infectious diseases.
Common compartments include ones for Susceptible, Infectious and Recovered individuals, which are included in the SIR model.
In this model there are two transitions (reactions
in mxlpy) between those compartments.
- susceptible individuals can become infected by contact with an infected person: $\beta S I$
- infected people can recover with a rate proportional: $\gamma I$
These transitions are scaled by the average number of contacts per person per time ($\beta$) and the inverse of the average infection time $\gamma$.
def sir() -> Model:
return (
Model()
.add_variables({"s": 0.9, "i": 0.1, "r": 0.0})
.add_parameters({"beta": 0.2, "gamma": 0.1})
.add_reaction(
"infection",
fns.mass_action_2s,
args=["s", "i", "beta"],
stoichiometry={"s": -1, "i": 1},
)
.add_reaction(
"recovery",
fns.mass_action_1s,
args=["i", "gamma"],
stoichiometry={"i": -1, "r": 1},
)
)
res = unwrap(Simulator(sir()).simulate(100).get_result())
fig, (ax1, ax2) = plot.two_axes(figsize=(7.5, 3.5))
_ = plot.lines(res.variables, ax=ax1)
_ = plot.lines(res.fluxes, ax=ax2)
ax1.set(xlabel="Time / a.u.", ylabel="Relative Population")
ax2.set(xlabel="Time / a.u.", ylabel="Rate of change")
plt.show()
We can now easily extend the original model by adding an additional compartment and transition.
The SIRD model for example differentiates between recovered and deceased individuals.
So there exists an additional compartment for deceased individuals and a transition for infected to deceased individuals, proportional to the amount of infected individuals and the mortality $\mu$ of the infection: $\mu I$
def sird() -> Model:
return (
sir()
.add_variable("d", 0.0)
.add_parameter("mu", 0.01)
.add_reaction(
"death",
fns.mass_action_1s,
args=["i", "mu"],
stoichiometry={"i": -1, "d": 1},
)
)
res = unwrap(Simulator(sird()).simulate(100).get_result())
fig, (ax1, ax2) = plot.two_axes(figsize=(7.5, 3.5))
_ = plot.lines(res.variables, ax=ax1)
_ = plot.lines(res.fluxes, ax=ax2)
ax1.set(xlabel="Time / a.u.", ylabel="Relative Population")
ax2.set(xlabel="Time / a.u.", ylabel="Rate of change")
plt.show()