Skip to content

Commit 504f298

Browse files
committed
feat: update LuxTestUtils to support 1.12
1 parent 226beb3 commit 504f298

File tree

4 files changed

+33
-16
lines changed

4 files changed

+33
-16
lines changed

lib/LuxTestUtils/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxTestUtils"
22
uuid = "ac9de150-d08f-4546-94fb-7472b5760531"
33
authors = ["Avik Pal <[email protected]>"]
4-
version = "2.0.1"
4+
version = "2.1.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -32,7 +32,7 @@ Enzyme = "0.13.81"
3232
FiniteDiff = "2.23.1"
3333
ForwardDiff = "0.10.36, 1"
3434
Functors = "0.5"
35-
JET = "0.9.6, 0.10"
35+
JET = "0.9.6, 0.10, 0.11"
3636
MLDataDevices = "1.6.10"
3737
Optimisers = "0.3.4, 0.4"
3838
Test = "1.10"

lib/LuxTestUtils/src/LuxTestUtils.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,41 @@ using Zygote: Zygote
3232
const CRC = ChainRulesCore
3333
const FD = FiniteDiff
3434

35+
const JET_TESTING_ENABLED = Ref{Bool}(false)
36+
const ENZYME_TESTING_ENABLED = Ref{Bool}(false)
37+
const ZYGOTE_TESTING_ENABLED = Ref{Bool}(false)
38+
3539
# Check if JET will work
3640
try
3741
using JET: JET, JETTestFailure, get_reports, report_call, report_opt
38-
# XXX: In 1.11, JET leads to stack overflows
39-
global JET_TESTING_ENABLED = v"1.10-" VERSION < v"1.11-"
42+
JET_TESTING_ENABLED[] = true
4043
catch err
4144
@error "`JET.jl` did not successfully precompile on $(VERSION). All `@jet` tests will \
4245
be skipped." maxlog = 1 err = err
43-
global JET_TESTING_ENABLED = false
46+
JET_TESTING_ENABLED[] = false
4447
end
4548

46-
# Check if Enzyme will work
47-
try
48-
using Enzyme: Enzyme
49-
__ftest(x) = x
50-
Enzyme.autodiff(Enzyme.Reverse, __ftest, Enzyme.Active, Enzyme.Active(2.0))
51-
global ENZYME_TESTING_ENABLED = Sys.islinux()
52-
catch err
53-
global ENZYME_TESTING_ENABLED = false
49+
# Check if Enzyme will work (only on non-prerelease versions)
50+
@static if isempty(VERSION.prerelease)
51+
try
52+
using Enzyme: Enzyme
53+
Enzyme.gradient(Enzyme.Reverse, Base.Fix1(sum, abs2), ones(Float32, 10))
54+
ENZYME_TESTING_ENABLED[] = Sys.islinux()
55+
catch err
56+
@error "`Enzyme.jl` did not successfully differentiate a simple function or \
57+
failed to load on $(VERSION). All Enzyme tests will be \
58+
skipped." maxlog = 1 err = err
59+
ENZYME_TESTING_ENABLED[] = false
60+
end
61+
end
62+
63+
function __init__()
64+
ZYGOTE_TESTING_ENABLED[] = VERSION < v"1.12-"
65+
66+
if JET_TESTING_ENABLED[]
67+
# JET doesn't work nicely on 1.11
68+
JET_TESTING_ENABLED[] = VERSION < v"1.11-" || VERSION v"1.12-"
69+
end
5470
end
5571

5672
include("test_softfail.jl")

lib/LuxTestUtils/src/autodiff.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ function gradient(f::F, ::AutoEnzyme{Nothing}, args...) where {F}
3131
end
3232

3333
function gradient(f::F, ad::AutoEnzyme{<:Enzyme.ReverseMode}, args...) where {F}
34-
!ENZYME_TESTING_ENABLED &&
34+
if !ENZYME_TESTING_ENABLED[]
3535
return ntuple(Returns(GradientComputationSkipped()), length(args))
36+
end
3637

3738
args_activity = map(args) do x
3839
needs_gradient(x) && return Enzyme.Duplicated(x, Enzyme.make_zero(x))
@@ -158,7 +159,7 @@ function test_gradients(
158159
total_length 32 && push!(backends, AutoForwardDiff())
159160
total_length 32 && push!(backends, AutoFiniteDiff())
160161
# TODO: Move Enzyme out of here once it supports GPUs
161-
if enable_enzyme_reverse_mode || ENZYME_TESTING_ENABLED
162+
if enable_enzyme_reverse_mode || ENZYME_TESTING_ENABLED[]
162163
mode = if enzyme_set_runtime_activity
163164
Enzyme.set_runtime_activity(Enzyme.Reverse)
164165
else

lib/LuxTestUtils/src/jet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Test Broken
5353
```
5454
"""
5555
macro jet(expr, args...)
56-
!JET_TESTING_ENABLED && return :()
56+
!JET_TESTING_ENABLED[] && return :()
5757

5858
all_args, call_extras, opt_extras = [], [], []
5959
target_modules_set = false

0 commit comments

Comments
 (0)