@@ -31,6 +31,7 @@ Training State containing:
3131Internal fields:
3232
3333 - `cache`: Cached values. Implementations are free to use this for whatever they want.
34+ - `allocator_cache`: Used by GPUArrays compatible backends to cache memory allocations.
3435 - `objective_function`: Objective function might be cached.
3536
3637!!! warning
@@ -41,6 +42,7 @@ Internal fields:
4142@concrete struct TrainState
4243 cache
4344 objective_function
45+ allocator_cache
4446 model
4547 parameters
4648 states
@@ -55,6 +57,7 @@ function Adapt.adapt_structure(to::AbstractDevice, ts::TrainState)
5557 return TrainState (
5658 nothing ,
5759 nothing ,
60+ get_allocator_cache (to),
5861 ts. model,
5962 to (ts. parameters),
6063 to (ts. states),
@@ -91,16 +94,7 @@ function Adapt.adapt_structure(to::ReactantDevice, ts::TrainState)
9194 This ensures the optimizer state and other internal states are on the device on
9295 construction.
9396 """
94- return TrainState (
95- nothing ,
96- nothing ,
97- ts. model,
98- to (ts. parameters),
99- to (ts. states),
100- ts. optimizer,
101- to (ts. optimizer_state),
102- ts. step,
103- )
97+ return @invoke Adapt. adapt_structure (to:: AbstractDevice , ts:: TrainState )
10498end
10599
106100"""
@@ -125,9 +119,13 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr
125119 optimizer = ReactantCompatibleOptimisers. make_reactant_compatible (optimizer, dev)
126120 end
127121 st_opt = Optimisers. setup (optimizer, ps)
128- return TrainState (nothing , nothing , model, ps, st, optimizer, st_opt, 0 )
122+ return TrainState (
123+ nothing , nothing , get_allocator_cache (dev), model, ps, st, optimizer, st_opt, 0
124+ )
129125end
130126
127+ get_allocator_cache (_) = nothing
128+
131129@concrete struct TrainingBackendCache
132130 backend
133131 first_try <: StaticBool
@@ -190,14 +188,25 @@ function apply_gradients(ts::TrainState, grads)
190188 )
191189 return apply_gradients_reactant (ts, grads)
192190 end
191+ return apply_gradients_with_allocator_cache (ts. allocator_cache, ts, grads)
192+ end
193+
194+ # apply_gradients -> apply_gradients_reactant (for ReactantBackend)
195+ # -> apply_gradients_with_allocator_cache -> apply_gradients_impl
196+
197+ function apply_gradients_with_allocator_cache (:: Nothing , ts:: TrainState , grads)
198+ return apply_gradients_impl (ts, grads)
199+ end
200+
201+ function apply_gradients_impl (ts:: TrainState , grads)
193202 optimizer_state, ps = Optimisers. update (ts. optimizer_state, ts. parameters, grads)
194203 @set! ts. parameters = ps
195204 @set! ts. optimizer_state = optimizer_state
196205 @set! ts. step = ts. step + 1
197206 return ts
198207end
199208
200- function apply_gradients_reactant end
209+ function apply_gradients_reactant end # updated in ReactantExt
201210
202211"""
203212 apply_gradients!(ts::TrainState, grads)
@@ -214,12 +223,23 @@ function apply_gradients!(ts::TrainState, grads)
214223 )
215224 return apply_gradients_reactant! (ts, grads)
216225 end
226+ return apply_gradients_with_allocator_cache! (ts. allocator_cache, ts, grads)
227+ end
228+
229+ # apply_gradients! -> apply_gradients_reactant! (for ReactantBackend)
230+ # -> apply_gradients_with_allocator_cache! -> apply_gradients_impl!
231+
232+ function apply_gradients_with_allocator_cache! (:: Nothing , ts:: TrainState , grads)
233+ return apply_gradients_impl! (ts, grads)
234+ end
235+
236+ function apply_gradients_impl! (ts:: TrainState , grads)
217237 Optimisers. update! (ts. optimizer_state, ts. parameters, grads)
218238 @set! ts. step = ts. step + 1
219239 return ts
220240end
221241
222- function apply_gradients_reactant! end
242+ function apply_gradients_reactant! end # updated in ReactantExt
223243
224244const SYNC_DOCSTRING = """
225245 - `sync`: If `true`, then the compiled reactant function is compiled with `sync=true`.
@@ -288,20 +308,17 @@ A 4-Tuple containing:
288308"""
289309function compute_gradients (ad, obj_fn:: F , data, ts:: TrainState ; sync:: Bool = false ) where {F}
290310 dev_type = get_device_type ((ts. parameters, ts. states))
291- return compute_gradients_impl (maybe_wrap_adtype (ad, dev_type; sync), obj_fn, data, ts)
311+ return compute_gradients_impl_with_allocator_cache (
312+ maybe_wrap_adtype (ad, dev_type; sync), ts. allocator_cache, obj_fn, data, ts
313+ )
292314end
293315
294- maybe_wrap_adtype (backend:: ReactantBackend , :: Any ; kwargs... ) = backend
295- maybe_wrap_adtype (ad:: AbstractADType , :: Any ; kwargs... ) = ad
296- function maybe_wrap_adtype (
297- ad:: AbstractADType ,
298- :: Type{ReactantDevice} ;
299- return_gradients:: Utils.BoolType = True (),
300- sync:: Bool = false ,
301- )
302- ad isa AutoEnzyme && return ReactantBackend (static (return_gradients), sync)
303- throw (ArgumentError (" Computing gradients for models on XLA is supported only with \
304- Enzyme.jl (`AutoEnzyme`)." ))
316+ # compute_gradients -> compute_gradients_impl_with_allocator_cache -> compute_gradients_impl
317+
318+ function compute_gradients_impl_with_allocator_cache (
319+ backend, :: Nothing , obj_fn:: F , data, ts:: TrainState
320+ ) where {F}
321+ return compute_gradients_impl (backend, obj_fn, data, ts)
305322end
306323
307324function compute_gradients_impl (ad, :: F , _, ts:: TrainState ) where {F}
@@ -328,6 +345,19 @@ for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme, :Mooncake)
328345 end
329346end
330347
348+ maybe_wrap_adtype (backend:: ReactantBackend , :: Any ; kwargs... ) = backend
349+ maybe_wrap_adtype (ad:: AbstractADType , :: Any ; kwargs... ) = ad
350+ function maybe_wrap_adtype (
351+ ad:: AbstractADType ,
352+ :: Type{ReactantDevice} ;
353+ return_gradients:: Utils.BoolType = True (),
354+ sync:: Bool = false ,
355+ )
356+ ad isa AutoEnzyme && return ReactantBackend (static (return_gradients), sync)
357+ throw (ArgumentError (" Computing gradients for models on XLA is supported only with \
358+ Enzyme.jl (`AutoEnzyme`)." ))
359+ end
360+
331361function generate_wrappers (:: F , m, ps, st, data, :: False ) where {F}
332362 @warn " Detected function wrapper generation with function being updated between calls. \
333363 This will generate type-unstable code. A possible reason for this is \
@@ -395,7 +425,9 @@ function single_train_step!(
395425 backend = maybe_wrap_adtype (
396426 backend, get_device_type ((ts. parameters, ts. states)); return_gradients, sync
397427 )
398- return single_train_step_impl! (backend, obj_fn, data, ts)
428+ return single_train_step_impl_with_allocator_cache! (
429+ backend, ts. allocator_cache, obj_fn, data, ts
430+ )
399431end
400432
401433"""
@@ -429,16 +461,29 @@ function single_train_step(
429461 backend = maybe_wrap_adtype (
430462 backend, get_device_type ((ts. parameters, ts. states)); return_gradients, sync
431463 )
432- return single_train_step_impl (backend, obj_fn, data, ts)
464+ return single_train_step_impl_with_allocator_cache (
465+ backend, ts. allocator_cache, obj_fn, data, ts
466+ )
433467end
434468
469+ # single_train_step -> single_train_step_impl_with_allocator_cache -> single_train_step_impl
470+
435471for inplace in (" !" , " " )
436472 step = Symbol (:single_train_step_impl , inplace)
473+ step_allocator_cache = Symbol (:single_train_step_impl_with_allocator_cache , inplace)
437474 apply_fn = Symbol (:apply_gradients , inplace)
438- @eval function $ (step)(backend, obj_fn:: F , data, ts:: TrainState ) where {F}
439- grads, loss, stats, ts = compute_gradients (backend, obj_fn, data, ts)
440- ts = $ (apply_fn)(ts, grads)
441- return grads, loss, stats, ts
475+ @eval begin
476+ function $ (step_allocator_cache)(
477+ backend, :: Nothing , obj_fn:: F , data, ts:: TrainState
478+ ) where {F}
479+ return $ (step)(backend, obj_fn, data, ts)
480+ end
481+
482+ function $ (step)(backend, obj_fn:: F , data, ts:: TrainState ) where {F}
483+ grads, loss, stats, ts = compute_gradients (backend, obj_fn, data, ts)
484+ ts = $ (apply_fn)(ts, grads)
485+ return grads, loss, stats, ts
486+ end
442487 end
443488end
444489
0 commit comments