1- using OrdinaryDiffEq, StaticArrays
2- import OrdinaryDiffEq: ODEIntegrator, ODEProblem
3- using DiffEqBase: __init, ODEFunction
1+ using DiffEqBase, StaticArrays
2+ using DiffEqBase: __init, ODEFunction, AbstractODEIntegrator
43
54export CDS_KWARGS
65# ####################################################################################
76# Defaults #
87# ####################################################################################
9- const DEFAULT_SOLVER = Vern9 ()
10- const DEFAULT_DIFFEQ_KWARGS = (abstol = 1e-9 ,
11- reltol = 1e-9 , maxiters = typemax (Int))
8+ using SimpleDiffEq: SimpleATsit5
9+ const DEFAULT_SOLVER = SimpleATsit5 ()
10+ const DEFAULT_DIFFEQ_KWARGS = (abstol = 1e-6 ,
11+ reltol = 1e-6 , maxiters = typemax (Int))
1212
1313const CDS_KWARGS = (alg = DEFAULT_SOLVER, DEFAULT_DIFFEQ_KWARGS... )
1414
@@ -22,26 +22,23 @@ function ContinuousDynamicalSystem(prob::ODEProblem, args...)
2222 t0 = prob. tspan[1 ])
2323end
2424
25- function ODEProblem (ds:: CDS{IIP} , tspan, args... ) where {IIP}
26- # when stable, do ODEFunction(ds.f; jac = ds.jacobian)
27- return ODEProblem {IIP} (ds. f, ds. u0, tspan, args... )
25+ function DiffEqBase. ODEProblem (ds:: CDS{IIP} , tspan, args... ) where {IIP}
26+ return ODEProblem {IIP} (ODEFunction (ds. f; jac = ds. jacobian), ds. u0, tspan, args... )
2827end
2928
3029# ####################################################################################
3130# Integrators #
3231# ####################################################################################
33- stateeltype (:: ODEIntegrator{Alg, S} ) where {Alg , S} = eltype (S)
34- stateeltype (:: ODEIntegrator{Alg , S} ) where {
35- Alg , S<: Vector{<:AbstractArray{T}} } where {T} = T
32+ stateeltype (:: AbstractODEIntegrator{A, IIP, S} ) where {A, IIP , S} = eltype (S)
33+ stateeltype (:: AbstractODEIntegrator{A, IIP , S} ) where {
34+ A, IIP , S<: Vector{<:AbstractArray{T}} } where {T} = T
3635
3736function integrator (ds:: CDS{iip} , u0 = ds. u0;
3837 tfinal = Inf , diffeq... ) where {iip}
3938
4039 u = safe_state_type (Val {iip} (), u0)
4140 prob = ODEProblem {iip} (ds. f, u, (ds. t0, typeof (ds. t0)(tfinal)), ds. p)
4241
43- (haskey (diffeq, :saveat ) && tfinal == Inf ) && error (" Infinite solving!" )
44-
4542 solver = _get_solver (diffeq)
4643 integ = __init (prob, solver; DEFAULT_DIFFEQ_KWARGS... ,
4744 save_everystep = false , diffeq... )
@@ -108,12 +105,12 @@ function create_parallel(ds::CDS{true}, states)
108105 return paralleleom, st
109106end
110107
111- const STIFFSOLVERS = (ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2,
112- GenericImplicitEuler,
113- GenericTrapezoid, SDIRK2, Kvaerno3, KenCarp3, Cash4, Hairer4, Hairer42, Kvaerno4,
114- KenCarp4, Kvaerno5, KenCarp5, Rosenbrock23,
115- Rosenbrock32, ROS3P, Rodas3, RosShamp4, Veldd4, Velds4, GRK4T,
116- GRK4A, Ros4LStab, Rodas4, Rodas42, Rodas4P)
108+ # const STIFFSOLVERS = (ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2,
109+ # GenericImplicitEuler,
110+ # GenericTrapezoid, SDIRK2, Kvaerno3, KenCarp3, Cash4, Hairer4, Hairer42, Kvaerno4,
111+ # KenCarp4, Kvaerno5, KenCarp5, Rosenbrock23,
112+ # Rosenbrock32, ROS3P, Rodas3, RosShamp4, Veldd4, Velds4, GRK4T,
113+ # GRK4A, Ros4LStab, Rodas4, Rodas42, Rodas4P)
117114
118115function parallel_integrator (ds:: CDS , states; diffeq... )
119116 peom, st = create_parallel (ds, states)
@@ -134,61 +131,74 @@ function trajectory(ds::ContinuousDynamicalSystem, T, u = ds.u0;
134131
135132 t0 = ds. t0
136133 tvec = (t0+ Ttr): dt: (T+ t0+ Ttr)
137- integ = integrator (ds, u; tfinal = t0 + Ttr + T, diffeq... , saveat = tvec)
138- solve! (integ)
139- return Dataset (integ. sol. u)
134+ sol = Vector {SVector{dimension(ds), stateeltype(ds)}} (undef, length (tvec))
135+ integ = integrator (ds, u; dt = dt, tfinal = tvec[end ]+ 2 dt, diffeq... )
136+ step! (integ, Ttr)
137+ for (i, t) in enumerate (tvec)
138+ while t > integ. t
139+ step! (integ)
140+ end
141+ if integ. tprev ≤ t ≤ integ. t
142+ sol[i] = integ (t)
143+ else
144+ error (" should be integ.tprev ≤ t ≤ integ.t" )
145+ end
146+ end
147+ return Dataset (sol)
140148end
141149
142150# ####################################################################################
143151# Get States #
144152# ####################################################################################
145- get_state (integ:: ODEIntegrator{Alg, S} ) where {Alg, S<: AbstractVector } = integ. u
146- get_state (integ:: ODEIntegrator{Alg, S} ) where {Alg, S<: AbstractMatrix } =
147- integ. u[:, 1 ]
148- get_state (integ:: ODEIntegrator{Alg, S} ) where {Alg, S<: Vector{<:AbstractVector} } =
153+ get_state (integ:: AbstractODEIntegrator{Alg, IIP, S} ) where {Alg, IIP, S<: AbstractVector } =
154+ integ. u
155+ get_state (integ:: AbstractODEIntegrator{Alg, IIP, S} ) where {Alg, IIP, S<: AbstractMatrix } =
156+ integ. u[:, 1 ]
157+ get_state (integ:: AbstractODEIntegrator{Alg, IIP, S} ) where {Alg, IIP, S<: Vector{<:AbstractVector} } =
149158 integ. u[1 ]
150- get_state (integ:: ODEIntegrator {Alg, S} , k:: Int ) where {
151- Alg, S <: Vector{<:AbstractVector} } = integ. u[k]
152- get_state (integ:: ODEIntegrator {Alg, S} , k:: Int ) where {Alg, S<: AbstractMatrix } =
159+ get_state (integ:: AbstractODEIntegrator {Alg, IIP, S} , k:: Int ) where {Alg, IIP, S <: Vector{<:AbstractVector} } =
160+ integ. u[k]
161+ get_state (integ:: AbstractODEIntegrator {Alg, IIP, S} , k:: Int ) where {Alg, IIP , S<: AbstractMatrix } =
153162 integ. u[:, k]
154163
155164function set_state! (
156- integ:: ODEIntegrator {Alg, S} , u:: AbstractVector , k:: Int = 1
157- ) where {Alg, S<: Vector{<:AbstractVector} }
165+ integ:: AbstractODEIntegrator {Alg, IIP , S} , u:: AbstractVector , k:: Int = 1
166+ ) where {Alg, IIP, S<: Vector{<:AbstractVector} }
158167 integ. u[k] = u
159168 u_modified! (integ, true )
160169end
161170function set_state! (
162- integ:: ODEIntegrator{Alg, S} , u:: AbstractVector ) where {Alg, S<: Matrix }
171+ integ:: AbstractODEIntegrator{Alg, IIP, S} , u:: AbstractVector
172+ ) where {Alg, IIP, S<: Matrix }
163173 integ. u[:, 1 ] .= u
164174 u_modified! (integ, true )
165175end
166176function set_state! (
167- integ:: ODEIntegrator {Alg, S} , u:: AbstractVector
168- ) where {Alg, S<: SMatrix{D, K} } where {D, K}
177+ integ:: AbstractODEIntegrator {Alg, IIP , S} , u:: AbstractVector
178+ ) where {Alg, IIP, S<: SMatrix{D, K} } where {D, K}
169179 integ. u = hcat (SVector {D} (u), integ. u[:, SVector {K-1} (2 : K... )])
170180 u_modified! (integ, true )
171181end
172182
173- get_deviations (integ:: ODEIntegrator {Alg, S} ) where {Alg, S<: Matrix } =
183+ get_deviations (integ:: AbstractODEIntegrator {Alg, IIP, S} ) where {Alg, IIP , S<: Matrix } =
174184 @view integ. u[:, 2 : end ]
175185
176186
177187@generated function get_deviations (
178- integ:: ODEIntegrator {Alg, S} ) where {Alg, S<: SMatrix{D,K} } where {D,K}
188+ integ:: AbstractODEIntegrator {Alg, IIP, S} ) where {Alg, IIP , S<: SMatrix{D,K} } where {D,K}
179189 gens = [:($ k) for k= 2 : K]
180190 quote
181191 sind = SVector {$(K-1)} ($ (gens... ))
182192 integ. u[:, sind]
183193 end
184194end
185195
186- set_deviations! (integ:: ODEIntegrator {Alg, S} , Q) where {Alg, S<: Matrix } =
196+ set_deviations! (integ:: AbstractODEIntegrator {Alg, IIP, S} , Q) where {Alg, IIP , S<: Matrix } =
187197 (integ. u[:, 2 : end ] .= Q; u_modified! (integ, true ))
188- set_deviations! (integ:: ODEIntegrator {Alg, S} , Q) where {Alg, S<: SMatrix } =
198+ set_deviations! (integ:: AbstractODEIntegrator {Alg, IIP, S} , Q) where {Alg, IIP , S<: SMatrix } =
189199 (integ. u = hcat (integ. u[:,1 ], Q); u_modified! (integ, true ))
190200
191- function DiffEqBase. reinit! (integ:: ODEIntegrator , u0:: AbstractVector ,
201+ function DiffEqBase. reinit! (integ:: AbstractODEIntegrator , u0:: AbstractVector ,
192202 Q0:: AbstractMatrix ; kwargs... )
193203
194204 set_state! (integ, u0)
0 commit comments