Skip to content

Commit 9c9b0d6

Browse files
committed
fix: caching in Reactant backend
1 parent 8c14c89 commit 9c9b0d6

File tree

5 files changed

+145
-119
lines changed

5 files changed

+145
-119
lines changed

.codecov.yml

Lines changed: 0 additions & 4 deletions
This file was deleted.

docs/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
2121
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
2222
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
2323
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
24-
OpenSSL_jll = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
2524
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
2625
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2726
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -62,7 +61,6 @@ LuxLib = "1.3.4"
6261
LuxTestUtils = "2"
6362
MLDataDevices = "1.6.10"
6463
NNlib = "0.9.27"
65-
OpenSSL_jll = "=3.0.16"
6664
Optimisers = "0.4.6"
6765
Printf = "1.10"
6866
Random = "1.10"

docs/src/introduction/index.md

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@ Pkg.add("Lux")
2020

2121
!!! tip "Pre-Requisites"
2222

23-
You need to install `Optimisers` and `Zygote` if not done already.
24-
`Pkg.add(["Optimisers", "Zygote"])`
23+
You need to install `Optimisers`, `Reactant` and `Enzyme` if not done already.
24+
`Pkg.add(["Optimisers", "Enzyme", "Reactant"])`
2525

2626
```@example quickstart
27-
using Lux, Random, Optimisers, Zygote
28-
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
27+
using Lux, Random, Optimisers, Enzyme, Reactant
2928
```
3029

3130
We take randomness very seriously
@@ -40,7 +39,7 @@ Build the model
4039

4140
```@example quickstart
4241
# Construct the layer
43-
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))
42+
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 256, tanh), Dense(256, 10)))
4443
```
4544

4645
Models don't hold parameters and states so initialize them. From there on, we can just use
@@ -49,7 +48,7 @@ API that provides an uniform API over all supported AD systems.
4948

5049
```@example quickstart
5150
# Get the device determined by Lux
52-
dev = gpu_device()
51+
dev = reactant_device()
5352
5453
# Parameter and State Variables
5554
ps, st = Lux.setup(rng, model) |> dev
@@ -58,25 +57,35 @@ ps, st = Lux.setup(rng, model) |> dev
5857
x = rand(rng, Float32, 128, 2) |> dev
5958
6059
# Run the model
61-
y, st = Lux.apply(model, x, ps, st)
60+
## We need to use @jit to compile and run the model with Reactant
61+
y, st = @jit Lux.apply(model, x, ps, st)
62+
63+
## For best performance, first compile the model with Reactant and then run it
64+
apply_compiled = @compile Lux.apply(model, x, ps, st)
65+
apply_compiled(model, x, ps, st)
6266
6367
# Gradients
6468
## First construct a TrainState
65-
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
69+
train_state = Training.TrainState(model, ps, st, Adam(0.0001f0))
6670
6771
## We can compute the gradients using Training.compute_gradients
72+
## TrainState handles compilation internally
6873
gs, loss, stats, train_state = Lux.Training.compute_gradients(
69-
AutoZygote(), MSELoss(),
70-
(x, dev(rand(rng, Float32, 10, 2))), train_state
74+
AutoEnzyme(),
75+
MSELoss(),
76+
(x, dev(rand(rng, Float32, 10, 2))),
77+
train_state
7178
)
7279
7380
## Optimization
7481
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)
7582
76-
# Both these steps can be combined into a single call
83+
# Both these steps can be combined into a single call (preferred approach)
7784
gs, loss, stats, train_state = Training.single_train_step!(
78-
AutoZygote(), MSELoss(),
79-
(x, dev(rand(rng, Float32, 10, 2))), train_state
85+
AutoEnzyme(),
86+
MSELoss(),
87+
(x, dev(rand(rng, Float32, 10, 2))),
88+
train_state
8089
)
8190
```
8291

ext/LuxReactantExt/training.jl

Lines changed: 107 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -70,43 +70,43 @@ function compute_gradients_internal(objective_function::F, model, data, ps, st)
7070
)
7171
end
7272

73-
Profiler.@annotate "Compile Compute Gradients" function Lux.Training.compute_gradients_impl(
73+
Profiler.@annotate "Compute Gradients" function Lux.Training.compute_gradients_impl(
7474
backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState
7575
) where {F}
76-
compiled_gradient_function = with_default_precision_config(ts.parameters) do
77-
@compile sync = backend.sync compute_gradients_internal(
78-
objective_function, ts.model, data, ts.parameters, ts.states
79-
)
76+
if (
77+
ts.cache isa TrainingBackendCache &&
78+
hasfield(typeof(ts.cache.extras), :compiled_gradient_function)
79+
)
80+
compiled_gradient_function = ts.cache.extras.compiled_gradient_function
81+
else
82+
compiled_gradient_function = with_default_precision_config(ts.parameters) do
83+
@compile sync = backend.sync compute_gradients_internal(
84+
objective_function, ts.model, data, ts.parameters, ts.states
85+
)
86+
end
87+
88+
if ts.cache isa TrainingBackendCache
89+
@set! ts.cache.extras = merge(ts.cache.extras, (; compiled_gradient_function))
90+
else
91+
cache = TrainingBackendCache(
92+
backend, False(), nothing, (; compiled_gradient_function)
93+
)
94+
@set! ts.cache = cache
95+
end
96+
@set! ts.objective_function = objective_function
8097
end
8198

8299
grads, loss, stats, st = compiled_gradient_function(
83100
objective_function, ts.model, data, ts.parameters, ts.states
84101
)
85102

86-
cache = TrainingBackendCache(backend, False(), nothing, (; compiled_gradient_function))
87-
@set! ts.cache = cache
88-
@set! ts.objective_function = objective_function
89-
@set! ts.states = st
90-
return grads, loss, stats, ts
91-
end
92-
93-
Profiler.@annotate "Compute Gradients" function Lux.Training.compute_gradients_impl(
94-
::ReactantBackend,
95-
obj_fn::F,
96-
data,
97-
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend},F},
98-
) where {F}
99-
grads, loss, stats, st = ts.cache.extras.compiled_gradient_function(
100-
obj_fn, ts.model, data, ts.parameters, ts.states
101-
)
102103
@set! ts.states = st
103104
return grads, loss, stats, ts
104105
end
105106

106107
for inplace in ("!", "")
107108
fname = Symbol(:single_train_step_impl, inplace)
108-
internal_fn = Symbol(:compute_gradients_internal_and_step, inplace)
109-
apply_gradients_fn = Symbol(:apply_gradients, inplace)
109+
apply_gradients_fn = Symbol(:apply_gradients_reactant, inplace)
110110
update_fn = Symbol(:update, inplace)
111111

112112
# Ideally users never hit this dispatch but it is still good to have as a fallback
@@ -115,7 +115,10 @@ for inplace in ("!", "")
115115
)(
116116
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}}, grads
117117
)
118-
if hasfield(typeof(ts.cache.extras), :update_function)
118+
if (
119+
ts.cache isa TrainingBackendCache &&
120+
hasfield(typeof(ts.cache.extras), :update_function)
121+
)
119122
update_function = ts.cache.extras.update_function
120123
else
121124
update_function = with_default_precision_config(ts.parameters) do
@@ -124,10 +127,16 @@ for inplace in ("!", "")
124127
)
125128
end
126129

127-
@set! ts.cache.extras = merge(ts.cache.extras, (; update_function))
130+
if ts.cache isa TrainingBackendCache
131+
@set! ts.cache.extras = merge(ts.cache.extras, (; update_function))
132+
else
133+
cache = TrainingBackendCache(backend, False(), nothing, (; update_function))
134+
@set! ts.cache = cache
135+
end
128136
end
129137

130138
opt_state, ps = update_function(ts.optimizer_state, ts.parameters, grads)
139+
131140
@set! ts.parameters = ps
132141
@set! ts.optimizer_state = opt_state
133142
@set! ts.step = ts.step + 1
@@ -141,110 +150,108 @@ for inplace in ("!", "")
141150
end
142151

143152
# XXX: recompile with a warning if new input types are used
144-
@eval Profiler.@annotate "Compile Train Step" function Lux.Training.$(fname)(
153+
@eval Profiler.@annotate "Train Step" function Lux.Training.$(fname)(
145154
backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState
146155
) where {F}
147-
device = get_device((ts.parameters, ts.states, ts.optimizer_state, data))
148-
@assert device isa ReactantDevice
149-
is_sharded = device.device === nothing
150-
151-
dps = if backend.return_gradients isa True
152-
Functors.fmap(Utils.zero, ts.parameters; exclude=MLDataDevices.isleaf)
156+
if (
157+
ts.cache isa TrainingBackendCache &&
158+
hasfield(typeof(ts.cache.extras), :compiled_grad_and_step_function)
159+
)
160+
(; compiled_grad_and_step_function, is_sharded) = ts.cache.extras
161+
ps = ts.parameters
162+
dparameters = ts.cache.dparameters
153163
else
154-
nothing
155-
end
164+
device = get_device((ts.parameters, ts.states, ts.optimizer_state, data))
165+
@assert device isa ReactantDevice
166+
is_sharded = device.device === nothing
167+
168+
dparameters = if backend.return_gradients isa True
169+
Functors.fmap(Utils.zero, ts.parameters; exclude=MLDataDevices.isleaf)
170+
else
171+
nothing
172+
end
156173

157-
$(ps_expr)
158-
159-
compiled_grad_and_step_function = with_default_precision_config(ts.parameters) do
160-
@compile sync = backend.sync $(internal_fn)(
161-
objective_function,
162-
ts.model,
163-
data,
164-
ps,
165-
ts.states,
166-
ts.optimizer_state,
167-
dps,
168-
is_sharded,
169-
)
174+
$(ps_expr)
175+
176+
compiled_grad_and_step_function =
177+
with_default_precision_config(ts.parameters) do
178+
@compile sync = backend.sync compute_gradients_internal_and_step!(
179+
objective_function,
180+
ts.model,
181+
data,
182+
ps,
183+
ts.states,
184+
ts.optimizer_state,
185+
dparameters,
186+
is_sharded,
187+
)
188+
end
189+
190+
if ts.cache isa TrainingBackendCache
191+
@set! ts.cache.dparameters = dparameters
192+
@set! ts.cache.extras = merge(
193+
ts.cache.extras, (; compiled_grad_and_step_function, is_sharded)
194+
)
195+
else
196+
cache = TrainingBackendCache(
197+
backend,
198+
False(),
199+
dparameters,
200+
(; compiled_grad_and_step_function, is_sharded),
201+
)
202+
@set! ts.cache = cache
203+
end
204+
@set! ts.objective_function = objective_function
170205
end
171206

207+
@show typeof(dparameters)
208+
172209
grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
173210
objective_function,
174211
ts.model,
175212
data,
176213
ps,
177214
ts.states,
178215
ts.optimizer_state,
179-
dps,
216+
dparameters,
180217
is_sharded,
181218
)
182219

183-
cache = TrainingBackendCache(
184-
backend, False(), dps, (; compiled_grad_and_step_function, is_sharded)
185-
)
186-
@set! ts.cache = cache
187-
@set! ts.objective_function = objective_function
188-
@set! ts.states = st
189-
@set! ts.parameters = ps
190-
@set! ts.optimizer_state = opt_state
191-
@set! ts.step = ts.step + 1
192-
193-
return grads, loss, stats, ts
194-
end
195-
196-
@eval Profiler.@annotate "Train Step" function Lux.Training.$(fname)(
197-
::ReactantBackend,
198-
obj_fn::F,
199-
data,
200-
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend},F},
201-
) where {F}
202-
grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
203-
obj_fn,
204-
ts.model,
205-
data,
206-
ts.parameters,
207-
ts.states,
208-
ts.optimizer_state,
209-
ts.cache.dparameters,
210-
ts.cache.extras.is_sharded,
211-
)
212-
213220
@set! ts.states = st
214221
@set! ts.parameters = ps
215222
@set! ts.optimizer_state = opt_state
216223
@set! ts.step = ts.step + 1
217224

218225
return grads, loss, stats, ts
219226
end
227+
end
220228

221-
@eval function $(internal_fn)(
222-
objective_function::F, model, data, ps, st, opt_state, ::Nothing, is_sharded::Bool
223-
) where {F}
224-
dps, loss, stats, stₙ = compute_gradients_internal(
225-
objective_function, model, data, ps, st
226-
)
229+
@eval function compute_gradients_internal_and_step!(
230+
objective_function::F, model, data, ps, st, opt_state, ::Nothing, is_sharded::Bool
231+
) where {F}
232+
dps, loss, stats, stₙ = compute_gradients_internal(
233+
objective_function, model, data, ps, st
234+
)
227235

228-
opt_state, psₙ = Optimisers.update!(opt_state, ps, dps)
229-
# Ensure sharding of input and output states are consistent
230-
is_sharded && mark_same_sharding_group(st, stₙ)
236+
opt_state, psₙ = Optimisers.update!(opt_state, ps, dps)
237+
# Ensure sharding of input and output states are consistent
238+
is_sharded && mark_same_sharding_group(st, stₙ)
231239

232-
return nothing, psₙ, loss, stats, stₙ, opt_state
233-
end
240+
return nothing, psₙ, loss, stats, stₙ, opt_state
241+
end
234242

235-
@eval function $(internal_fn)(
236-
objective_function::F, model, data, ps, st, opt_state, dps, is_sharded::Bool
237-
) where {F}
238-
dps, loss, stats, stₙ = compute_gradients_internal!(
239-
dps, objective_function, model, data, ps, st
240-
)
243+
@eval function compute_gradients_internal_and_step!(
244+
objective_function::F, model, data, ps, st, opt_state, dps, is_sharded::Bool
245+
) where {F}
246+
dps, loss, stats, stₙ = compute_gradients_internal!(
247+
dps, objective_function, model, data, ps, st
248+
)
241249

242-
opt_state, psₙ = Optimisers.update!(opt_state, ps, dps)
243-
# Ensure sharding of input and output states are consistent
244-
is_sharded && mark_same_sharding_group(st, stₙ)
250+
opt_state, psₙ = Optimisers.update!(opt_state, ps, dps)
251+
# Ensure sharding of input and output states are consistent
252+
is_sharded && mark_same_sharding_group(st, stₙ)
245253

246-
return dps, psₙ, loss, stats, stₙ, opt_state
247-
end
254+
return dps, psₙ, loss, stats, stₙ, opt_state
248255
end
249256

250257
mark_same_sharding_group(args...) = Functors.fmap(mark_same_sharding_group_inner, args...)

0 commit comments

Comments
 (0)