Skip to content

Commit 48822f4

Browse files
authored
fix: clipnorm support for reactant (#1469)
1 parent e1104d3 commit 48822f4

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.21.1"
4+
version = "1.21.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/helpers/optimizers.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,18 @@ function make_reactant_compatible(
7070
)
7171
end
7272

73+
function make_reactant_compatible(
74+
opt::Optimisers.ClipNorm, dev::ReactantDevice, outermost=Val(true)
75+
)
76+
opt_ra = Optimisers.ClipNorm(
77+
Utils.to_rarray(opt.omega; track_numbers=Integer, _dev_to_kwargs(dev)...),
78+
opt.p,
79+
false,
80+
)
81+
outermost isa Val{true} && return ReactantOptimiser(opt_ra)
82+
return opt_ra
83+
end
84+
7385
function optimisers_setup_with_jit end
7486

7587
end

test/reactant/training_tests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,25 @@ end
106106
hlo = repr(@code_hlo(Optimisers.update(st_opt, ps, ps)))
107107
@test length(findall("stablehlo.if", hlo)) == (2 + 1 + 2) * 2
108108
end
109+
110+
@testitem "Reactant Optimisers Patch: ClipNorm" tags = [:reactant] setup = [SharedTestSetup] begin
111+
using Lux, Random, Reactant, Optimisers
112+
113+
dev = reactant_device(; force=true)
114+
115+
model = Chain(
116+
Dense(2 => 4, relu), Chain(Dense(4 => 2, relu; use_bias=false), Dense(2 => 2))
117+
)
118+
ps, st = Lux.setup(Random.default_rng(), model) |> dev
119+
120+
x = randn(Float32, 2, 32) |> dev
121+
122+
train_state = Training.TrainState(
123+
model, ps, st, OptimiserChain(ClipNorm(0.5), Descent(0.1))
124+
)
125+
126+
_, loss, stats, ts = Training.single_train_step(
127+
AutoEnzyme(), MSELoss(), (x, x), train_state; return_gradients=Val(false)
128+
)
129+
@test loss isa Number
130+
end

0 commit comments

Comments
 (0)