Skip to content

Broken Exporting Lux Models to Jax example on jax >= 0.7.0 #1558

@bartvanerp

Description

@bartvanerp

The example for exporting a Lux model to Jax seems to be broken on JAX starting from version 0.7.0. With enzyme-ad=0.0.10 (build from source on Raspberry Pi) the example throws an error in the Python part:

Traceback (most recent call last):
  File "/home/pi/Documents/_/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2076, in default_process_primitive
    out_avals, effs = _cached_abstract_eval(primitive, *aval_qdds, **params)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pi/Documents/_/.venv/lib/python3.11/site-packages/jax/_src/util.py", line 298, in wrapper
    return cached(config.trace_context() if trace_context_in_key else _ignore(),
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: unhashable type: 'dict'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/pi/Documents/_/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 1907, in _verify_params_are_hashable
    hash(v)
TypeError: unhashable type: 'dict'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/pi/Documents/_/lux_example/example.py", line 61, in <module>
    jax.jit(run_lux_model)(
  File "/home/pi/Documents/_/lux_example/example.py", line 23, in run_lux_model
    return hlo_call(
           ^^^^^^^^^
  File "/home/pi/Documents/_/.venv/lib/python3.11/site-packages/enzyme_ad/jax/primitives.py", line 1361, in hlo_call
    return _enzyme_primal_p.bind(
           ^^^^^^^^^^^^^^^^^^^^^^
TypeError: As of JAX v0.7, parameters to jaxpr equations must have __hash__ and __eq__ methods. In a call to primitive enzyme_primal, the value of parameter source was not hashable: (PyTreeDef((*, *, *, *, *, *, *, *, *, *, *)), {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10}, {0: -1}, 'module @"reactant_Chain{@..." attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {\n  func.func @main(%arg0: tensor<4x1x28x28xf32> {enzymexla.memory_effects = []}, %arg1: tensor<6x1x5x5xf32> {enzymexla.memory_effects = []}, %arg2: tensor<6xf32> {enzymexla.memory_effects = []}, %arg3: tensor<16x6x5x5xf32> {enzymexla.memory_effects = []}, %arg4: tensor<16xf32> {enzymexla.memory_effects = []}, %arg5: tensor<256x128xf32> {enzymexla.memory_effects = []}, %arg6: tensor<128xf32> {enzymexla.memory_effects = []}, %arg7: tensor<128x84xf32> {enzymexla.memory_effects = []}, %arg8: tensor<84xf32> {enzymexla.memory_effects = []}, %arg9: tensor<84x10xf32> {enzymexla.memory_effects = []}, %arg10: tensor<10xf32> {enzymexla.memory_effects = []}) -> tensor<4x10xf32> attributes {enzymexla.memory_effects = []} {\n    %cst = stablehlo.constant dense<0.000000e+00> : tensor<84x4xf32>\n    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x4xf32>\n    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<4x16x8x8xf32>\n    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x4xf32>\n    %cst_3 = stablehlo.constant dense<0xFF800000> : tensor<f32>\n    %0 = stablehlo.reverse %arg1, dims = [3, 2] : tensor<6x1x5x5xf32>\n    %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 1, 0]x[o, i, 1, 0]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<4x1x28x28xf32>, tensor<6x1x5x5xf32>) -> tensor<24x24x6x4xf32>\n    %2 = stablehlo.broadcast_in_dim %arg2, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x4xf32>\n    %3 = stablehlo.add %1, %2 : tensor<24x24x6x4xf32>\n    %4 = stablehlo.maximum %cst_2, %3 : tensor<24x24x6x4xf32>\n    %5 = "stablehlo.reduce_window"(%4, %cst_3) <{base_dilations = array<i64: 1, 1, 1, 1>, padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({\n    ^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):\n      %24 = stablehlo.maximum %arg11, %arg12 : tensor<f32>\n      stablehlo.return %24 : tensor<f32>\n    }) : (tensor<24x24x6x4xf32>, tensor<f32>) -> tensor<12x12x6x4xf32>\n    %6 = stablehlo.reverse %arg3, dims = [3, 2] : tensor<16x6x5x5xf32>\n    %7 = stablehlo.convolution(%5, %6) dim_numbers = [0, 1, f, b]x[o, i, 1, 0]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<12x12x6x4xf32>, tensor<16x6x5x5xf32>) -> tensor<4x16x8x8xf32>\n    %8 = stablehlo.broadcast_in_dim %arg4, dims = [1] : (tensor<16xf32>) -> tensor<4x16x8x8xf32>\n    %9 = stablehlo.add %7, %8 : tensor<4x16x8x8xf32>\n    %10 = stablehlo.maximum %cst_1, %9 : tensor<4x16x8x8xf32>\n    %11 = "stablehlo.reduce_window"(%10, %cst_3) <{base_dilations = array<i64: 1, 1, 1, 1>, padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 2, 2>}> ({\n    ^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):\n      %24 = stablehlo.maximum %arg11, %arg12 : tensor<f32>\n      stablehlo.return %24 : tensor<f32>\n    }) : (tensor<4x16x8x8xf32>, tensor<f32>) -> tensor<4x16x4x4xf32>\n    %12 = stablehlo.reshape %11 : (tensor<4x16x4x4xf32>) -> tensor<4x256xf32>\n    %13 = stablehlo.dot_general %arg5, %12, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x128xf32>, tensor<4x256xf32>) -> tensor<128x4xf32>\n    %14 = stablehlo.broadcast_in_dim %arg6, dims = [0] : (tensor<128xf32>) -> tensor<128x4xf32>\n    %15 = stablehlo.add %13, %14 : tensor<128x4xf32>\n    %16 = stablehlo.maximum %cst_0, %15 : tensor<128x4xf32>\n    %17 = stablehlo.dot_general %arg7, %16, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<128x84xf32>, tensor<128x4xf32>) -> tensor<84x4xf32>\n    %18 = stablehlo.broadcast_in_dim %arg8, dims = [0] : (tensor<84xf32>) -> tensor<84x4xf32>\n    %19 = stablehlo.add %17, %18 : tensor<84x4xf32>\n    %20 = stablehlo.maximum %cst, %19 : tensor<84x4xf32>\n    %21 = stablehlo.broadcast_in_dim %arg10, dims = [1] : (tensor<10xf32>) -> tensor<4x10xf32>\n    %22 = stablehlo.dot_general %20, %arg9, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<84x4xf32>, tensor<84x10xf32>) -> tensor<4x10xf32>\n    %23 = stablehlo.add %22, %21 : tensor<4x10xf32>\n    return %23 : tensor<4x10xf32>\n  }\n}', {})
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Temporary fix: downgrade JAX to 0.6.2.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions