Skip to content

Forward mode AD fails and aborts the repl #1523

@jacobleft

Description

@jacobleft

This is a follow-up issue of EnzymeAD/Reactant.jl#1749, which reported failure of Forward mode AD using Reactant.
It was fixed by the latest release for simple functions, but still fails for a Lux model as follows.

Minimal Example

using Lux, Enzyme, Reactant, Random

Reactant.set_default_backend("cpu")
const cdev = cpu_device()
const xdev = reactant_device(; force=true)

model = Chain(
    Dense(4, 16, tanh),
    Dense(16, 16, tanh),
    Dense(16, 4)
)


ps_model_ra, st_model_ra = Lux.setup(Random.default_rng(), model) |> xdev
model_stateful = StatefulLuxLayer(model, ps_model_ra, st_model_ra)


square_func(x) = x.^2

x = collect(Float32, 1:4)
x_onehot = Enzyme.onehot(x)

x_ra = Reactant.to_rarray(x)
x_ra_onehot = @jit Enzyme.onehot(x_ra)

m = @jit model_stateful(x_ra)

# Forward:
@jit Enzyme.autodiff(Forward, square_func, BatchDuplicated(x_ra, x_ra_onehot)) # works
@jit Enzyme.autodiff(Forward, model_stateful, BatchDuplicated(x_ra, x_ra_onehot)) # fails

The last line causes the Julia process to abort, after printing this error

loc(callsite("overloaded_mul!/dot_general"("~/.julia/packages/Reactant/IgTfV/src/stdlibs/LinearAlgebra.jl":255:0) at "traced_call/call"("~/.julia/packages/Reactant/IgTfV/src/ControlFlow.jl":8:0))): error: Mismatched ranks of types2 vs 1
LLVM ERROR: Failed to infer result type(s):
"stablehlo.add"(...) {} : (tensor<4x16xf32>, tensor<16xf32>) -> ( ??? )

Environment

(jl_CeLRAA) pkg> st
Status `/tmp/jl_CeLRAA/Project.toml`
  [7da242da] Enzyme v0.13.87
  [b2108857] Lux v1.24.0
  [3c362404] Reactant v0.2.171
  [9a3f8284] Random v1.11.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions