Skip to content

Commit 1d0d982

Browse files
authored
test: enzyme tests are now passing (#1475)
* test: enzyme tests are now passing * test: try switching off runtime activity * chore: bump min enzyme version * feat: add an enzyme rule to help out * chore: remove dep * fix: cleanup * fix: dispatch * fix: unwanted packages
1 parent 4d94f03 commit 1d0d982

File tree

24 files changed

+97
-132
lines changed

24 files changed

+97
-132
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.21.2"
4+
version = "1.21.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -87,8 +87,8 @@ ComponentArrays = "0.15.28"
8787
ConcreteStructs = "0.2.3"
8888
DiffResults = "1.1"
8989
DispatchDoctor = "0.4.26"
90-
Enzyme = "0.13.49"
91-
EnzymeCore = "0.8.12"
90+
Enzyme = "0.13.74"
91+
EnzymeCore = "0.8.13"
9292
FastClosures = "0.3.2"
9393
Flux = "0.16.3"
9494
ForwardDiff = "0.10.36, 1"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ ComponentArrays = "0.15.22"
4747
Documenter = "1.4"
4848
DocumenterCitations = "1.3.6"
4949
DocumenterVitepress = "0.2"
50-
Enzyme = "0.13.35"
50+
Enzyme = "0.13.74"
5151
FiniteDiff = "2.23.1"
5252
Flux = "0.16.3"
5353
ForwardDiff = "0.10.36, 1"

examples/CIFAR10/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ BFloat16s = "0.5.0"
2424
Comonicon = "1.0.8"
2525
ConcreteStructs = "0.2.3"
2626
DataAugmentation = "0.3"
27-
Enzyme = "0.13.35"
27+
Enzyme = "0.13.74"
2828
ImageCore = "0.10.2"
2929
ImageShow = "0.3.8"
3030
Interpolations = "0.15.1, 0.16"

examples/ConvolutionalVAE/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1616
[compat]
1717
ConcreteStructs = "0.2.3"
1818
DataAugmentation = "0.3.2"
19-
Enzyme = "0.13.35"
19+
Enzyme = "0.13.74"
2020
ImageShow = "0.3.8"
2121
Images = "0.26.1"
2222
Lux = "1.4.1"

examples/LSTMEncoderDecoder/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
99

1010
[compat]
1111
CairoMakie = "0.13.6, 0.14, 0.15"
12-
Enzyme = "0.13.44"
12+
Enzyme = "0.13.74"
1313
Lux = "1.12.4"
1414
MLUtils = "0.4.8"
1515
Optimisers = "0.4.6"

examples/PINN2DPDE/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1414
[compat]
1515
ADTypes = "1.10"
1616
CairoMakie = "0.12.10, 0.13, 0.14, 0.15"
17-
Enzyme = "0.13"
17+
Enzyme = "0.13.74"
1818
Lux = "1"
1919
MLUtils = "0.4.4"
2020
OnlineStats = "1.7.1"

examples/RealNVP/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
[compat]
1414
CairoMakie = "0.13.1, 0.14, 0.15"
1515
ConcreteStructs = "0.2.3"
16-
Enzyme = "0.13.35"
16+
Enzyme = "0.13.74"
1717
Lux = "1.5"
1818
MLUtils = "0.4.5"
1919
Optimisers = "0.4.6"

lib/LuxLib/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxLib"
22
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.11.1"
4+
version = "1.11.2"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -72,8 +72,8 @@ CUDA = "5.8"
7272
ChainRulesCore = "1.24"
7373
Compat = "4.16"
7474
DispatchDoctor = "0.4.12"
75-
Enzyme = "0.13.35"
76-
EnzymeCore = "0.8.12"
75+
Enzyme = "0.13.74"
76+
EnzymeCore = "0.8.13"
7777
FastClosures = "0.3.2"
7878
ForwardDiff = "0.10.36, 1"
7979
Functors = "0.5"

lib/LuxLib/src/impl/batchnorm.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ function batchnorm(
3939
) where {F,xT,N}
4040
(μ, σ²), (rμ, rσ²) = compute_batch_statistics(
4141
x,
42-
reshape_norm_dims(x, rμ),
43-
reshape_norm_dims(x, rσ²),
42+
reshape_norm_dims(rμ, size(x)),
43+
reshape_norm_dims(rσ², size(x)),
4444
batchnorm_reduce_dims(x),
4545
training,
4646
momentum,
@@ -73,7 +73,7 @@ function batchnorm_affine_normalize(
7373
ϵ,
7474
) where {F,xT,μT,σ²T,N}
7575
return affine_normalize(
76-
act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ
76+
act, x, μ, σ², reshape_norm_dims(γ, size(x)), reshape_norm_dims(β, size(x)), ϵ
7777
)
7878
end
7979

lib/LuxLib/src/impl/groupnorm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function groupnorm_affine_normalize(
4242
ϵ,
4343
) where {F,N,xT,μT,σ²T}
4444
return affine_normalize(
45-
act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ
45+
act, x, μ, σ², reshape_norm_dims(γ, size(x)), reshape_norm_dims(β, size(x)), ϵ
4646
)
4747
end
4848

0 commit comments

Comments
 (0)