JAX Metal error: failed to legalize operation 'mhlo.scatter'

I only get this error when using the JAX Metal device (CPU is fine). It seems to be a problem whenever I want to modify values of an array in-place using at and set.

note: see current operation: 
%2903 = "mhlo.scatter"(%arg3, %2902, %2893) ({
^bb0(%arg4: tensor<f32>, %arg5: tensor<f32>):
  "mhlo.return"(%arg5) : (tensor<f32>) -> ()
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<10x100x4xf32>, tensor<1xsi32>, tensor<10x4xf32>) -> tensor<10x100x4xf32>
        blocks = blocks.at[i].set(
...
Post not yet marked as solved Up vote post of Cemlyn Down vote post of Cemlyn
891 views

Replies

Thx for reporting it. Several bugs of advanced indexing, involving GatherOp and ScatterOp conversion have been fixed at the tip. The example in the post shall be fixed. The fixes will be integrated into next release of jax-metal.

  • No worries @dingshuhan, love all the work the Apple team are doing to get JAX to use Apple GPUs! Hope you have a great week

Add a Comment

Hi, is there an ETA on the next release with these fixes? Thanks in advance!

I would like to add to this thread that 64bit mode doesn't work for me when gpu is enabled.

Thanks for the great code!

@dingshuhan any ETA on this?

Hi @dingshuhan any update on this release? The bug on scatter means that jacobians in Jax can't be computed on jax-metal.

This seems to be fixed in the latest release (0.0.5)