|
| 1 | +# Common mistake that users make is passing in a compiled function |
| 2 | +function Lux.Training.TrainState( |
| 3 | + ::Reactant.Compiler.Thunk, ps, st, optimizer::Optimisers.AbstractRule |
| 4 | +) |
| 5 | + throw( |
| 6 | + ArgumentError( |
| 7 | + """ |
| 8 | +Invalid TrainState construction using a compiled function. |
| 9 | +
|
| 10 | +`TrainState` is being constructed with a reactant compiled function, i.e. a |
| 11 | +`Reactant.Compiler.Thunk`. This is likely a mistake as the model should be |
| 12 | +passed in directly without being compiled first. |
| 13 | +
|
| 14 | +This is likely originating from the following style of usage: |
| 15 | +
|
| 16 | +```julia |
| 17 | +using Lux, Reactant, Random, Optimisers |
| 18 | +
|
| 19 | +rdev = reactant_device() |
| 20 | +
|
| 21 | +model = Dense(10, 10) |
| 22 | +ps, st = Lux.setup(Random.default_rng(), model) |> rdev |
| 23 | +x = rand(10) |> rdev |
| 24 | +
|
| 25 | +model_compiled = @compile model(x, ps, st) |
| 26 | +
|
| 27 | +train_state = Training.TrainState(model_compiled, ps, st, Adam()) |
| 28 | +``` |
| 29 | +
|
| 30 | +Instead avoid compiling the model and pass it directly to `TrainState`. When |
| 31 | +`single_train_step` or other functions are called on the `TrainState`, the |
| 32 | +model will be compiled automatically. |
| 33 | +
|
| 34 | +```julia |
| 35 | +train_state = Training.TrainState(model, ps, st, Adam()) |
| 36 | +``` |
| 37 | +
|
| 38 | +For end-to-end usage example refer to the documentation: |
| 39 | +<https://lux.csail.mit.edu/stable/manual/compiling_lux_models#compile_lux_model_trainstate> |
| 40 | +""" |
| 41 | + ), |
| 42 | + ) |
| 43 | +end |
| 44 | + |
1 | 45 | function objective_function_wrapper(objective_function::F, model, ps, st, data) where {F} |
2 | 46 | loss, stₙ, stats = objective_function(model, ps, st, data) |
3 | 47 | return loss, Reactant.ignore_derivatives(stₙ), Reactant.ignore_derivatives(stats) |
|
0 commit comments