Skip to content

Commit 6ef2ad5

Browse files
committed
fix: switch arg position
1 parent f99bab9 commit 6ef2ad5

File tree

1 file changed

+4
-14
lines changed

1 file changed

+4
-14
lines changed

src/helpers/training.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ Internal fields:
4141
"""
4242
@concrete struct TrainState
4343
cache
44-
allocator_cache
4544
objective_function
45+
allocator_cache
4646
model
4747
parameters
4848
states
@@ -56,8 +56,8 @@ MLDataDevices.isleaf(::TrainState) = true
5656
function Adapt.adapt_structure(to::AbstractDevice, ts::TrainState)
5757
return TrainState(
5858
nothing,
59-
get_allocator_cache(to),
6059
nothing,
60+
get_allocator_cache(to),
6161
ts.model,
6262
to(ts.parameters),
6363
to(ts.states),
@@ -94,17 +94,7 @@ function Adapt.adapt_structure(to::ReactantDevice, ts::TrainState)
9494
This ensures the optimizer state and other internal states are on the device on
9595
construction.
9696
"""
97-
return TrainState(
98-
nothing,
99-
nothing,
100-
nothing,
101-
ts.model,
102-
to(ts.parameters),
103-
to(ts.states),
104-
ts.optimizer,
105-
to(ts.optimizer_state),
106-
ts.step,
107-
)
97+
return @invoke Adapt.adapt_structure(to::AbstractDevice, ts::TrainState)
10898
end
10999

110100
"""
@@ -130,7 +120,7 @@ function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.Abstr
130120
end
131121
st_opt = Optimisers.setup(optimizer, ps)
132122
return TrainState(
133-
nothing, get_allocator_cache(dev), nothing, model, ps, st, optimizer, st_opt, 0
123+
nothing, nothing, get_allocator_cache(dev), model, ps, st, optimizer, st_opt, 0
134124
)
135125
end
136126

0 commit comments

Comments
 (0)