Skip to content

Commit 6b9a7db

Browse files
committed
hanlde special callbacks
1 parent 29e0c62 commit 6b9a7db

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# v0.3.1
22
* Added `jacobian` function
33
* Removed `EomVector` nonsense.
4+
* Now `trajectory` correctly gives equi-spaced points when the ODEProblem has
5+
"special" callbacks.
46

57
# v0.3.0
68
## BREAKING

src/continuous.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,15 @@ function evolve(ds::ContinuousDS, t = 1.0, state = ds.prob.u0;
254254
return get_sol(prob, diff_eq_kwargs)[end]
255255
end
256256

257-
257+
function use_tstops(prob::ODEProblem)
258+
if prob.callback == nothing
259+
return false
260+
elseif typeof(prob.callback) <: CallbackSet
261+
any(x->typeof(x.affect!)<:ManifoldProjection, prob.callback.discrete_callbacks)
262+
else
263+
return typeof(prob.callback.affect!) == ManifoldProjection
264+
end
265+
end
258266

259267
# See discrete.jl for the documentation string
260268
function trajectory(ds::ContinuousDS, T;
@@ -263,11 +271,10 @@ function trajectory(ds::ContinuousDS, T;
263271
# Necessary due to DifferentialEquations:
264272

265273
if typeof(T) <: Real && !issubtype(typeof(T), AbstractFloat)
274+
T<=0 && throw(ArgumentError("Total time `T` must be positive."))
266275
T = convert(Float64, T)
267276
end
268-
T<=0 && throw(ArgumentError("Total time `T` must be positive."))
269277

270-
D = dimension(ds)
271278
if typeof(T) <: Real
272279
t = zero(T):dt:T #time vector
273280
elseif typeof(T) == Tuple
@@ -277,6 +284,9 @@ function trajectory(ds::ContinuousDS, T;
277284
prob = ODEProblem(ds, T)
278285
kw = Dict{Symbol, Any}(diff_eq_kwargs) #nessesary conversion to add :saveat
279286
kw[:saveat] = t
287+
# Take special care of callback sessions and use `tstops` if necessary
288+
use_tstops(prob) && (kw[:tstops] = t)
289+
280290
return Dataset(get_sol(prob, kw))
281291
end
282292

0 commit comments

Comments
 (0)