Probability density/mass functions from jax.scipy.stats not supported on Metal

(Copied from https://github.com/google/jax/issues/20835)

I am attempting to use JAX on Metal (on a M1 Pro chip) to model discrete (count) data. I've installed the latest version jax-metal 0.0.6 using pip.

The installation seems to have worked overall as I can perform basic Jax array operations on GPU. However, when I try to compute the (log-)PMFs/PDFs of random variables which are defined in terms of the (log-)Gamma function I get errors like the one below which seems to indicate that the lax.lgamma function is not supported under the hood on M1 metal.

This is essential functionality for a wide class of probabilistic machine learning models. Note that following functions (among others) are broken as a result:

  • jax.scipy.stats.binom.logpmf
  • jax.scipy.stats.nbinom.logpmf
  • jax.scipy.stats.poisson.logpmf
  • jax.scipy.stats.dirichlet.logpdf
  • jax.scipy.stats.beta.logpdf
  • jax.scipy.stats.gamma.logpdf

...

>>> jax.scipy.stats.binom.logpmf(1, n=2, p=0.5)

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/stats/binom.py", line 31, in logpmf
    gammaln(n + 1),
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/special.py", line 44, in gammaln
    return lax.lgamma(x)
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/lax/special.py", line 46, in lgamma
    return lgamma_p.bind(x)
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 422, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 425, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'chlo.lgamma'
<stdin>:1:0: note: see current operation: %0 = "chlo.lgamma"(%arg0) : (tensor<f32>) -> tensor<f32>

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.23
numpy:  1.26.4
python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='PHS027794', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000', machine='arm64')