-
Notifications
You must be signed in to change notification settings - Fork 81
Open
Description
This is a follow-up issue of EnzymeAD/Reactant.jl#1749, which reported failure of Forward mode AD using Reactant.
It was fixed by the latest release for simple functions, but still fails for a Lux model as follows.
Minimal Example
using Lux, Enzyme, Reactant, Random
Reactant.set_default_backend("cpu")
const cdev = cpu_device()
const xdev = reactant_device(; force=true)
model = Chain(
Dense(4, 16, tanh),
Dense(16, 16, tanh),
Dense(16, 4)
)
ps_model_ra, st_model_ra = Lux.setup(Random.default_rng(), model) |> xdev
model_stateful = StatefulLuxLayer(model, ps_model_ra, st_model_ra)
square_func(x) = x.^2
x = collect(Float32, 1:4)
x_onehot = Enzyme.onehot(x)
x_ra = Reactant.to_rarray(x)
x_ra_onehot = @jit Enzyme.onehot(x_ra)
m = @jit model_stateful(x_ra)
# Forward:
@jit Enzyme.autodiff(Forward, square_func, BatchDuplicated(x_ra, x_ra_onehot)) # works
@jit Enzyme.autodiff(Forward, model_stateful, BatchDuplicated(x_ra, x_ra_onehot)) # failsThe last line causes the Julia process to abort, after printing this error
loc(callsite("overloaded_mul!/dot_general"("~/.julia/packages/Reactant/IgTfV/src/stdlibs/LinearAlgebra.jl":255:0) at "traced_call/call"("~/.julia/packages/Reactant/IgTfV/src/ControlFlow.jl":8:0))): error: Mismatched ranks of types2 vs 1
LLVM ERROR: Failed to infer result type(s):
"stablehlo.add"(...) {} : (tensor<4x16xf32>, tensor<16xf32>) -> ( ??? )Environment
(jl_CeLRAA) pkg> st
Status `/tmp/jl_CeLRAA/Project.toml`
[7da242da] Enzyme v0.13.87
[b2108857] Lux v1.24.0
[3c362404] Reactant v0.2.171
[9a3f8284] Random v1.11.0Metadata
Metadata
Assignees
Labels
No labels