import aerosandbox.numpy as np
import casadi as _cas
from typing import Union, Callable, Tuple, Optional, List
from scipy import integrate
[docs]def quad(
func: Union[Callable, _cas.MX],
a: float,
b: float,
full_output: bool = False,
variable_of_integration: _cas.MX = None,
) -> Union[Tuple[float, float], Tuple[float, float, dict]]:
if np.is_casadi_type(func):
all_vars = _cas.symvar(func) # All variables found in the expression graph
if variable_of_integration is None:
if not len(all_vars) == 1:
raise ValueError(
f"`func` must be a function of one variable, or you must specify the `variable_of_integration`.\n"
f"Currently, it is a function of: {all_vars}"
)
variable_of_integration = all_vars[0]
parameters = [
var for var in all_vars if not _cas.is_equal(var, variable_of_integration)
]
integrator = _cas.integrator(
"integrator",
"cvodes",
{
"x": _cas.MX.sym("dummy_variable"),
"p": _cas.vertcat(*parameters),
"t": variable_of_integration,
"ode": func,
},
a, # t0
b, # tf
{ # Options
"abstol": 1e-8,
"reltol": 1e-6,
},
)
res = integrator(
x0=0,
p=_cas.vertcat(*parameters),
)
tol = 1e-8
if full_output:
return res["xf"], tol, res
else:
return res["xf"], tol
else:
return integrate.quad(
func=func,
a=a,
b=b,
full_output=full_output,
)
[docs]def solve_ivp(
fun: Union[Callable, _cas.MX],
t_span: Tuple[float, float],
y0: Union[np.ndarray, _cas.MX],
method: str = "RK45",
t_eval: Union[np.ndarray, _cas.MX] = None,
dense_output: bool = False,
events: Union[Callable, List[Callable]] = None,
vectorized: bool = False,
args: Optional[Tuple] = None,
t_variable: _cas.MX = None,
y_variables: Union[_cas.MX, Tuple[_cas.MX]] = None,
**options,
):
# Determine which backend to use
if np.is_casadi_type(fun, recursive=False):
backend = "casadi_expr"
else:
try:
f = np.array(fun(t_span[0], y0))
if np.is_casadi_type(f):
backend = "casadi_func"
else:
try:
np.asanyarray(f)
except ValueError:
raise ValueError(
"If `fun` is not a Callable, it must be a CasADi expression."
)
backend = "numpy_func"
except TypeError:
raise TypeError(
"If `fun` is not a Callable, it must be a CasADi expression."
)
# Do some checks
if backend == "casadi_func" or backend == "numpy_func":
if t_variable is not None:
raise ValueError(
"If `fun` is a Callable, `t_variable` must be None (as it's implied)."
)
if y_variables is not None:
raise ValueError(
"If `fun` is a Callable, `y_variables` must be None (as they're implied)."
)
if backend == "casadi_expr":
if t_variable is None:
raise ValueError(
"If `fun` is a CasADi expression, `t_variable` must be specified (and the y_variables are inferred)."
)
all_vars = _cas.symvar(fun) # All variables found in the expression graph
# Determine y_variables by selecting all variables that are not t_variable
if y_variables is None:
y_variables = np.array(
[var for var in all_vars if not _cas.is_equal(var, t_variable)]
)
if backend == "numpy_func":
return integrate.solve_ivp(
fun=fun,
t_span=t_span,
y0=y0,
method=method,
t_eval=t_eval,
dense_output=dense_output,
events=events,
vectorized=vectorized,
args=args,
**options,
)
elif backend == "casadi_func" or backend == "casadi_expr":
# Exception on non-implemented options
if dense_output:
raise NotImplementedError(
"dense_output is not yet implemented for CasADi functions."
)
if events is not None:
raise NotImplementedError(
"Events are not yet implemented for CasADi functions."
)
if args:
raise NotImplementedError(
"args are not yet implemented for CasADi functions."
)
if not np.is_casadi_type(y0, recursive=False):
y0 = _cas.vertcat(*y0)
if backend == "casadi_func":
t_variable = _cas.MX.sym("t")
y_variables = _cas.MX.sym("y", y0.shape[0], y0.shape[1])
fun = np.array(fun(t_variable, y_variables))
"""
At this point:
* `fun` is a CasADi expression (cas.MX)
* `t_variable` is a CasADi variable (cas.MX)
* `y_variables` is a CasADi variable (cas.MX), possibly a vector of variables
"""
t0 = t_span[0]
tf = t_span[1]
# sim_time = t0 + (tf - t0) * t_variable
ode = _cas.substitute(
fun,
t_variable, # from normalized time
# (t_variable - t0) / (tf - t0), # to real time
t0 + (tf - t0) * t_variable, # to real time
) * (tf - t0)
# Find parameters by finding all variables in the expression graph that are not t_variable or y_variables
all_vars = _cas.symvar(ode) # All variables found in the expression graph
def variable_is_t_or_y(var):
return (
_cas.is_equal(var, t_variable)
or _cas.is_equal(var, y_variables)
or any(
[
_cas.is_equal(var, y_variables[i])
for i in range(np.prod(y_variables.shape))
]
)
)
parameters = _cas.vertcat(
*[var for var in all_vars if not variable_is_t_or_y(var)]
)
simtime_eval = np.linspace(0, 1, 100)
# Define the integrator
integrator = _cas.integrator(
"integrator",
"cvodes",
# 'idas',
{
"x": y_variables,
"p": parameters,
"t": t_variable,
"ode": ode,
"quad": 1,
},
0,
simtime_eval,
{ # Options
"abstol": 1e-8,
"reltol": 1e-6,
},
)
res = integrator(
x0=y0,
p=parameters,
)
return integrate._ivp.ivp.OdeResult(
t=t0 + (tf - t0) * res["qf"],
y=res["xf"],
t_events=None,
y_events=None,
nfev=0,
njev=0,
nlu=0,
status=0,
message="",
success=True,
sol=None,
)
else:
raise ValueError(f"Invalid backend: {backend}")
if __name__ == "__main__":
# t = cas.MX.sym("t")
# print(
# quad(
# func=t ** 2,
# a=0,
# b=1,
# )
# )
[docs] def lotkavolterra_func(t, z):
a, b, c, d = 1.5, 1, 3, 1
z = _cas.MX(z)
x = z[0]
y = z[1]
return [a * x - b * x * y, -c * y + d * x * y]
t_eval = np.linspace(0, 15, 3000)
tf = _cas.MX.sym("tf")
# t_eval = np.linspace(0, tf, 100)
sol = solve_ivp(
lotkavolterra_func,
t_span=(t_eval[0], t_eval[-1]),
# t_eval=t_eval,
y0=[10, 5],
)
z = sol.y
import matplotlib.pyplot as plt
plt.plot(
_cas.evalf(_cas.substitute(sol.t.T, tf, 15)),
_cas.evalf(_cas.substitute(sol.y.T, tf, 15)),
)
plt.xlabel("t")
plt.legend(["x", "y"], shadow=True)
plt.title("Lotka-Volterra System")
plt.show()
t = _cas.MX.sym("t")
m = _cas.MX.sym("m")
n = _cas.MX.sym("n")
a, b, c, d = 1.5, 1, 3, 1
lotkavolterra_expr = np.array(
[
a * m - b * m * n,
-c * n + d * m * n,
]
)
sol = solve_ivp(
lotkavolterra_expr,
t_span=(t_eval[0], t_eval[-1]),
t_eval=t_eval,
y0=[10, 5],
t_variable=t,
# y_variables=[m, n],
)
plt.plot(
_cas.evalf(_cas.substitute(sol.t.T, tf, 15)),
_cas.evalf(_cas.substitute(sol.y.T, tf, 15)),
)
plt.xlabel("t")
plt.legend(["x", "y"], shadow=True)
plt.title("Lotka-Volterra System")
plt.show()