Skip to content

Commit 254e800

Browse files
committed
feat: introduce a special (::dev)(...) syntax to compile reactant models
1 parent cb92a56 commit 254e800

File tree

4 files changed

+37
-1
lines changed

4 files changed

+37
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.2.0"
4+
version = "1.3.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/LuxReactantExt/LuxReactantExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ using Static: False
99
using Lux: Lux, LuxOps, Training
1010
using Lux.Training: TrainingBackendCache, ReactantBackend
1111

12+
Lux.is_extension_loaded(::Val{:Reactant}) = true
13+
1214
include("training.jl")
1315

1416
end

src/Lux.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using Optimisers: Optimisers
1717
using Random: Random, AbstractRNG
1818
using Static: StaticBool, StaticInt, StaticSymbol, True, False, static, known, dynamic
1919
using Reexport: Reexport, @reexport
20+
using Setfield: @set!
2021
using Statistics: mean
2122

2223
import LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer,
@@ -69,6 +70,7 @@ include("helpers/losses.jl")
6970
include("helpers/recursive_ops.jl")
7071
include("helpers/match_eltype.jl")
7172
include("helpers/size_propagator.jl")
73+
include("helpers/compile.jl")
7274

7375
# AutoDiff
7476
include("autodiff/api.jl")

src/helpers/compile.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
function compile_reactant_model end
2+
3+
function (dev::AbstractDevice)(model::AbstractExplicitLayer, x, ps, st)
4+
return model, dev(x), dev(ps), dev(st)
5+
end
6+
7+
function (dev::XLADevice)(model::AbstractExplicitLayer, x, ps, st)
8+
if !is_extension_loaded(Val(:Reactant))
9+
error("Reactant.jl needs to be loaded to use XLA compilation")
10+
end
11+
return compile_reactant_model(model, x, ps, st)
12+
end
13+
14+
rewrite_conv_as_cross_correlation(model, ps, st, ::Functors.KeyPath) = model, ps, st
15+
function rewrite_conv_as_cross_correlation(model::Conv, ps, st, ::Functors.KeyPath)
16+
if model.cross_correlation isa False
17+
@set! model.cross_correlation = True()
18+
@allowscalar ps_reversed = reverse(
19+
ps.weight; dims=ntuple(i -> i, ndims(ps.weight) - 2)
20+
)
21+
@set! ps.weight = ps_reversed
22+
end
23+
return model, ps, st
24+
end
25+
26+
function (dev::AMDGPUDevice)(model::AbstractExplicitLayer, x, ps, st)
27+
# Conv is slow in MIOpen, so we rewrite it as a cross-correlation
28+
model, ps, st = Experimental.layer_map(
29+
rewrite_conv_as_cross_correlation, model, ps, st
30+
)
31+
return model, dev(x), dev(ps), dev(st)
32+
end

0 commit comments

Comments
 (0)