tf.einsum runs on CPU only with tensorflow-macos and tensorflow-metal

I found out that tf.einsum is extremely slow when running on Apple Silicone. And it's due to tf.einsum being placed on CPU instead of GPU. This causes very low GPU utilization and slowness of the operation itself. I tried forcing it on GPU, but does not work.

It can be reproduced with following code that logs device placement for tf.einsum.

import tensorflow as tf

# https://www.tensorflow.org/guide/gpu#logging_device_placement
tf.debugging.set_log_device_placement(True)

print(tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

B=1000
S=200
N=16
K=64
a = tf.random.uniform(shape=[B,N,S,K], minval=0.0, maxval=100.0, dtype=float)
b = tf.random.uniform(shape=[B,N,S,K], minval=0.0, maxval=100.0, dtype=float)
tf.einsum('BNSK,BNTK->BNTS', a,b)  # Notice in the output Einsum is placed on CPU.

Output:

Executing op Einsum in device /job:localhost/replica:0/task:0/device:CPU:0
2023-03-08 16:44:01.119144: I tensorflow/core/common_runtime/placer.cc:114] inputs_0: (_Arg): /job:localhost/replica:0/task:0/device:CPU:0
2023-03-08 16:44:01.119182: I tensorflow/core/common_runtime/placer.cc:114] inputs_1: (_Arg): /job:localhost/replica:0/task:0/device:CPU:0
2023-03-08 16:44:01.119191: I tensorflow/core/common_runtime/placer.cc:114] Einsum: (Einsum): /job:localhost/replica:0/task:0/device:CPU:0
2023-03-08 16:44:01.119197: I tensorflow/core/common_runtime/placer.cc:114] output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:CPU:0

If I use tf.matmul, then everything happen on GPU which is what I want. With this limitation I basically have to convert all the code that uses tf.einsum to using tf.matmul and is very inconvenient. Has anyone noticed the same problem? What's your solution?