Skip to content

Commit 8c14c89

Browse files
authored
feat: more informative error on constructing trainstate with compiled function (#1547)
* feat: more informative error on constructing trainstate with compiled function * fix: grammatical error
1 parent 0d0558e commit 8c14c89

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

ext/LuxReactantExt/training.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,47 @@
1+
# Common mistake that users make is passing in a compiled function
2+
function Lux.Training.TrainState(
3+
::Reactant.Compiler.Thunk, ps, st, optimizer::Optimisers.AbstractRule
4+
)
5+
throw(
6+
ArgumentError(
7+
"""
8+
Invalid TrainState construction using a compiled function.
9+
10+
`TrainState` is being constructed with a reactant compiled function, i.e. a
11+
`Reactant.Compiler.Thunk`. This is likely a mistake as the model should be
12+
passed in directly without being compiled first.
13+
14+
This is likely originating from the following style of usage:
15+
16+
```julia
17+
using Lux, Reactant, Random, Optimisers
18+
19+
rdev = reactant_device()
20+
21+
model = Dense(10, 10)
22+
ps, st = Lux.setup(Random.default_rng(), model) |> rdev
23+
x = rand(10) |> rdev
24+
25+
model_compiled = @compile model(x, ps, st)
26+
27+
train_state = Training.TrainState(model_compiled, ps, st, Adam())
28+
```
29+
30+
Instead avoid compiling the model and pass it directly to `TrainState`. When
31+
`single_train_step` or other functions are called on the `TrainState`, the
32+
model will be compiled automatically.
33+
34+
```julia
35+
train_state = Training.TrainState(model, ps, st, Adam())
36+
```
37+
38+
For end-to-end usage example refer to the documentation:
39+
<https://lux.csail.mit.edu/stable/manual/compiling_lux_models#compile_lux_model_trainstate>
40+
"""
41+
),
42+
)
43+
end
44+
145
function objective_function_wrapper(objective_function::F, model, ps, st, data) where {F}
246
loss, stₙ, stats = objective_function(model, ps, st, data)
347
return loss, Reactant.ignore_derivatives(stₙ), Reactant.ignore_derivatives(stats)

test/reactant/training_tests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,19 @@ end
167167
@test length(Reactant.XLA.devices(Reactant.XLA.sharding(loss.data))) == 8
168168
end
169169
end
170+
171+
@testitem "Reactant.Compiler.Thunk in TrainState" tags = [:reactant] setup = [
172+
SharedTestSetup
173+
] begin
174+
using Lux, Random, Reactant, Optimisers
175+
176+
rdev = reactant_device(; force=true)
177+
178+
model = Dense(10, 10)
179+
ps, st = Lux.setup(Random.default_rng(), model) |> rdev
180+
x = rand(10) |> rdev
181+
182+
model_compiled = @compile model(x, ps, st)
183+
184+
@test_throws ArgumentError Training.TrainState(model_compiled, ps, st, Adam())
185+
end

0 commit comments

Comments
 (0)