Skip to content

Commit 654bd34

Browse files
authored
feat: use a caching allocator for GPUArrays workflows (#1549)
* feat: use a caching allocator for GPUArrays workflows * fix: switch arg position * fix: GPUArrays compat * fix: other device types
1 parent 510f710 commit 654bd34

File tree

3 files changed

+123
-31
lines changed

3 files changed

+123
-31
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
4343
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4444
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4545
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
46+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
4647
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
4748
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
4849
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
@@ -64,6 +65,7 @@ WeightInitializers = {path = "lib/WeightInitializers"}
6465
LuxComponentArraysExt = "ComponentArrays"
6566
LuxEnzymeExt = "Enzyme"
6667
LuxFluxExt = "Flux"
68+
LuxGPUArraysExt = "GPUArrays"
6769
LuxLossFunctionsExt = "LossFunctions"
6870
LuxMLUtilsExt = "MLUtils"
6971
LuxMPIExt = "MPI"
@@ -93,6 +95,7 @@ Flux = "0.16.3"
9395
ForwardDiff = "0.10.36, =1"
9496
FunctionWrappers = "1.1.3"
9597
Functors = "0.5"
98+
GPUArrays = "11"
9699
GPUArraysCore = "0.2"
97100
LinearAlgebra = "1.10"
98101
LossFunctions = "0.11.1, 1"

ext/LuxGPUArraysExt.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
module LuxGPUArraysExt
2+
3+
using GPUArrays: AllocCache, @cached
4+
using Lux: Training
5+
using MLDataDevices: AbstractGPUDevice
6+
7+
Training.get_allocator_cache(::AbstractGPUDevice) = AllocCache()
8+
9+
function Training.compute_gradients_impl_with_allocator_cache(
10+
backend, alloc_cache::AllocCache, obj_fn::F, data, ts::Training.TrainState
11+
) where {F}
12+
@cached alloc_cache begin
13+
return Training.compute_gradients_impl(backend, obj_fn, data, ts)
14+
end
15+
end
16+
17+
for inplace in ("!", "")
18+
step_with_alloc_cache = Symbol(:single_train_step_impl_with_allocator_cache, inplace)
19+
step_inner = Symbol(:single_train_step_impl, inplace)
20+
apply_gradients_with_alloc_cache = Symbol(
21+
:apply_gradients_with_allocator_cache, inplace
22+
)
23+
apply_fn = Symbol(:apply_gradients_impl, inplace)
24+
25+
@eval begin
26+
function Training.$(apply_gradients_with_alloc_cache)(
27+
alloc_cache::AllocCache, ts::Training.TrainState, grads
28+
)
29+
@cached alloc_cache begin
30+
return Training.$(apply_fn)(ts, grads)
31+
end
32+
end
33+
34+
function Training.$(step_with_alloc_cache)(
35+
backend, alloc_cache::AllocCache, obj_fn::F, data, ts::Training.TrainState
36+
) where {F}
37+
@cached alloc_cache begin
38+
return Training.$(step_inner)(backend, obj_fn, data, ts)
39+
end
40+
end
41+
end
42+
end
43+
44+
end

src/helpers/training.jl

Lines changed: 76 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Training State containing:
3131
Internal fields:
3232
3333
- `cache`: Cached values. Implementations are free to use this for whatever they want.
34+
- `allocator_cache`: Used by GPUArrays compatible backends to cache memory allocations.
3435
- `objective_function`: Objective function might be cached.
3536
3637
!!! warning
@@ -41,6 +42,7 @@ Internal fields:
4142
@concrete struct TrainState
4243
cache
4344
objective_function
45+
allocator_cache
4446
model
4547
parameters
4648
states
@@ -55,6 +57,7 @@ function Adapt.adapt_structure(to::AbstractDevice, ts::TrainState)
5557
return TrainState(
5658
nothing,
5759
nothing,
60+
get_allocator_cache(to),
5861
ts.model,
5962
to(ts.parameters),
6063
to(ts.states),
@@ -91,16 +94,7 @@ function Adapt.adapt_structure(to::ReactantDevice, ts::TrainState)
9194
This ensures the optimizer state and other internal states are on the device on
9295
construction.
9396
"""
94-
return TrainState(
95-
nothing,
96-
nothing,
97-
ts.model,
98-
to(ts.parameters),
99-
to(ts.states),
100-
ts.optimizer,
101-
to(ts.optimizer_state),
102-
ts.step,
103-
)
97+
return @invoke Adapt.adapt_structure(to::AbstractDevice, ts::TrainState)
10498
end
10599

106100
"""
@@ -125,9 +119,13 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr
125119
optimizer = ReactantCompatibleOptimisers.make_reactant_compatible(optimizer, dev)
126120
end
127121
st_opt = Optimisers.setup(optimizer, ps)
128-
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
122+
return TrainState(
123+
nothing, nothing, get_allocator_cache(dev), model, ps, st, optimizer, st_opt, 0
124+
)
129125
end
130126

127+
get_allocator_cache(_) = nothing
128+
131129
@concrete struct TrainingBackendCache
132130
backend
133131
first_try <: StaticBool
@@ -190,14 +188,25 @@ function apply_gradients(ts::TrainState, grads)
190188
)
191189
return apply_gradients_reactant(ts, grads)
192190
end
191+
return apply_gradients_with_allocator_cache(ts.allocator_cache, ts, grads)
192+
end
193+
194+
# apply_gradients -> apply_gradients_reactant (for ReactantBackend)
195+
# -> apply_gradients_with_allocator_cache -> apply_gradients_impl
196+
197+
function apply_gradients_with_allocator_cache(::Nothing, ts::TrainState, grads)
198+
return apply_gradients_impl(ts, grads)
199+
end
200+
201+
function apply_gradients_impl(ts::TrainState, grads)
193202
optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads)
194203
@set! ts.parameters = ps
195204
@set! ts.optimizer_state = optimizer_state
196205
@set! ts.step = ts.step + 1
197206
return ts
198207
end
199208

200-
function apply_gradients_reactant end
209+
function apply_gradients_reactant end # updated in ReactantExt
201210

202211
"""
203212
apply_gradients!(ts::TrainState, grads)
@@ -214,12 +223,23 @@ function apply_gradients!(ts::TrainState, grads)
214223
)
215224
return apply_gradients_reactant!(ts, grads)
216225
end
226+
return apply_gradients_with_allocator_cache!(ts.allocator_cache, ts, grads)
227+
end
228+
229+
# apply_gradients! -> apply_gradients_reactant! (for ReactantBackend)
230+
# -> apply_gradients_with_allocator_cache! -> apply_gradients_impl!
231+
232+
function apply_gradients_with_allocator_cache!(::Nothing, ts::TrainState, grads)
233+
return apply_gradients_impl!(ts, grads)
234+
end
235+
236+
function apply_gradients_impl!(ts::TrainState, grads)
217237
Optimisers.update!(ts.optimizer_state, ts.parameters, grads)
218238
@set! ts.step = ts.step + 1
219239
return ts
220240
end
221241

222-
function apply_gradients_reactant! end
242+
function apply_gradients_reactant! end # updated in ReactantExt
223243

224244
const SYNC_DOCSTRING = """
225245
- `sync`: If `true`, then the compiled reactant function is compiled with `sync=true`.
@@ -288,20 +308,17 @@ A 4-Tuple containing:
288308
"""
289309
function compute_gradients(ad, obj_fn::F, data, ts::TrainState; sync::Bool=false) where {F}
290310
dev_type = get_device_type((ts.parameters, ts.states))
291-
return compute_gradients_impl(maybe_wrap_adtype(ad, dev_type; sync), obj_fn, data, ts)
311+
return compute_gradients_impl_with_allocator_cache(
312+
maybe_wrap_adtype(ad, dev_type; sync), ts.allocator_cache, obj_fn, data, ts
313+
)
292314
end
293315

294-
maybe_wrap_adtype(backend::ReactantBackend, ::Any; kwargs...) = backend
295-
maybe_wrap_adtype(ad::AbstractADType, ::Any; kwargs...) = ad
296-
function maybe_wrap_adtype(
297-
ad::AbstractADType,
298-
::Type{ReactantDevice};
299-
return_gradients::Utils.BoolType=True(),
300-
sync::Bool=false,
301-
)
302-
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients), sync)
303-
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
304-
Enzyme.jl (`AutoEnzyme`)."))
316+
# compute_gradients -> compute_gradients_impl_with_allocator_cache -> compute_gradients_impl
317+
318+
function compute_gradients_impl_with_allocator_cache(
319+
backend, ::Nothing, obj_fn::F, data, ts::TrainState
320+
) where {F}
321+
return compute_gradients_impl(backend, obj_fn, data, ts)
305322
end
306323

307324
function compute_gradients_impl(ad, ::F, _, ts::TrainState) where {F}
@@ -328,6 +345,19 @@ for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme, :Mooncake)
328345
end
329346
end
330347

348+
maybe_wrap_adtype(backend::ReactantBackend, ::Any; kwargs...) = backend
349+
maybe_wrap_adtype(ad::AbstractADType, ::Any; kwargs...) = ad
350+
function maybe_wrap_adtype(
351+
ad::AbstractADType,
352+
::Type{ReactantDevice};
353+
return_gradients::Utils.BoolType=True(),
354+
sync::Bool=false,
355+
)
356+
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients), sync)
357+
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
358+
Enzyme.jl (`AutoEnzyme`)."))
359+
end
360+
331361
function generate_wrappers(::F, m, ps, st, data, ::False) where {F}
332362
@warn "Detected function wrapper generation with function being updated between calls. \
333363
This will generate type-unstable code. A possible reason for this is \
@@ -395,7 +425,9 @@ function single_train_step!(
395425
backend = maybe_wrap_adtype(
396426
backend, get_device_type((ts.parameters, ts.states)); return_gradients, sync
397427
)
398-
return single_train_step_impl!(backend, obj_fn, data, ts)
428+
return single_train_step_impl_with_allocator_cache!(
429+
backend, ts.allocator_cache, obj_fn, data, ts
430+
)
399431
end
400432

401433
"""
@@ -429,16 +461,29 @@ function single_train_step(
429461
backend = maybe_wrap_adtype(
430462
backend, get_device_type((ts.parameters, ts.states)); return_gradients, sync
431463
)
432-
return single_train_step_impl(backend, obj_fn, data, ts)
464+
return single_train_step_impl_with_allocator_cache(
465+
backend, ts.allocator_cache, obj_fn, data, ts
466+
)
433467
end
434468

469+
# single_train_step -> single_train_step_impl_with_allocator_cache -> single_train_step_impl
470+
435471
for inplace in ("!", "")
436472
step = Symbol(:single_train_step_impl, inplace)
473+
step_allocator_cache = Symbol(:single_train_step_impl_with_allocator_cache, inplace)
437474
apply_fn = Symbol(:apply_gradients, inplace)
438-
@eval function $(step)(backend, obj_fn::F, data, ts::TrainState) where {F}
439-
grads, loss, stats, ts = compute_gradients(backend, obj_fn, data, ts)
440-
ts = $(apply_fn)(ts, grads)
441-
return grads, loss, stats, ts
475+
@eval begin
476+
function $(step_allocator_cache)(
477+
backend, ::Nothing, obj_fn::F, data, ts::TrainState
478+
) where {F}
479+
return $(step)(backend, obj_fn, data, ts)
480+
end
481+
482+
function $(step)(backend, obj_fn::F, data, ts::TrainState) where {F}
483+
grads, loss, stats, ts = compute_gradients(backend, obj_fn, data, ts)
484+
ts = $(apply_fn)(ts, grads)
485+
return grads, loss, stats, ts
486+
end
442487
end
443488
end
444489

0 commit comments

Comments
 (0)