-
Notifications
You must be signed in to change notification settings - Fork 81
Open
Description
The example for exporting a Lux model to Jax seems to be broken on JAX starting from version 0.7.0. With enzyme-ad=0.0.10 (build from source on Raspberry Pi) the example throws an error in the Python part:
Traceback (most recent call last):
File "/home/pi/Documents/_/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 2076, in default_process_primitive
out_avals, effs = _cached_abstract_eval(primitive, *aval_qdds, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pi/Documents/_/.venv/lib/python3.11/site-packages/jax/_src/util.py", line 298, in wrapper
return cached(config.trace_context() if trace_context_in_key else _ignore(),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: unhashable type: 'dict'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/pi/Documents/_/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 1907, in _verify_params_are_hashable
hash(v)
TypeError: unhashable type: 'dict'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/pi/Documents/_/lux_example/example.py", line 61, in <module>
jax.jit(run_lux_model)(
File "/home/pi/Documents/_/lux_example/example.py", line 23, in run_lux_model
return hlo_call(
^^^^^^^^^
File "/home/pi/Documents/_/.venv/lib/python3.11/site-packages/enzyme_ad/jax/primitives.py", line 1361, in hlo_call
return _enzyme_primal_p.bind(
^^^^^^^^^^^^^^^^^^^^^^
TypeError: As of JAX v0.7, parameters to jaxpr equations must have __hash__ and __eq__ methods. In a call to primitive enzyme_primal, the value of parameter source was not hashable: (PyTreeDef((*, *, *, *, *, *, *, *, *, *, *)), {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10}, {0: -1}, 'module @"reactant_Chain{@..." attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {\n func.func @main(%arg0: tensor<4x1x28x28xf32> {enzymexla.memory_effects = []}, %arg1: tensor<6x1x5x5xf32> {enzymexla.memory_effects = []}, %arg2: tensor<6xf32> {enzymexla.memory_effects = []}, %arg3: tensor<16x6x5x5xf32> {enzymexla.memory_effects = []}, %arg4: tensor<16xf32> {enzymexla.memory_effects = []}, %arg5: tensor<256x128xf32> {enzymexla.memory_effects = []}, %arg6: tensor<128xf32> {enzymexla.memory_effects = []}, %arg7: tensor<128x84xf32> {enzymexla.memory_effects = []}, %arg8: tensor<84xf32> {enzymexla.memory_effects = []}, %arg9: tensor<84x10xf32> {enzymexla.memory_effects = []}, %arg10: tensor<10xf32> {enzymexla.memory_effects = []}) -> tensor<4x10xf32> attributes {enzymexla.memory_effects = []} {\n %cst = stablehlo.constant dense<0.000000e+00> : tensor<84x4xf32>\n %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x4xf32>\n %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<4x16x8x8xf32>\n %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x4xf32>\n %cst_3 = stablehlo.constant dense<0xFF800000> : tensor<f32>\n %0 = stablehlo.reverse %arg1, dims = [3, 2] : tensor<6x1x5x5xf32>\n %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 1, 0]x[o, i, 1, 0]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<4x1x28x28xf32>, tensor<6x1x5x5xf32>) -> tensor<24x24x6x4xf32>\n %2 = stablehlo.broadcast_in_dim %arg2, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x4xf32>\n %3 = stablehlo.add %1, %2 : tensor<24x24x6x4xf32>\n %4 = stablehlo.maximum %cst_2, %3 : tensor<24x24x6x4xf32>\n %5 = "stablehlo.reduce_window"(%4, %cst_3) <{base_dilations = array<i64: 1, 1, 1, 1>, padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({\n ^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):\n %24 = stablehlo.maximum %arg11, %arg12 : tensor<f32>\n stablehlo.return %24 : tensor<f32>\n }) : (tensor<24x24x6x4xf32>, tensor<f32>) -> tensor<12x12x6x4xf32>\n %6 = stablehlo.reverse %arg3, dims = [3, 2] : tensor<16x6x5x5xf32>\n %7 = stablehlo.convolution(%5, %6) dim_numbers = [0, 1, f, b]x[o, i, 1, 0]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<12x12x6x4xf32>, tensor<16x6x5x5xf32>) -> tensor<4x16x8x8xf32>\n %8 = stablehlo.broadcast_in_dim %arg4, dims = [1] : (tensor<16xf32>) -> tensor<4x16x8x8xf32>\n %9 = stablehlo.add %7, %8 : tensor<4x16x8x8xf32>\n %10 = stablehlo.maximum %cst_1, %9 : tensor<4x16x8x8xf32>\n %11 = "stablehlo.reduce_window"(%10, %cst_3) <{base_dilations = array<i64: 1, 1, 1, 1>, padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 2, 2>}> ({\n ^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):\n %24 = stablehlo.maximum %arg11, %arg12 : tensor<f32>\n stablehlo.return %24 : tensor<f32>\n }) : (tensor<4x16x8x8xf32>, tensor<f32>) -> tensor<4x16x4x4xf32>\n %12 = stablehlo.reshape %11 : (tensor<4x16x4x4xf32>) -> tensor<4x256xf32>\n %13 = stablehlo.dot_general %arg5, %12, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x128xf32>, tensor<4x256xf32>) -> tensor<128x4xf32>\n %14 = stablehlo.broadcast_in_dim %arg6, dims = [0] : (tensor<128xf32>) -> tensor<128x4xf32>\n %15 = stablehlo.add %13, %14 : tensor<128x4xf32>\n %16 = stablehlo.maximum %cst_0, %15 : tensor<128x4xf32>\n %17 = stablehlo.dot_general %arg7, %16, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<128x84xf32>, tensor<128x4xf32>) -> tensor<84x4xf32>\n %18 = stablehlo.broadcast_in_dim %arg8, dims = [0] : (tensor<84xf32>) -> tensor<84x4xf32>\n %19 = stablehlo.add %17, %18 : tensor<84x4xf32>\n %20 = stablehlo.maximum %cst, %19 : tensor<84x4xf32>\n %21 = stablehlo.broadcast_in_dim %arg10, dims = [1] : (tensor<10xf32>) -> tensor<4x10xf32>\n %22 = stablehlo.dot_general %20, %arg9, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<84x4xf32>, tensor<84x10xf32>) -> tensor<4x10xf32>\n %23 = stablehlo.add %22, %21 : tensor<4x10xf32>\n return %23 : tensor<4x10xf32>\n }\n}', {})
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.Temporary fix: downgrade JAX to 0.6.2.
Metadata
Metadata
Assignees
Labels
No labels