Skip to content

Commit d275ad2

Browse files
committed
fix: use finite differences to test ground truth
1 parent 60d52ba commit d275ad2

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

lib/LuxTestUtils/src/autodiff.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ struct Constant{T}
22
val::T
33
end
44

5-
# Zygote.jl on CPU
5+
# FiniteDiff.jl on CPU
66
function ground_truth_gradient(f, args...)
77
cdev = cpu_device()
88
f_cpu = try
@@ -12,7 +12,7 @@ function ground_truth_gradient(f, args...)
1212
be fixed by defining overloads using ConstructionBase.jl" err
1313
f
1414
end
15-
return gradient(f_cpu, AutoZygote(), map(cdev, args)...)
15+
return gradient(f_cpu, AutoFiniteDiff(), map(cdev, args)...)
1616
end
1717

1818
# Zygote.jl
@@ -80,13 +80,13 @@ end
8080
"""
8181
test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...)
8282
83-
Test the gradients of `f` with respect to `args` using the specified backends.
83+
Test the gradients of `f` with respect to `args` using the specified backends. The ground
84+
truth gradients are computed using FiniteDiff.jl on CPU.
8485
8586
| Backend | ADType | CPU | GPU | Notes |
8687
|:-------------- |:------------------- |:--- |:--- |:----------------- |
8788
| Zygote.jl | `AutoZygote()` | ✔ | ✔ | |
8889
| ForwardDiff.jl | `AutoForwardDiff()` | ✔ | ✖ | `len ≤ 32` |
89-
| FiniteDiff.jl | `AutoFiniteDiff()` | ✔ | ✖ | `len ≤ 32` |
9090
| Enzyme.jl | `AutoEnzyme()` | ✔ | ✖ | Only Reverse Mode |
9191
9292
## Arguments
@@ -157,7 +157,6 @@ function test_gradients(
157157
push!(backends, AutoZygote())
158158
if !on_gpu
159159
total_length 32 && push!(backends, AutoForwardDiff())
160-
total_length 32 && push!(backends, AutoFiniteDiff())
161160
# TODO: Move Enzyme out of here once it supports GPUs
162161
if enable_enzyme_reverse_mode || ENZYME_TESTING_ENABLED[]
163162
mode = if enzyme_set_runtime_activity

0 commit comments

Comments
 (0)