Skip to content

Commit 0a12bb5

Browse files
committed
get_sol upgrade and ODEProblem with callbacks
1 parent 7139d75 commit 0a12bb5

File tree

3 files changed

+38
-26
lines changed

3 files changed

+38
-26
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
* Added energy conservation option to the Henon Helies system
44
* All `ContinuousDS` evolution now internally passes thrgouh the `get_sol` function,
55
which improves the clarity and stability of the ecosystem greatly!!!
6-
6+
* Improved stability in propagating `solve` keywords.
7+
* `get_sol` now returns solution and time vector for generality purposes.
8+
* `get_sol` is now also exported.
79

810
# v0.3.3
911
## Non-breaking

src/continuous.jl

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import OrdinaryDiffEq.ODEProblem
33
import OrdinaryDiffEq.ODEIntegrator
44

55
export ContinuousDS, variational_integrator, ODEIntegrator, ODEProblem
6-
export ContinuousDynamicalSystem
6+
export ContinuousDynamicalSystem, DEFAULT_DIFFEQ_KWARGS, get_sol
77

88
#######################################################################################
99
# Constructors #
@@ -165,9 +165,16 @@ ODEProblem(ds::ContinuousDS, tspan::Tuple, state = ds.prob.u0) =
165165
ODEProblem{true}(ds.prob.f, state, tspan,
166166
callback = ds.prob.callback, mass_matrix = ds.prob.mass_matrix)
167167

168-
ODEProblem(ds::ContinuousDS, t::Real, state, cb) =
169-
ODEProblem{true}(ds.prob.f, state, (zero(t), t),
170-
callback = cb, mass_matrix = ds.prob.mass_matrix)
168+
function ODEProblem(ds::ContinuousDS, t::Real, state, cb)
169+
if ds.prob.callback == nothing
170+
return ODEProblem{true}(ds.prob.f, state, (zero(t), t),
171+
callback = cb, mass_matrix = ds.prob.mass_matrix)
172+
else
173+
return ODEProblem{true}(ds.prob.f, state, (zero(t), t),
174+
callback = CallbackSet(cb, ds.prob.callback),
175+
mass_matrix = ds.prob.mass_matrix)
176+
end
177+
end
171178

172179
"""
173180
ODEIntegrator(ds::ContinuousDS, t [, state]; diff_eq_kwargs)
@@ -270,7 +277,7 @@ const DEFAULT_SOLVER = Vern9()
270277
function evolve(ds::ContinuousDS, t = 1.0, state = ds.prob.u0;
271278
diff_eq_kwargs = DEFAULT_DIFFEQ_KWARGS)
272279
prob = ODEProblem(ds, t, state)
273-
return get_sol(prob, diff_eq_kwargs)[end]
280+
return get_sol(prob, diff_eq_kwargs)[1][end]
274281
end
275282

276283
evolve!(ds::ContinuousDS, t = 1.0; diff_eq_kwargs = DEFAULT_DIFFEQ_KWARGS) =
@@ -279,7 +286,7 @@ evolve!(ds::ContinuousDS, t = 1.0; diff_eq_kwargs = DEFAULT_DIFFEQ_KWARGS) =
279286
function extract_solver(diff_eq_kwargs)
280287
# Extract solver from kwargs
281288
if haskey(diff_eq_kwargs, :solver)
282-
newkw = copy(diff_eq_kwargs)
289+
newkw = deepcopy(diff_eq_kwargs)
283290
solver = diff_eq_kwargs[:solver]
284291
pop!(newkw, :solver)
285292
else
@@ -290,31 +297,41 @@ function extract_solver(diff_eq_kwargs)
290297
end
291298

292299
"""
293-
get_sol(prob::ODEProblem, diff_eq_kwargs::Dict = Dict())
294-
Solve the `prob` using `solve` and return the solution.
300+
get_sol(prob::ODEProblem [, diff_eq_kwargs::Dict, extra_kwargs::Dict])
301+
Solve the `prob` using `solve` and return the solutions vector as well as
302+
the time vector.
295303
296-
Correctly uses `tstops` if necessary (e.g. in the presence of `ManifoldProjection`).
304+
The second and third
305+
arguments are optional *position* arguments, passed to `solve` as keyword arguments.
306+
They both have to be dictionaries of `Symbol` keys.
307+
Only the second argument may contain a solver via the `:solver` key.
308+
309+
`get_sol` correctly uses `tstops` if necessary
310+
(e.g. in the presence of `DiscreteCallback`s).
297311
"""
298-
function get_sol(prob::ODEProblem, diff_eq_kwargs::Dict = DEFAULT_DIFFEQ_KWARGS)
312+
function get_sol(prob::ODEProblem, diff_eq_kwargs::Dict = DEFAULT_DIFFEQ_KWARGS,
313+
extra_kwargs = Dict())
299314

300315
solver, newkw = extract_solver(diff_eq_kwargs)
301316
# Take special care of callback sessions and use `tstops` if necessary
302317
# in conjuction with `saveat`
303318
if haskey(newkw, :saveat) && use_tstops(prob)
304-
newkw[:tstops] = newkw[:saveat]
319+
sol = solve(prob, solver; newkw..., extra_kwargs..., save_everystep=false,
320+
tstops = newkw[:saveat])
321+
else
322+
sol = solve(prob, solver; newkw..., extra_kwargs..., save_everystep=false)
305323
end
306324

307-
sol = solve(prob, solver; newkw..., save_everystep=false)
308-
return sol.u
325+
return sol.u, sol.t
309326
end
310327

311328
function use_tstops(prob::ODEProblem)
312329
if prob.callback == nothing
313330
return false
314331
elseif typeof(prob.callback) <: CallbackSet
315-
any(x->typeof(x.affect!)<:ManifoldProjection, prob.callback.discrete_callbacks)
332+
any(x->typeof(x)<:DiscreteCallback, prob.callback.discrete_callbacks)
316333
else
317-
return typeof(prob.callback.affect!) <: ManifoldProjection
334+
return typeof(prob.callback) <: DiscreteCallback
318335
end
319336
end
320337

@@ -335,15 +352,8 @@ function trajectory(ds::ContinuousDS, T;
335352
end
336353

337354
prob = ODEProblem(ds, T)
338-
if eltype(diff_eq_kwargs) != Pair{Symbol,Any}
339-
# nessesary conversion to add :saveat
340-
kw = Dict{Symbol, Any}(diff_eq_kwargs)
341-
else
342-
kw = diff_eq_kwargs
343-
end
344-
kw[:saveat] = t
345355

346-
return Dataset(get_sol(prob, kw))
356+
return Dataset(get_sol(prob, diff_eq_kwargs, Dict(:saveat => t))[1])
347357
end
348358

349359
#######################################################################################

test/continuous_systems.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,6 @@ end
146146
E1 = [Hhh(p) for p in tra1]
147147
E2 = [Hhh(p) for p in tra2]
148148

149-
@test std(E1) < 1e-13
150-
@test std(E2) < 1e-13
149+
@test std(E1) < 1e-12
150+
@test std(E2) < 1e-12
151151
end

0 commit comments

Comments
 (0)