@@ -3,7 +3,7 @@ import OrdinaryDiffEq.ODEProblem
33import OrdinaryDiffEq. ODEIntegrator
44
55export ContinuousDS, variational_integrator, ODEIntegrator, ODEProblem
6- export ContinuousDynamicalSystem
6+ export ContinuousDynamicalSystem, DEFAULT_DIFFEQ_KWARGS, get_sol
77
88# ######################################################################################
99# Constructors #
@@ -165,9 +165,16 @@ ODEProblem(ds::ContinuousDS, tspan::Tuple, state = ds.prob.u0) =
165165ODEProblem {true} (ds. prob. f, state, tspan,
166166callback = ds. prob. callback, mass_matrix = ds. prob. mass_matrix)
167167
168- ODEProblem (ds:: ContinuousDS , t:: Real , state, cb) =
169- ODEProblem {true} (ds. prob. f, state, (zero (t), t),
170- callback = cb, mass_matrix = ds. prob. mass_matrix)
168+ function ODEProblem (ds:: ContinuousDS , t:: Real , state, cb)
169+ if ds. prob. callback == nothing
170+ return ODEProblem {true} (ds. prob. f, state, (zero (t), t),
171+ callback = cb, mass_matrix = ds. prob. mass_matrix)
172+ else
173+ return ODEProblem {true} (ds. prob. f, state, (zero (t), t),
174+ callback = CallbackSet (cb, ds. prob. callback),
175+ mass_matrix = ds. prob. mass_matrix)
176+ end
177+ end
171178
172179"""
173180 ODEIntegrator(ds::ContinuousDS, t [, state]; diff_eq_kwargs)
@@ -270,7 +277,7 @@ const DEFAULT_SOLVER = Vern9()
270277function evolve (ds:: ContinuousDS , t = 1.0 , state = ds. prob. u0;
271278 diff_eq_kwargs = DEFAULT_DIFFEQ_KWARGS)
272279 prob = ODEProblem (ds, t, state)
273- return get_sol (prob, diff_eq_kwargs)[end ]
280+ return get_sol (prob, diff_eq_kwargs)[1 ][ end ]
274281end
275282
276283evolve! (ds:: ContinuousDS , t = 1.0 ; diff_eq_kwargs = DEFAULT_DIFFEQ_KWARGS) =
@@ -279,7 +286,7 @@ evolve!(ds::ContinuousDS, t = 1.0; diff_eq_kwargs = DEFAULT_DIFFEQ_KWARGS) =
279286function extract_solver (diff_eq_kwargs)
280287 # Extract solver from kwargs
281288 if haskey (diff_eq_kwargs, :solver )
282- newkw = copy (diff_eq_kwargs)
289+ newkw = deepcopy (diff_eq_kwargs)
283290 solver = diff_eq_kwargs[:solver ]
284291 pop! (newkw, :solver )
285292 else
@@ -290,31 +297,41 @@ function extract_solver(diff_eq_kwargs)
290297end
291298
292299"""
293- get_sol(prob::ODEProblem, diff_eq_kwargs::Dict = Dict())
294- Solve the `prob` using `solve` and return the solution.
300+ get_sol(prob::ODEProblem [, diff_eq_kwargs::Dict, extra_kwargs::Dict])
301+ Solve the `prob` using `solve` and return the solutions vector as well as
302+ the time vector.
295303
296- Correctly uses `tstops` if necessary (e.g. in the presence of `ManifoldProjection`).
304+ The second and third
305+ arguments are optional *position* arguments, passed to `solve` as keyword arguments.
306+ They both have to be dictionaries of `Symbol` keys.
307+ Only the second argument may contain a solver via the `:solver` key.
308+
309+ `get_sol` correctly uses `tstops` if necessary
310+ (e.g. in the presence of `DiscreteCallback`s).
297311"""
298- function get_sol (prob:: ODEProblem , diff_eq_kwargs:: Dict = DEFAULT_DIFFEQ_KWARGS)
312+ function get_sol (prob:: ODEProblem , diff_eq_kwargs:: Dict = DEFAULT_DIFFEQ_KWARGS,
313+ extra_kwargs = Dict ())
299314
300315 solver, newkw = extract_solver (diff_eq_kwargs)
301316 # Take special care of callback sessions and use `tstops` if necessary
302317 # in conjuction with `saveat`
303318 if haskey (newkw, :saveat ) && use_tstops (prob)
304- newkw[:tstops ] = newkw[:saveat ]
319+ sol = solve (prob, solver; newkw... , extra_kwargs... , save_everystep= false ,
320+ tstops = newkw[:saveat ])
321+ else
322+ sol = solve (prob, solver; newkw... , extra_kwargs... , save_everystep= false )
305323 end
306324
307- sol = solve (prob, solver; newkw... , save_everystep= false )
308- return sol. u
325+ return sol. u, sol. t
309326end
310327
311328function use_tstops (prob:: ODEProblem )
312329 if prob. callback == nothing
313330 return false
314331 elseif typeof (prob. callback) <: CallbackSet
315- any (x-> typeof (x. affect! )<: ManifoldProjection , prob. callback. discrete_callbacks)
332+ any (x-> typeof (x)<: DiscreteCallback , prob. callback. discrete_callbacks)
316333 else
317- return typeof (prob. callback. affect! ) <: ManifoldProjection
334+ return typeof (prob. callback) <: DiscreteCallback
318335 end
319336end
320337
@@ -335,15 +352,8 @@ function trajectory(ds::ContinuousDS, T;
335352 end
336353
337354 prob = ODEProblem (ds, T)
338- if eltype (diff_eq_kwargs) != Pair{Symbol,Any}
339- # nessesary conversion to add :saveat
340- kw = Dict {Symbol, Any} (diff_eq_kwargs)
341- else
342- kw = diff_eq_kwargs
343- end
344- kw[:saveat ] = t
345355
346- return Dataset (get_sol (prob, kw) )
356+ return Dataset (get_sol (prob, diff_eq_kwargs, Dict ( :saveat => t))[ 1 ] )
347357end
348358
349359# ######################################################################################
0 commit comments