Error while using JAX

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! 2024-03-23 22:04:38.947506: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB maxCacheSize: 5.33 GB

loc("-":0:0): error: current mps dialect version is 1.0.0, can't parse version 1.1.0 /AppleInternal/Library/BuildRoots/495c257e-668e-11ee-93ce-926038f30c31/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1097: failed assertion `Error importing MLIR bytecode. ' zsh: abort python -c 'import jax; print(jax.numpy.arange(10))'

Post not yet marked as solved Up vote post of Guillaume117 Down vote post of Guillaume117
302 views
  • Was able to resolve this by going from OS 14.0 -> 14.4. Currently using jax/jaxlib 0.4.23 and jax metal 0.0.6.

Add a Comment

Replies

Thanks it works for me. Hope that it won't interfere with torch.backends