Skip to content

Commit 0fe8fa8

Browse files
committed
fix: caching in Reactant backend
1 parent 8c14c89 commit 0fe8fa8

File tree

3 files changed

+117
-112
lines changed

3 files changed

+117
-112
lines changed

docs/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
2121
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
2222
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
2323
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
24-
OpenSSL_jll = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
2524
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
2625
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2726
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -62,7 +61,6 @@ LuxLib = "1.3.4"
6261
LuxTestUtils = "2"
6362
MLDataDevices = "1.6.10"
6463
NNlib = "0.9.27"
65-
OpenSSL_jll = "=3.0.16"
6664
Optimisers = "0.4.6"
6765
Printf = "1.10"
6866
Random = "1.10"

docs/src/introduction/index.md

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@ Pkg.add("Lux")
2020

2121
!!! tip "Pre-Requisites"
2222

23-
You need to install `Optimisers` and `Zygote` if not done already.
24-
`Pkg.add(["Optimisers", "Zygote"])`
23+
You need to install `Optimisers`, `Reactant` and `Enzyme` if not done already.
24+
`Pkg.add(["Optimisers", "Enzyme", "Reactant"])`
2525

2626
```@example quickstart
27-
using Lux, Random, Optimisers, Zygote
28-
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
27+
using Lux, Random, Optimisers, Enzyme, Reactant
2928
```
3029

3130
We take randomness very seriously
@@ -40,7 +39,7 @@ Build the model
4039

4140
```@example quickstart
4241
# Construct the layer
43-
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))
42+
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 256, tanh), Dense(256, 10)))
4443
```
4544

4645
Models don't hold parameters and states so initialize them. From there on, we can just use
@@ -49,7 +48,7 @@ API that provides an uniform API over all supported AD systems.
4948

5049
```@example quickstart
5150
# Get the device determined by Lux
52-
dev = gpu_device()
51+
dev = reactant_device()
5352
5453
# Parameter and State Variables
5554
ps, st = Lux.setup(rng, model) |> dev
@@ -58,25 +57,35 @@ ps, st = Lux.setup(rng, model) |> dev
5857
x = rand(rng, Float32, 128, 2) |> dev
5958
6059
# Run the model
61-
y, st = Lux.apply(model, x, ps, st)
60+
## We need to use @jit to compile and run the model with Reactant
61+
y, st = @jit Lux.apply(model, x, ps, st)
62+
63+
## For best performance, first compile the model with Reactant and then run it
64+
apply_compiled = @compile Lux.apply(model, x, ps, st)
65+
apply_compiled(model, x, ps, st)
6266
6367
# Gradients
6468
## First construct a TrainState
65-
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
69+
train_state = Training.TrainState(model, ps, st, Adam(0.0001f0))
6670
6771
## We can compute the gradients using Training.compute_gradients
72+
## TrainState handles compilation internally
6873
gs, loss, stats, train_state = Lux.Training.compute_gradients(
69-
AutoZygote(), MSELoss(),
70-
(x, dev(rand(rng, Float32, 10, 2))), train_state
74+
AutoEnzyme(),
75+
MSELoss(),
76+
(x, dev(rand(rng, Float32, 10, 2))),
77+
train_state
7178
)
7279
7380
## Optimization
7481
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)
7582
76-
# Both these steps can be combined into a single call
83+
# Both these steps can be combined into a single call (preferred approach)
7784
gs, loss, stats, train_state = Training.single_train_step!(
78-
AutoZygote(), MSELoss(),
79-
(x, dev(rand(rng, Float32, 10, 2))), train_state
85+
AutoEnzyme(),
86+
MSELoss(),
87+
(x, dev(rand(rng, Float32, 10, 2))),
88+
train_state
8089
)
8190
```
8291

ext/LuxReactantExt/training.jl

Lines changed: 95 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -70,42 +70,42 @@ function compute_gradients_internal(objective_function::F, model, data, ps, st)
7070
)
7171
end
7272

73-
Profiler.@annotate "Compile Compute Gradients" function Lux.Training.compute_gradients_impl(
73+
Profiler.@annotate "Compute Gradients" function Lux.Training.compute_gradients_impl(
7474
backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState
7575
) where {F}
76-
compiled_gradient_function = with_default_precision_config(ts.parameters) do
77-
@compile sync = backend.sync compute_gradients_internal(
78-
objective_function, ts.model, data, ts.parameters, ts.states
79-
)
76+
if (
77+
ts.cache isa TrainingBackendCache &&
78+
hasfield(typeof(ts.cache.extras), :compiled_gradient_function)
79+
)
80+
compiled_gradient_function = ts.cache.extras.compiled_gradient_function
81+
else
82+
compiled_gradient_function = with_default_precision_config(ts.parameters) do
83+
@compile sync = backend.sync compute_gradients_internal(
84+
objective_function, ts.model, data, ts.parameters, ts.states
85+
)
86+
end
87+
88+
if ts.cache isa TrainingBackendCache
89+
@set! ts.cache.extras = merge(ts.cache.extras, (; compiled_gradient_function))
90+
else
91+
cache = TrainingBackendCache(
92+
backend, False(), nothing, (; compiled_gradient_function)
93+
)
94+
@set! ts.cache = cache
95+
end
96+
@set! ts.objective_function = objective_function
8097
end
8198

8299
grads, loss, stats, st = compiled_gradient_function(
83100
objective_function, ts.model, data, ts.parameters, ts.states
84101
)
85102

86-
cache = TrainingBackendCache(backend, False(), nothing, (; compiled_gradient_function))
87-
@set! ts.cache = cache
88-
@set! ts.objective_function = objective_function
89-
@set! ts.states = st
90-
return grads, loss, stats, ts
91-
end
92-
93-
Profiler.@annotate "Compute Gradients" function Lux.Training.compute_gradients_impl(
94-
::ReactantBackend,
95-
obj_fn::F,
96-
data,
97-
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend},F},
98-
) where {F}
99-
grads, loss, stats, st = ts.cache.extras.compiled_gradient_function(
100-
obj_fn, ts.model, data, ts.parameters, ts.states
101-
)
102103
@set! ts.states = st
103104
return grads, loss, stats, ts
104105
end
105106

106107
for inplace in ("!", "")
107108
fname = Symbol(:single_train_step_impl, inplace)
108-
internal_fn = Symbol(:compute_gradients_internal_and_step, inplace)
109109
apply_gradients_fn = Symbol(:apply_gradients, inplace)
110110
update_fn = Symbol(:update, inplace)
111111

@@ -141,110 +141,108 @@ for inplace in ("!", "")
141141
end
142142

143143
# XXX: recompile with a warning if new input types are used
144-
@eval Profiler.@annotate "Compile Train Step" function Lux.Training.$(fname)(
144+
@eval Profiler.@annotate "Train Step" function Lux.Training.$(fname)(
145145
backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState
146146
) where {F}
147-
device = get_device((ts.parameters, ts.states, ts.optimizer_state, data))
148-
@assert device isa ReactantDevice
149-
is_sharded = device.device === nothing
150-
151-
dps = if backend.return_gradients isa True
152-
Functors.fmap(Utils.zero, ts.parameters; exclude=MLDataDevices.isleaf)
147+
if (
148+
ts.cache isa TrainingBackendCache &&
149+
hasfield(typeof(ts.cache.extras), :compiled_grad_and_step_function)
150+
)
151+
(; compiled_grad_and_step_function, is_sharded) = ts.cache.extras
152+
ps = ts.parameters
153+
dparameters = ts.cache.dparameters
153154
else
154-
nothing
155-
end
155+
device = get_device((ts.parameters, ts.states, ts.optimizer_state, data))
156+
@assert device isa ReactantDevice
157+
is_sharded = device.device === nothing
158+
159+
dparameters = if backend.return_gradients isa True
160+
Functors.fmap(Utils.zero, ts.parameters; exclude=MLDataDevices.isleaf)
161+
else
162+
nothing
163+
end
156164

157-
$(ps_expr)
158-
159-
compiled_grad_and_step_function = with_default_precision_config(ts.parameters) do
160-
@compile sync = backend.sync $(internal_fn)(
161-
objective_function,
162-
ts.model,
163-
data,
164-
ps,
165-
ts.states,
166-
ts.optimizer_state,
167-
dps,
168-
is_sharded,
169-
)
165+
$(ps_expr)
166+
167+
compiled_grad_and_step_function =
168+
with_default_precision_config(ts.parameters) do
169+
@compile sync = backend.sync compute_gradients_internal_and_step!(
170+
objective_function,
171+
ts.model,
172+
data,
173+
ps,
174+
ts.states,
175+
ts.optimizer_state,
176+
dparameters,
177+
is_sharded,
178+
)
179+
end
180+
181+
if ts.cache isa TrainingBackendCache
182+
@set! ts.cache.dparameters = dparameters
183+
@set! ts.cache.extras = merge(
184+
ts.cache.extras, (; compiled_grad_and_step_function, is_sharded)
185+
)
186+
else
187+
cache = TrainingBackendCache(
188+
backend,
189+
False(),
190+
dparameters,
191+
(; compiled_grad_and_step_function, is_sharded),
192+
)
193+
@set! ts.cache = cache
194+
end
195+
@set! ts.objective_function = objective_function
170196
end
171197

198+
@show typeof(dparameters)
199+
172200
grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
173201
objective_function,
174202
ts.model,
175203
data,
176204
ps,
177205
ts.states,
178206
ts.optimizer_state,
179-
dps,
207+
dparameters,
180208
is_sharded,
181209
)
182210

183-
cache = TrainingBackendCache(
184-
backend, False(), dps, (; compiled_grad_and_step_function, is_sharded)
185-
)
186-
@set! ts.cache = cache
187-
@set! ts.objective_function = objective_function
188-
@set! ts.states = st
189-
@set! ts.parameters = ps
190-
@set! ts.optimizer_state = opt_state
191-
@set! ts.step = ts.step + 1
192-
193-
return grads, loss, stats, ts
194-
end
195-
196-
@eval Profiler.@annotate "Train Step" function Lux.Training.$(fname)(
197-
::ReactantBackend,
198-
obj_fn::F,
199-
data,
200-
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend},F},
201-
) where {F}
202-
grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
203-
obj_fn,
204-
ts.model,
205-
data,
206-
ts.parameters,
207-
ts.states,
208-
ts.optimizer_state,
209-
ts.cache.dparameters,
210-
ts.cache.extras.is_sharded,
211-
)
212-
213211
@set! ts.states = st
214212
@set! ts.parameters = ps
215213
@set! ts.optimizer_state = opt_state
216214
@set! ts.step = ts.step + 1
217215

218216
return grads, loss, stats, ts
219217
end
218+
end
220219

221-
@eval function $(internal_fn)(
222-
objective_function::F, model, data, ps, st, opt_state, ::Nothing, is_sharded::Bool
223-
) where {F}
224-
dps, loss, stats, stₙ = compute_gradients_internal(
225-
objective_function, model, data, ps, st
226-
)
220+
@eval function compute_gradients_internal_and_step!(
221+
objective_function::F, model, data, ps, st, opt_state, ::Nothing, is_sharded::Bool
222+
) where {F}
223+
dps, loss, stats, stₙ = compute_gradients_internal(
224+
objective_function, model, data, ps, st
225+
)
227226

228-
opt_state, psₙ = Optimisers.update!(opt_state, ps, dps)
229-
# Ensure sharding of input and output states are consistent
230-
is_sharded && mark_same_sharding_group(st, stₙ)
227+
opt_state, psₙ = Optimisers.update!(opt_state, ps, dps)
228+
# Ensure sharding of input and output states are consistent
229+
is_sharded && mark_same_sharding_group(st, stₙ)
231230

232-
return nothing, psₙ, loss, stats, stₙ, opt_state
233-
end
231+
return nothing, psₙ, loss, stats, stₙ, opt_state
232+
end
234233

235-
@eval function $(internal_fn)(
236-
objective_function::F, model, data, ps, st, opt_state, dps, is_sharded::Bool
237-
) where {F}
238-
dps, loss, stats, stₙ = compute_gradients_internal!(
239-
dps, objective_function, model, data, ps, st
240-
)
234+
@eval function compute_gradients_internal_and_step!(
235+
objective_function::F, model, data, ps, st, opt_state, dps, is_sharded::Bool
236+
) where {F}
237+
dps, loss, stats, stₙ = compute_gradients_internal!(
238+
dps, objective_function, model, data, ps, st
239+
)
241240

242-
opt_state, psₙ = Optimisers.update!(opt_state, ps, dps)
243-
# Ensure sharding of input and output states are consistent
244-
is_sharded && mark_same_sharding_group(st, stₙ)
241+
opt_state, psₙ = Optimisers.update!(opt_state, ps, dps)
242+
# Ensure sharding of input and output states are consistent
243+
is_sharded && mark_same_sharding_group(st, stₙ)
245244

246-
return dps, psₙ, loss, stats, stₙ, opt_state
247-
end
245+
return dps, psₙ, loss, stats, stₙ, opt_state
248246
end
249247

250248
mark_same_sharding_group(args...) = Functors.fmap(mark_same_sharding_group_inner, args...)

0 commit comments

Comments
 (0)