Support for complex numbers

Hi,

Are there plans to support complex numbers?

Something simple like this:

def return_complex(x):
    return x*1+1.0j

x = jnp.ones((10))
print(return_complex(x))

results in an error.

Replies

When I run the snippet below:

import jax
import jax.numpy as jnp

def return_complex(x):
    return x*1+1.0j

x = jnp.ones((10))
print(return_complex(x))

I get no errors.

https://colab.research.google.com/notebooks/welcome.ipynb#scrollTo=VpzLDQeWxgyY&line=8&uniqifier=1

It runs on other platforms (like Colab) but not on jax-metal.

There is a discussion here on metal - https://github.com/google/jax/issues/8074