@@ -70,42 +70,42 @@ function compute_gradients_internal(objective_function::F, model, data, ps, st)
7070 )
7171end
7272
73- Profiler. @annotate " Compile Compute Gradients" function Lux. Training. compute_gradients_impl (
73+ Profiler. @annotate " Compute Gradients" function Lux. Training. compute_gradients_impl (
7474 backend:: ReactantBackend , objective_function:: F , data, ts:: Training.TrainState
7575) where {F}
76- compiled_gradient_function = with_default_precision_config (ts. parameters) do
77- @compile sync = backend. sync compute_gradients_internal (
78- objective_function, ts. model, data, ts. parameters, ts. states
79- )
76+ if (
77+ ts. cache isa TrainingBackendCache &&
78+ hasfield (typeof (ts. cache. extras), :compiled_gradient_function )
79+ )
80+ compiled_gradient_function = ts. cache. extras. compiled_gradient_function
81+ else
82+ compiled_gradient_function = with_default_precision_config (ts. parameters) do
83+ @compile sync = backend. sync compute_gradients_internal (
84+ objective_function, ts. model, data, ts. parameters, ts. states
85+ )
86+ end
87+
88+ if ts. cache isa TrainingBackendCache
89+ @set! ts. cache. extras = merge (ts. cache. extras, (; compiled_gradient_function))
90+ else
91+ cache = TrainingBackendCache (
92+ backend, False (), nothing , (; compiled_gradient_function)
93+ )
94+ @set! ts. cache = cache
95+ end
96+ @set! ts. objective_function = objective_function
8097 end
8198
8299 grads, loss, stats, st = compiled_gradient_function (
83100 objective_function, ts. model, data, ts. parameters, ts. states
84101 )
85102
86- cache = TrainingBackendCache (backend, False (), nothing , (; compiled_gradient_function))
87- @set! ts. cache = cache
88- @set! ts. objective_function = objective_function
89- @set! ts. states = st
90- return grads, loss, stats, ts
91- end
92-
93- Profiler. @annotate " Compute Gradients" function Lux. Training. compute_gradients_impl (
94- :: ReactantBackend ,
95- obj_fn:: F ,
96- data,
97- ts:: Training.TrainState{<:TrainingBackendCache{<:ReactantBackend},F} ,
98- ) where {F}
99- grads, loss, stats, st = ts. cache. extras. compiled_gradient_function (
100- obj_fn, ts. model, data, ts. parameters, ts. states
101- )
102103 @set! ts. states = st
103104 return grads, loss, stats, ts
104105end
105106
106107for inplace in (" !" , " " )
107108 fname = Symbol (:single_train_step_impl , inplace)
108- internal_fn = Symbol (:compute_gradients_internal_and_step , inplace)
109109 apply_gradients_fn = Symbol (:apply_gradients , inplace)
110110 update_fn = Symbol (:update , inplace)
111111
@@ -141,110 +141,108 @@ for inplace in ("!", "")
141141 end
142142
143143 # XXX : recompile with a warning if new input types are used
144- @eval Profiler. @annotate " Compile Train Step" function Lux. Training.$ (fname)(
144+ @eval Profiler. @annotate " Train Step" function Lux. Training.$ (fname)(
145145 backend:: ReactantBackend , objective_function:: F , data, ts:: Training.TrainState
146146 ) where {F}
147- device = get_device ((ts. parameters, ts. states, ts. optimizer_state, data))
148- @assert device isa ReactantDevice
149- is_sharded = device. device === nothing
150-
151- dps = if backend. return_gradients isa True
152- Functors. fmap (Utils. zero, ts. parameters; exclude= MLDataDevices. isleaf)
147+ if (
148+ ts. cache isa TrainingBackendCache &&
149+ hasfield (typeof (ts. cache. extras), :compiled_grad_and_step_function )
150+ )
151+ (; compiled_grad_and_step_function, is_sharded) = ts. cache. extras
152+ ps = ts. parameters
153+ dparameters = ts. cache. dparameters
153154 else
154- nothing
155- end
155+ device = get_device ((ts. parameters, ts. states, ts. optimizer_state, data))
156+ @assert device isa ReactantDevice
157+ is_sharded = device. device === nothing
158+
159+ dparameters = if backend. return_gradients isa True
160+ Functors. fmap (Utils. zero, ts. parameters; exclude= MLDataDevices. isleaf)
161+ else
162+ nothing
163+ end
156164
157- $ (ps_expr)
158-
159- compiled_grad_and_step_function = with_default_precision_config (ts. parameters) do
160- @compile sync = backend. sync $ (internal_fn)(
161- objective_function,
162- ts. model,
163- data,
164- ps,
165- ts. states,
166- ts. optimizer_state,
167- dps,
168- is_sharded,
169- )
165+ $ (ps_expr)
166+
167+ compiled_grad_and_step_function =
168+ with_default_precision_config (ts. parameters) do
169+ @compile sync = backend. sync compute_gradients_internal_and_step! (
170+ objective_function,
171+ ts. model,
172+ data,
173+ ps,
174+ ts. states,
175+ ts. optimizer_state,
176+ dparameters,
177+ is_sharded,
178+ )
179+ end
180+
181+ if ts. cache isa TrainingBackendCache
182+ @set! ts. cache. dparameters = dparameters
183+ @set! ts. cache. extras = merge (
184+ ts. cache. extras, (; compiled_grad_and_step_function, is_sharded)
185+ )
186+ else
187+ cache = TrainingBackendCache (
188+ backend,
189+ False (),
190+ dparameters,
191+ (; compiled_grad_and_step_function, is_sharded),
192+ )
193+ @set! ts. cache = cache
194+ end
195+ @set! ts. objective_function = objective_function
170196 end
171197
198+ @show typeof (dparameters)
199+
172200 grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function (
173201 objective_function,
174202 ts. model,
175203 data,
176204 ps,
177205 ts. states,
178206 ts. optimizer_state,
179- dps ,
207+ dparameters ,
180208 is_sharded,
181209 )
182210
183- cache = TrainingBackendCache (
184- backend, False (), dps, (; compiled_grad_and_step_function, is_sharded)
185- )
186- @set! ts. cache = cache
187- @set! ts. objective_function = objective_function
188- @set! ts. states = st
189- @set! ts. parameters = ps
190- @set! ts. optimizer_state = opt_state
191- @set! ts. step = ts. step + 1
192-
193- return grads, loss, stats, ts
194- end
195-
196- @eval Profiler. @annotate " Train Step" function Lux. Training.$ (fname)(
197- :: ReactantBackend ,
198- obj_fn:: F ,
199- data,
200- ts:: Training.TrainState{<:TrainingBackendCache{<:ReactantBackend},F} ,
201- ) where {F}
202- grads, ps, loss, stats, st, opt_state = ts. cache. extras. compiled_grad_and_step_function (
203- obj_fn,
204- ts. model,
205- data,
206- ts. parameters,
207- ts. states,
208- ts. optimizer_state,
209- ts. cache. dparameters,
210- ts. cache. extras. is_sharded,
211- )
212-
213211 @set! ts. states = st
214212 @set! ts. parameters = ps
215213 @set! ts. optimizer_state = opt_state
216214 @set! ts. step = ts. step + 1
217215
218216 return grads, loss, stats, ts
219217 end
218+ end
220219
221- @eval function $ (internal_fn) (
222- objective_function:: F , model, data, ps, st, opt_state, :: Nothing , is_sharded:: Bool
223- ) where {F}
224- dps, loss, stats, stₙ = compute_gradients_internal (
225- objective_function, model, data, ps, st
226- )
220+ @eval function compute_gradients_internal_and_step! (
221+ objective_function:: F , model, data, ps, st, opt_state, :: Nothing , is_sharded:: Bool
222+ ) where {F}
223+ dps, loss, stats, stₙ = compute_gradients_internal (
224+ objective_function, model, data, ps, st
225+ )
227226
228- opt_state, psₙ = Optimisers. update! (opt_state, ps, dps)
229- # Ensure sharding of input and output states are consistent
230- is_sharded && mark_same_sharding_group (st, stₙ)
227+ opt_state, psₙ = Optimisers. update! (opt_state, ps, dps)
228+ # Ensure sharding of input and output states are consistent
229+ is_sharded && mark_same_sharding_group (st, stₙ)
231230
232- return nothing , psₙ, loss, stats, stₙ, opt_state
233- end
231+ return nothing , psₙ, loss, stats, stₙ, opt_state
232+ end
234233
235- @eval function $ (internal_fn) (
236- objective_function:: F , model, data, ps, st, opt_state, dps, is_sharded:: Bool
237- ) where {F}
238- dps, loss, stats, stₙ = compute_gradients_internal! (
239- dps, objective_function, model, data, ps, st
240- )
234+ @eval function compute_gradients_internal_and_step! (
235+ objective_function:: F , model, data, ps, st, opt_state, dps, is_sharded:: Bool
236+ ) where {F}
237+ dps, loss, stats, stₙ = compute_gradients_internal! (
238+ dps, objective_function, model, data, ps, st
239+ )
241240
242- opt_state, psₙ = Optimisers. update! (opt_state, ps, dps)
243- # Ensure sharding of input and output states are consistent
244- is_sharded && mark_same_sharding_group (st, stₙ)
241+ opt_state, psₙ = Optimisers. update! (opt_state, ps, dps)
242+ # Ensure sharding of input and output states are consistent
243+ is_sharded && mark_same_sharding_group (st, stₙ)
245244
246- return dps, psₙ, loss, stats, stₙ, opt_state
247- end
245+ return dps, psₙ, loss, stats, stₙ, opt_state
248246end
249247
250248mark_same_sharding_group (args... ) = Functors. fmap (mark_same_sharding_group_inner, args... )
0 commit comments