Skip to content

Commit 30d09de

Browse files
authored
feat: MTK Jacobian for CoupledDEs (#229)
* feat: MTK jacobian for CoupledDEs * add CoupledSDE to CoreDynamicalSystem * implement comments * add docs * don't use `dynamical_rule` but prob * make `dynamic_rule` always take the nested f * simplify `dynamic_rule` to not directly access nested function f
1 parent 7b7ac7a commit 30d09de

File tree

4 files changed

+67
-20
lines changed

4 files changed

+67
-20
lines changed

ext/src/CoupledSDEs.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ function DynamicalSystemsBase.CoupledSDEs(
118118
noise_process=nothing,
119119
seed=UInt64(0)
120120
)
121-
return CoupledSDEs(dynamic_rule(ds), current_state(ds), p;
121+
prob = referrenced_sciml_prob(ds)
122+
# we want the symbolic jacobian to be transfered over
123+
# dynamic_rule(ds) takes the deepest nested f wich does not have a jac field
124+
return CoupledSDEs(prob.f, current_state(ds), p;
122125
g, noise_strength, covariance, diffeq, noise_prototype, noise_process, seed)
123126
end
124127

@@ -130,9 +133,11 @@ deterministic part of `ds`.
130133
"""
131134
function DynamicalSystemsBase.CoupledODEs(
132135
sys::CoupledSDEs; diffeq=DEFAULT_DIFFEQ, t0=0.0)
136+
prob = referrenced_sciml_prob(sys)
137+
# we want the symbolic jacobian to be transfered over
138+
# dynamic_rule(ds) takes the deepest nested f wich does not have a jac field
133139
return CoupledODEs(
134-
dynamic_rule(sys), SVector{length(sys.integ.u)}(sys.integ.u), sys.p0;
135-
diffeq=diffeq, t0=t0
140+
prob.f, SVector{length(sys.integ.u)}(sys.integ.u), sys.p0; diffeq=diffeq, t0=t0
136141
)
137142
end
138143

@@ -155,16 +160,6 @@ StateSpaceSets.dimension(::CoupledSDEs{IIP,D}) where {IIP,D} = D
155160
DynamicalSystemsBase.current_state(ds::CoupledSDEs) = current_state(ds.integ)
156161
DynamicalSystemsBase.isdeterministic(ds::CoupledSDEs) = false
157162

158-
function DynamicalSystemsBase.dynamic_rule(ds::CoupledSDEs)
159-
# with remake it can happen that we have nested SDEFunctions
160-
# sciml integrator deals with this internally well
161-
f = ds.integ.f
162-
while hasfield(typeof(f), :f)
163-
f = f.f
164-
end
165-
return f
166-
end
167-
168163
function DynamicalSystemsBase.set_state!(ds::CoupledSDEs, u::AbstractArray)
169164
(set_state!(ds.integ, u); ds)
170165
end
@@ -224,6 +219,8 @@ function covariance_matrix(ds::CoupledSDEs)::AbstractMatrix
224219
(A == nothing) ? nothing : A * A'
225220
end
226221

222+
jacobian(sde::CoupledSDEs) = DynamicalSystemsBase.jacobian(CoupledODEs(sde))
223+
227224
###########################################################################################
228225
# Utilities
229226
###########################################################################################

src/core_systems/jacobian.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import ForwardDiff
66
jacobian(ds::CoreDynamicalSystem)
77
88
Construct the Jacobian rule for the dynamical system `ds`.
9-
This is done via automatic differentiation using module
9+
If the system already has a Jacobian rule constructed via ModelingToolkit it returns this,
10+
otherwise it constructs the Jacobian rule with automatic differentiation using module
1011
[`ForwardDiff`](https://github.com/JuliaDiff/ForwardDiff.jl).
1112
1213
## Description
@@ -19,7 +20,17 @@ For in-place systems, `jacobian` returns the Jacobian rule as a function
1920
at the state `u`, parameters `p` and time `t` and save the result in `J0`.
2021
"""
2122
function jacobian(ds::CoreDynamicalSystem{IIP}) where {IIP}
22-
_jacobian(ds, Val{IIP}())
23+
if ds isa ContinuousTimeDynamicalSystem
24+
prob = referrenced_sciml_prob(ds)
25+
if prob.f isa SciMLBase.AbstractDiffEqFunction && !isnothing(prob.f.jac)
26+
jac = prob.f.jac
27+
else
28+
jac = _jacobian(ds, Val{IIP}())
29+
end
30+
else
31+
jac = _jacobian(ds, Val{IIP}())
32+
end
33+
return jac
2334
end
2435

2536
function _jacobian(ds, ::Val{true})
@@ -43,4 +54,6 @@ function _jacobian(ds, ::Val{false})
4354
f = dynamic_rule(ds)
4455
Jf = (u, p, t = 0) -> ForwardDiff.jacobian((x) -> f(x, p, t), u)
4556
return Jf
46-
end
57+
end
58+
59+
jacobian(ds::CoupledSDEs) = jacobian(CoupledODEs(ds))

test/jacobian.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function iip(du, u, p, t)
99
return nothing
1010
end
1111

12-
#%%
12+
1313
@testset "IDT=$(IDT), IIP=$(IIP)" for IDT in (true, false), IIP in (false, true)
1414
SystemType = IDT ? DeterministicIteratedMap : CoupledODEs
1515
rule = IIP ? iip : oop
@@ -26,4 +26,41 @@ end
2626
else
2727
@test J(current_state(ds), current_parameters(ds), 0.0) == result
2828
end
29-
end
29+
end
30+
31+
@testset "MTK Jacobian" begin
32+
using ModelingToolkit
33+
using ModelingToolkit: Num, RuntimeGeneratedFunctions.RuntimeGeneratedFunction
34+
using DynamicalSystemsBase: SciMLBase
35+
@independent_variables t
36+
@variables u(t)[1:2]
37+
D = Differential(t)
38+
39+
eqs = [D(u[1]) ~ 3.0 * u[1],
40+
D(u[2]) ~ -3.0 * u[2]]
41+
@named sys = ODESystem(eqs, t)
42+
sys = structural_simplify(sys)
43+
44+
prob = ODEProblem(sys, [1.0, 1.0], (0.0, 1.0); jac=true)
45+
ode = CoupledODEs(prob)
46+
47+
jac = jacobian(ode)
48+
@test jac.jac_oop isa RuntimeGeneratedFunction
49+
@test jac([1.0, 1.0], [], 0.0) == [3 0;0 -3]
50+
51+
@testset "CoupledSDEs" begin
52+
# just to check if MTK @brownian does not give any problems
53+
using StochasticDiffEq
54+
@brownian β
55+
eqs = [D(u[1]) ~ 3.0 * u[1]+ β,
56+
D(u[2]) ~ -3.0 * u[2] + β]
57+
@mtkbuild sys = System(eqs, t)
58+
59+
prob = SDEProblem(sys, [1.0, 1.0], (0.0, 1.0), jac=true)
60+
sde = CoupledSDEs(prob)
61+
62+
jac = jacobian(sde)
63+
@test jac.jac_oop isa RuntimeGeneratedFunction
64+
@test jac([1.0, 1.0], [], 0.0) == [3 0;0 -3]
65+
end
66+
end

test/stochastic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ end
8181

8282
# CoupledODEs creation
8383
ds = CoupledODEs(lorenz_oop)
84-
@test dynamic_rule(ds) == lorenz_rule
84+
@test dynamic_rule(ds).f == lorenz_rule
8585
@test ds.integ.alg isa Tsit5
8686
test_dynamical_system(ds, u0, p0; idt = false, iip = false)
8787
# and back
8888
sde = CoupledSDEs(ds, p0)
89-
@test dynamic_rule(sde) == lorenz_rule
89+
@test dynamic_rule(sde).f.f == lorenz_rule
9090
@test sde.integ.alg isa SOSRA
9191
end
9292

0 commit comments

Comments
 (0)