@@ -70,43 +70,43 @@ function compute_gradients_internal(objective_function::F, model, data, ps, st)
7070 )
7171end
7272
73- Profiler. @annotate " Compile Compute Gradients" function Lux. Training. compute_gradients_impl (
73+ Profiler. @annotate " Compute Gradients" function Lux. Training. compute_gradients_impl (
7474 backend:: ReactantBackend , objective_function:: F , data, ts:: Training.TrainState
7575) where {F}
76- compiled_gradient_function = with_default_precision_config (ts. parameters) do
77- @compile sync = backend. sync compute_gradients_internal (
78- objective_function, ts. model, data, ts. parameters, ts. states
79- )
76+ if (
77+ ts. cache isa TrainingBackendCache &&
78+ hasfield (typeof (ts. cache. extras), :compiled_gradient_function )
79+ )
80+ compiled_gradient_function = ts. cache. extras. compiled_gradient_function
81+ else
82+ compiled_gradient_function = with_default_precision_config (ts. parameters) do
83+ @compile sync = backend. sync compute_gradients_internal (
84+ objective_function, ts. model, data, ts. parameters, ts. states
85+ )
86+ end
87+
88+ if ts. cache isa TrainingBackendCache
89+ @set! ts. cache. extras = merge (ts. cache. extras, (; compiled_gradient_function))
90+ else
91+ cache = TrainingBackendCache (
92+ backend, False (), nothing , (; compiled_gradient_function)
93+ )
94+ @set! ts. cache = cache
95+ end
96+ @set! ts. objective_function = objective_function
8097 end
8198
8299 grads, loss, stats, st = compiled_gradient_function (
83100 objective_function, ts. model, data, ts. parameters, ts. states
84101 )
85102
86- cache = TrainingBackendCache (backend, False (), nothing , (; compiled_gradient_function))
87- @set! ts. cache = cache
88- @set! ts. objective_function = objective_function
89- @set! ts. states = st
90- return grads, loss, stats, ts
91- end
92-
93- Profiler. @annotate " Compute Gradients" function Lux. Training. compute_gradients_impl (
94- :: ReactantBackend ,
95- obj_fn:: F ,
96- data,
97- ts:: Training.TrainState{<:TrainingBackendCache{<:ReactantBackend},F} ,
98- ) where {F}
99- grads, loss, stats, st = ts. cache. extras. compiled_gradient_function (
100- obj_fn, ts. model, data, ts. parameters, ts. states
101- )
102103 @set! ts. states = st
103104 return grads, loss, stats, ts
104105end
105106
106107for inplace in (" !" , " " )
107108 fname = Symbol (:single_train_step_impl , inplace)
108- internal_fn = Symbol (:compute_gradients_internal_and_step , inplace)
109- apply_gradients_fn = Symbol (:apply_gradients , inplace)
109+ apply_gradients_fn = Symbol (:apply_gradients_reactant , inplace)
110110 update_fn = Symbol (:update , inplace)
111111
112112 # Ideally users never hit this dispatch but it is still good to have as a fallback
@@ -115,7 +115,10 @@ for inplace in ("!", "")
115115 )(
116116 ts:: Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}} , grads
117117 )
118- if hasfield (typeof (ts. cache. extras), :update_function )
118+ if (
119+ ts. cache isa TrainingBackendCache &&
120+ hasfield (typeof (ts. cache. extras), :update_function )
121+ )
119122 update_function = ts. cache. extras. update_function
120123 else
121124 update_function = with_default_precision_config (ts. parameters) do
@@ -124,10 +127,16 @@ for inplace in ("!", "")
124127 )
125128 end
126129
127- @set! ts. cache. extras = merge (ts. cache. extras, (; update_function))
130+ if ts. cache isa TrainingBackendCache
131+ @set! ts. cache. extras = merge (ts. cache. extras, (; update_function))
132+ else
133+ cache = TrainingBackendCache (backend, False (), nothing , (; update_function))
134+ @set! ts. cache = cache
135+ end
128136 end
129137
130138 opt_state, ps = update_function (ts. optimizer_state, ts. parameters, grads)
139+
131140 @set! ts. parameters = ps
132141 @set! ts. optimizer_state = opt_state
133142 @set! ts. step = ts. step + 1
@@ -141,110 +150,108 @@ for inplace in ("!", "")
141150 end
142151
143152 # XXX : recompile with a warning if new input types are used
144- @eval Profiler. @annotate " Compile Train Step" function Lux. Training.$ (fname)(
153+ @eval Profiler. @annotate " Train Step" function Lux. Training.$ (fname)(
145154 backend:: ReactantBackend , objective_function:: F , data, ts:: Training.TrainState
146155 ) where {F}
147- device = get_device ((ts. parameters, ts. states, ts. optimizer_state, data))
148- @assert device isa ReactantDevice
149- is_sharded = device. device === nothing
150-
151- dps = if backend. return_gradients isa True
152- Functors. fmap (Utils. zero, ts. parameters; exclude= MLDataDevices. isleaf)
156+ if (
157+ ts. cache isa TrainingBackendCache &&
158+ hasfield (typeof (ts. cache. extras), :compiled_grad_and_step_function )
159+ )
160+ (; compiled_grad_and_step_function, is_sharded) = ts. cache. extras
161+ ps = ts. parameters
162+ dparameters = ts. cache. dparameters
153163 else
154- nothing
155- end
164+ device = get_device ((ts. parameters, ts. states, ts. optimizer_state, data))
165+ @assert device isa ReactantDevice
166+ is_sharded = device. device === nothing
167+
168+ dparameters = if backend. return_gradients isa True
169+ Functors. fmap (Utils. zero, ts. parameters; exclude= MLDataDevices. isleaf)
170+ else
171+ nothing
172+ end
156173
157- $ (ps_expr)
158-
159- compiled_grad_and_step_function = with_default_precision_config (ts. parameters) do
160- @compile sync = backend. sync $ (internal_fn)(
161- objective_function,
162- ts. model,
163- data,
164- ps,
165- ts. states,
166- ts. optimizer_state,
167- dps,
168- is_sharded,
169- )
174+ $ (ps_expr)
175+
176+ compiled_grad_and_step_function =
177+ with_default_precision_config (ts. parameters) do
178+ @compile sync = backend. sync compute_gradients_internal_and_step! (
179+ objective_function,
180+ ts. model,
181+ data,
182+ ps,
183+ ts. states,
184+ ts. optimizer_state,
185+ dparameters,
186+ is_sharded,
187+ )
188+ end
189+
190+ if ts. cache isa TrainingBackendCache
191+ @set! ts. cache. dparameters = dparameters
192+ @set! ts. cache. extras = merge (
193+ ts. cache. extras, (; compiled_grad_and_step_function, is_sharded)
194+ )
195+ else
196+ cache = TrainingBackendCache (
197+ backend,
198+ False (),
199+ dparameters,
200+ (; compiled_grad_and_step_function, is_sharded),
201+ )
202+ @set! ts. cache = cache
203+ end
204+ @set! ts. objective_function = objective_function
170205 end
171206
207+ @show typeof (dparameters)
208+
172209 grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function (
173210 objective_function,
174211 ts. model,
175212 data,
176213 ps,
177214 ts. states,
178215 ts. optimizer_state,
179- dps ,
216+ dparameters ,
180217 is_sharded,
181218 )
182219
183- cache = TrainingBackendCache (
184- backend, False (), dps, (; compiled_grad_and_step_function, is_sharded)
185- )
186- @set! ts. cache = cache
187- @set! ts. objective_function = objective_function
188- @set! ts. states = st
189- @set! ts. parameters = ps
190- @set! ts. optimizer_state = opt_state
191- @set! ts. step = ts. step + 1
192-
193- return grads, loss, stats, ts
194- end
195-
196- @eval Profiler. @annotate " Train Step" function Lux. Training.$ (fname)(
197- :: ReactantBackend ,
198- obj_fn:: F ,
199- data,
200- ts:: Training.TrainState{<:TrainingBackendCache{<:ReactantBackend},F} ,
201- ) where {F}
202- grads, ps, loss, stats, st, opt_state = ts. cache. extras. compiled_grad_and_step_function (
203- obj_fn,
204- ts. model,
205- data,
206- ts. parameters,
207- ts. states,
208- ts. optimizer_state,
209- ts. cache. dparameters,
210- ts. cache. extras. is_sharded,
211- )
212-
213220 @set! ts. states = st
214221 @set! ts. parameters = ps
215222 @set! ts. optimizer_state = opt_state
216223 @set! ts. step = ts. step + 1
217224
218225 return grads, loss, stats, ts
219226 end
227+ end
220228
221- @eval function $ (internal_fn) (
222- objective_function:: F , model, data, ps, st, opt_state, :: Nothing , is_sharded:: Bool
223- ) where {F}
224- dps, loss, stats, stₙ = compute_gradients_internal (
225- objective_function, model, data, ps, st
226- )
229+ @eval function compute_gradients_internal_and_step! (
230+ objective_function:: F , model, data, ps, st, opt_state, :: Nothing , is_sharded:: Bool
231+ ) where {F}
232+ dps, loss, stats, stₙ = compute_gradients_internal (
233+ objective_function, model, data, ps, st
234+ )
227235
228- opt_state, psₙ = Optimisers. update! (opt_state, ps, dps)
229- # Ensure sharding of input and output states are consistent
230- is_sharded && mark_same_sharding_group (st, stₙ)
236+ opt_state, psₙ = Optimisers. update! (opt_state, ps, dps)
237+ # Ensure sharding of input and output states are consistent
238+ is_sharded && mark_same_sharding_group (st, stₙ)
231239
232- return nothing , psₙ, loss, stats, stₙ, opt_state
233- end
240+ return nothing , psₙ, loss, stats, stₙ, opt_state
241+ end
234242
235- @eval function $ (internal_fn) (
236- objective_function:: F , model, data, ps, st, opt_state, dps, is_sharded:: Bool
237- ) where {F}
238- dps, loss, stats, stₙ = compute_gradients_internal! (
239- dps, objective_function, model, data, ps, st
240- )
243+ @eval function compute_gradients_internal_and_step! (
244+ objective_function:: F , model, data, ps, st, opt_state, dps, is_sharded:: Bool
245+ ) where {F}
246+ dps, loss, stats, stₙ = compute_gradients_internal! (
247+ dps, objective_function, model, data, ps, st
248+ )
241249
242- opt_state, psₙ = Optimisers. update! (opt_state, ps, dps)
243- # Ensure sharding of input and output states are consistent
244- is_sharded && mark_same_sharding_group (st, stₙ)
250+ opt_state, psₙ = Optimisers. update! (opt_state, ps, dps)
251+ # Ensure sharding of input and output states are consistent
252+ is_sharded && mark_same_sharding_group (st, stₙ)
245253
246- return dps, psₙ, loss, stats, stₙ, opt_state
247- end
254+ return dps, psₙ, loss, stats, stₙ, opt_state
248255end
249256
250257mark_same_sharding_group (args... ) = Functors. fmap (mark_same_sharding_group_inner, args... )
0 commit comments