tensorflow-metal

RSS for tag

TensorFlow accelerates machine learning model training with Metal on Mac GPUs.

tensorflow-metal Documentation

Posts under tensorflow-metal tag

127 Posts
Sort by:
Post not yet marked as solved
0 Replies
31 Views
(Copied from https://github.com/google/jax/issues/20835) I am attempting to use JAX on Metal (on a M1 Pro chip) to model discrete (count) data. I've installed the latest version jax-metal 0.0.6 using pip. The installation seems to have worked overall as I can perform basic Jax array operations on GPU. However, when I try to compute the (log-)PMFs/PDFs of random variables which are defined in terms of the (log-)Gamma function I get errors like the one below which seems to indicate that the lax.lgamma function is not supported under the hood on M1 metal. This is essential functionality for a wide class of probabilistic machine learning models. Note that following functions (among others) are broken as a result: jax.scipy.stats.binom.logpmf jax.scipy.stats.nbinom.logpmf jax.scipy.stats.poisson.logpmf jax.scipy.stats.dirichlet.logpdf jax.scipy.stats.beta.logpdf jax.scipy.stats.gamma.logpdf ... >>> jax.scipy.stats.binom.logpmf(1, n=2, p=0.5) jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/stats/binom.py", line 31, in logpmf gammaln(n + 1), File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/special.py", line 44, in gammaln return lax.lgamma(x) File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/lax/special.py", line 46, in lgamma return lgamma_p.bind(x) File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 422, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 425, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive return primitive.impl(*tracers, **params) File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive outs = fun(*args) jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'chlo.lgamma' <stdin>:1:0: note: see current operation: %0 = "chlo.lgamma"(%arg0) : (tensor<f32>) -> tensor<f32> System info (python version, jaxlib version, accelerator, etc.) jax: 0.4.26 jaxlib: 0.4.23 numpy: 1.26.4 python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='PHS027794', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000', machine='arm64')
Posted
by lblaine.
Last updated
.
Post not yet marked as solved
0 Replies
156 Views
Copying from https://github.com/google/jax/issues/20750: import jax import jax.numpy as jnp def test_func(x, y): return x, y def main(): # Print available JAX devices print("JAX devices:", jax.devices()) # Create two random matrices a = jnp.array([[1.0, 2.0], [3.0, 4.0]]) b = jnp.array([[5.0, 6.0], [7.0, 8.0]]) # Perform matrix multiplication c = jnp.dot(a, b) # Print the result print("Result of matrix multiplication:") print(c) # Compute the gradient of sum of c with respect to a grad_a = jax.grad(lambda a: jnp.sum(jnp.dot(a, b)))(a) print("Gradient with respect to a:") print(grad_a) rng = jax.random.PRNGKey(0) test_input = jax.random.normal(key=rng, shape=(5,5,5)) initial_state = jax.numpy.array(0.0) x, y = jax.lax.scan(test_func, initial_state, test_input) if __name__ == "__main__": main() Gets: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! 2024-04-15 18:22:28.994752: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M2 Pro systemMemory: 16.00 GB maxCacheSize: 5.33 GB JAX devices: [METAL(id=0)] Result of matrix multiplication: [[19. 22.] [43. 50.]] Gradient with respect to a: [[11. 15.] [11. 15.]] zsh: segmentation fault python JAXTest.py With more info from the debugger: Current thread 0x00000001fdd3bac0 (most recent call first): File "/Users/.../anaconda3/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1213 in __call__ My configuration is: jax-metal : 0.0.6 jax: 0.4.26 jaxlib: 0.4.23 numpy: 1.24.3 python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64') macOS 14.4.1 (23E224) Before in 3.9+0.0.3 etc it wasn't happening.
Posted Last updated
.
Post not yet marked as solved
0 Replies
104 Views
Hi, I have encountered to a segfault error when I called something via jax.lax.scan. A minimum failing example is pasted below: $ ipython Python 3.9.6 (default, Feb 3 2024, 15:58:27) Type 'copyright', 'credits' or 'license' for more information IPython 8.18.1 -- An enhanced Interactive Python. Type '?' for help. In [1]: import jax In [2]: jax.__version__ Out[2]: '0.4.22' In [3]: import jaxlib In [4]: jaxlib.__version__ Out[4]: '0.4.22' In [6]: import jax.numpy as jnp In [7]: def f(carry, x): ...: return carry + x * x, x * x ...: ...: jax.lax.scan(f, jnp.zeros((), dtype=jnp.float32), jnp.arange(3, dtype=jnp.float32)) Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! 2024-04-16 01:03:52.483015: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M3 Max systemMemory: 36.00 GB maxCacheSize: 13.50 GB zsh: segmentation fault ipython This might be related to the thread below: https://developer.apple.com/forums/thread/749080 Strangely, when we call it jax.lax.scan is a very important building block, so I would greatly appreciate if this can be resolved soon.
Posted Last updated
.
Post not yet marked as solved
2 Replies
116 Views
Hardware: 16" 2023 MBP M3 Pro OS: 14.4.1 Memory: 36 GB python version: 3.8.16 TF-Metal version: tensorflow-metal 1.0.1 installed via pip TF version: 2.13.0 Tensorflow-Metal starts pretty slow, approximately 10s/iteration and over the course of 36 iteration progressively slows down to over 120s/iteration. Info log prints out that TFLite is using XNNPack. Can't share the TFLite model but it is relatively shallow, small, and simple. Uninstalled TF-Metal, and installed tensorflow. Inference speed picks right up and is rock solid at 0.78s/iteration. What is going on??? **TLDR, TFLite inference speed: TF Metal = 120s/iteration TF = 0.78s/iteration**
Posted
by markostam.
Last updated
.
Post not yet marked as solved
3 Replies
265 Views
Tried various how-tos on youtube and github. Have conda. Third step fails. conda install -c apple tensorflow-deps pip install tensorflow-macos pip install tensorflow-metal ERROR: Could not find a version that satisfies the requirement tensorflow-metal (from versions: none) ERROR: No matching distribution found for tensorflow-metal I see a lot of fixes for Intel-based Mac. None for M3. HELP!?
Posted
by gumstead.
Last updated
.
Post not yet marked as solved
2 Replies
226 Views
I tried running inference with the 2B model from https://github.com/google-deepmind/gemma on my M2 MacBook Pro, but it segfaults during sampling: https://pastebin.com/KECyz60T Note: out of the box it will try to load bfloat16 weights, which will fail. To avoid this, I patched line 30 in gemma/params.py to explicitly cast to float32: param_state = jax.tree_util.tree_map(lambda p: jnp.array(p, jnp.float32), params)
Posted
by mononofu.
Last updated
.
Post not yet marked as solved
1 Replies
245 Views
Tensorflow metal was working on my Power Mac Mac M3 until yesterday. Then my code started freezing. I ran the test script from https://developer.apple.com/metal/tensorflow-plugin/ and it now crashes - this used to work fine, but all of a sudden it does not. The results are shown below. Has anyone seen anything like this? Could this be a hardware problem? MacBook-Pro-3: carl$ python mac_tensorflow_test.py Epoch 1/5 1/782 [..............................] - ETA: 51:53 - loss: 6.0044 - accuracy: 0.0312Error: command buffer exited with error status. The Metal Performance Shaders operations encoded on it may not have completed. Error: (null) Ignored (for causing prior/excessive GPU errors) (00000004:kIOGPUCommandBufferCallbackErrorSubmissionsIgnored) <AGXG15XFamilyCommandBuffer: 0x1172515e0> label = <none> device = <AGXG15SDevice: 0x1588e6000> name = Apple M3 Pro commandQueue = <AGXG15XFamilyCommandQueue: 0x17427e400> label = <none> device = <AGXG15SDevice: 0x1588e6000> name = Apple M3 Pro retainedReferences = 1 Error: command buffer exited with error status. The Metal Performance Shaders operations encoded on it may not have completed. Error: (null) Ignored (for causing prior/excessive GPU errors) (00000004:kIOGPUCommandBufferCallbackErrorSubmissionsIgnored) <AGXG15XFamilyCommandBuffer: 0x117257b40> label = <none> device = <AGXG15SDevice: 0x1588e6000> name = Apple M3 Pro commandQueue = <AGXG15XFamilyCommandQueue: 0x17427e400> label = <none> device = <AGXG15SDevice: 0x1588e6000> name = Apple M3 Pro retainedReferences = 1 Many more rows of similar printouts follow.
Posted
by carl24k.
Last updated
.
Post not yet marked as solved
5 Replies
499 Views
Hi i am trying to set up tensorflow-metal as instructed by https://developer.apple.com/metal/tensorflow-plugin/ when running line (python -m pip install tensorflow-metal) I get the following error: ERROR: Could not find a version that satisfies the requirement tensorflow-metal (from versions: none) ERROR: No matching distribution found for tensorflow-metal According to the troubleshooting section: "Check that the Python version used in the environment is supported (Python 3.8, Python 3.9, Python 3.10)." My current version is Python 3.9.12. Any insight would be great!
Posted Last updated
.
Post not yet marked as solved
0 Replies
259 Views
Hi, I am looking for a routine to perform complex-valued linear algebra on the GPU in python for scientific programming, in particular quantum physics simulations. At the moment I am looking for a routine for complex-valued matrix multiplication. I found MLX has a routine for float matrix multiplication, but it does not directly work for complex-valued matrices. I figured a work-around by splitting the complex valued matrix into real and imaginary part and working with the pair, but it makes it cumbersome to integrate with the remainder of the code. I was hoping for a library-based implementation similar to cupy. I also tried out using the tensorflow linear algebra routines, but I couldn't get them to run on the GPU by now. Specifically, a testfile with a tensorflow.keras.applications.ResNet50 routine runs on the GPU, but the routines from tensorflow.linalg and tensorflow.math that I tested (matmul, expm, eigh) were not running on the GPU. Any advice on how to make linear algebra calculations on mac GPUs work is highly appreciated! For my application the unified memory might be especially beneficial. Thank you!
Posted
by MG607.
Last updated
.
Post not yet marked as solved
0 Replies
223 Views
InvalidArgumentError: Cannot assign a device for operation don_nn/model_2/branch_hidden0/MatMul/ReadVariableOp: Could not satisfy explicit device specification '' because the node {{colocation_node don_nn/model_2/branch_hidden0/MatMul/ReadVariableOp}} was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/device:GPU:0'. All available devices [/job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0].
Posted
by brrr.
Last updated
.
Post not yet marked as solved
2 Replies
448 Views
Problem I am trying to use the jax.numpy.einsum function (https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.einsum.html). However, for some subscripts, this seems to fail. Hardware Apple M1 Max, 32GB RAM Steps to Reproduce follow installation steps from https://developer.apple.com/metal/jax/ conda create -n 'jax_metal_demo' python=3.11 conda activate jax_metal_demo python -m pip install numpy wheel ml-dtypes==0.2.0 python -m pip install jax-metal Save the following code in a file called minimal_example.py import numpy as np from jax import device_put import jax.numpy as jnp np.random.seed(0) a = np.random.rand(11, 12, 13, 11, 12) b = np.random.rand(11, 12, 13) subscripts = 'ijklm,ijk->lmk' # intended result print(np.einsum(subscripts, a, b)) # will cause crash a, b = device_put(a), device_put(b) print(jnp.einsum(subscripts, a, b)) run the code python minimal_example.py Output I waas expecting Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! 2024-02-12 16:45:34.684973: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M1 Max systemMemory: 32.00 GB maxCacheSize: 10.67 GB Traceback (most recent call last): File "/Users/linus/workspace/minimal_example.py", line 15, in <module> print(jnp.einsum(subscripts, a, b)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/linus/miniforge3/envs/jax_metal_demo/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 3369, in einsum return _einsum_computation(operands, contractions, precision, # type: ignore[operator] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/linus/miniforge3/envs/jax_metal_demo/lib/python3.11/contextlib.py", line 81, in inner return func(*args, **kwds) ^^^^^^^^^^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/linus/workspace/minimal_example.py:15:6: error: failed to legalize operation 'mhlo.dot_general' print(jnp.einsum(subscripts, a, b)) ^ /Users/linus/workspace/minimal_example.py:15:6: note: see current operation: %0 = "mhlo.dot_general"(%arg1, %arg0) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [2], rhs_batching_dimensions = [2], lhs_contracting_dimensions = [0, 1], rhs_contracting_dimensions = [0, 1]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<11x12x13xf32>, tensor<11x12x13x11x12xf32>) -> tensor<13x11x12xf32> -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. Conclusion I would greatly appreciate any ideas for workarounds.
Posted Last updated
.
Post not yet marked as solved
3 Replies
735 Views
macbook pro m2 max/ 64G / macos:13.2.1 (22D68) import tensorflow as tf def runMnist(device = '/device:CPU:0'): with tf.device(device): #tf.config.set_default_device(device) mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10) ]) loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy']) model.fit(x_train, y_train, epochs=10) runMnist(device = '/device:CPU:0') runMnist(device = '/device:GPU:0')
Posted Last updated
.
Post not yet marked as solved
2 Replies
385 Views
Hi, I have a an issue with jax.numpy.linalg.inv(a). import jax.numpy.linalg as jnpl B = jnp.identity(2) jnpl.inv(B) Throws the following error: XlaRuntimeError: UNKNOWN: /var/folders/pw/wk5rfkjj6qggqp8r8zb2bw8w0000gn/T/ipykernel_34334/2572982404.py:9:0: error: failed to legalize operation 'mhlo.triangular_solve' /var/folders/pw/wk5rfkjj6qggqp8r8zb2bw8w0000gn/T/ipykernel_34334/2572982404.py:9:0: note: called from /var/folders/pw/wk5rfkjj6qggqp8r8zb2bw8w0000gn/T/ipykernel_34334/2572982404.py:9:0: note: see current operation: %120 = \"mhlo.triangular_solve\"(%42#4, %119) {left_side = true, lower = true, transpose_a = #mhlo&lt;transpose NO_TRANSPOSE&gt;, unit_diagonal = true} : (tensor&lt;2x2xf32&gt;, tensor&lt;2x2xf32&gt;) -&gt; tensor&lt;2x2xf32&gt; Any ideas what could be the issue or how to solve it?
Posted Last updated
.
Post not yet marked as solved
5 Replies
698 Views
Hello, We all face issues with the latest tensorflow gpu. Incorrect result, errors etc... We all agreed to pay extra for the M1/2/3 so we could work on a professional grade computer but in the end we must use CPU. When will apple actually comment on that and provide updates. I totally understand these issues aren't fixed overnight and take some time, but i've never seen any apple dev answer saying that they understand and they're working on a fix. I've basically bought a Mac M3 Pro to be able to run on GPU some stuff without having to purchase a server and it's now useless. It's really frustrating.
Posted
by ivxnszn.
Last updated
.
Post not yet marked as solved
0 Replies
321 Views
I haven't used the GPU implementation for over a year now due to constant issues (I use tf.config.set_visible_devices([], 'GPU') to use CPU only. I have also had a couple of issues with model convergence using GPU, however this issue seems more prominent, and possibly unrelated. Here is an example of code that causes a memory leak using GPU (I cannot link the dataset, but it is called: Text classification documentation, by TANISHQ DUBLISH on Kaggle. import pandas as pd import numpy as np import matplotlib.pyplot as plt import tensorflow as tf df = pd.read_csv('df_file.csv') df.head() train_df = df.sample(frac=0.7, random_state=42) val_df = df.drop(train_df.index).sample(frac=0.5, random_state=42) test_df = df.drop(train_df.index).drop(val_df.index) train_dataset = tf.data.Dataset.from_tensor_slices((train_df['Text'].values, train_df['Label'].values)).batch(32).prefetch(tf.data.AUTOTUNE) val_dataset = tf.data.Dataset.from_tensor_slices((val_df['Text'].values, val_df['Label'].values)).batch(32).prefetch(tf.data.AUTOTUNE) test_dataset = tf.data.Dataset.from_tensor_slices((test_df['Text'].values, test_df['Label'].values)).batch(32).prefetch(tf.data.AUTOTUNE) text_vectorizer = tf.keras.layers.TextVectorization(max_tokens=100_000, output_mode='int', output_sequence_length=1000, pad_to_max_tokens=True) text_vectorizer.adapt(train_df['Text'].values) embedding = tf.keras.layers.Embedding(input_dim=len(text_vectorizer.get_vocabulary()), output_dim=128, input_length=1000) inputs = tf.keras.layers.Input(shape=[], dtype=tf.string) x = text_vectorizer(inputs) x = embedding(x) x = tf.keras.layers.LSTM(64)(x) outputs = tf.keras.layers.Dense(5, activation='softmax')(x) model_2 = tf.keras.Model(inputs, outputs, name='model_2_lstm') model_2.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=tf.keras.optimizers.legacy.Adam(), metrics=['accuracy']) model_2_history = model_2.fit(train_dataset, epochs=50, validation_data=val_dataset, callbacks=[ tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True), tf.keras.callbacks.ModelCheckpoint(model_2.name, save_best_only=True), tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', patience=5, verbose=1) ])
Posted
by 09jtip.
Last updated
.
Post not yet marked as solved
0 Replies
956 Views
On an Apple M1 with Ventura 13.6. I followed the steps on the Get started with tensorflow-metal page here: https://developer.apple.com/metal/tensorflow-plugin/ python3 -m venv ~/venv-metal source ~/venv-metal/bin/activate python -m pip install -U pip python -m pip install tensorflow python -m pip install tensorflow-metal With a clean start I also tried a pinning python -m pip install tensorflow==2.13.0 Where Successfully installed tensorflow-metal-1.0.0 The table here suggested this should work. https://pypi.org/project/tensorflow-metal/ But I got the same error... Running Python code without the tensorflow import was not a problem. I found forums with similar error on Mac 1 but none of the proposed solution worked. Is there suggested steps to get the `get started tutorial working?
Posted Last updated
.
Post not yet marked as solved
0 Replies
356 Views
Kia ora, Been having heaps of trouble recently trying to get TensorFlow working, it just suddenly stopped and the kernel would just crash every time I try to import tf. I've tried just about everything eg. fresh install of python, reinstalling Xcode dev tools Below is the relevant lines of pip freeze, using python 1.10.13 btw tensorboard==2.15.1 tensorboard-data-server==0.7.2 tensorboard-plugin-wit==1.8.1 tensorflow==2.15.0 tensorflow-estimator==2.15.0 tensorflow-io-gcs-filesystem==0.34.0 tensorflow-macos==2.15.0 tensorflow-metal==0.5.0 Below is the cell in question that is killing the kernal import tensorflow as tf import matplotlib.pyplot as plt import tensorflow_datasets as tfds from tensorflow.keras.layers import Conv2D, MaxPool2D, Dense, Flatten, InputLayer, BatchNormalization, Dropout from tensorflow.keras.losses import BinaryCrossentropy from tensorflow.keras.optimizers.legacy import Adam I'll be around all day so if you have anything that can help, I'll be sure to give it a go as soon as you post it and get back to you! Looking forward to your replies. Nga mihi, Kane
Posted Last updated
.
Post not yet marked as solved
1 Replies
370 Views
I did a clean install of Python (v. 3.10), then Tensorflow & Tensorflow-Metal following exactly the process stated in Apple's plugin support page. Now, every time I run ANY python code with Tensorflow it crashes in the model.fit instruction. It does not matter what I feed into it, even code that used to run perfectly on my previous MacBook (Intel)... I've researched ad-vomitum for answers but Apple washes it's hands stating that is Tensorflow and Tensorflow does the same. Fact is that exactly the same code runs flawlessly on my Windows NVIDIA PC setup. I purchased the m3 laptop with the hope of having the possibility to train my neural networks "on the go"... now I lost $5,000 usd, I can't make it work, and is a total disaster. I am extremely competent in Python development and have been developing neural networks for years. So if you are going to comment, please avoid suggestions like "check your Python version" etc. - This is DEFINITIVELY due to the m3 Mac. Exact same setup is working OK on an M1-Ultra Mac Studio. It is just not portable... Does anyone have any specific advice on how to make a proper setup of Tensorflow for the Mac M3??
Posted
by Mopi.
Last updated
.
Post not yet marked as solved
0 Replies
309 Views
Running grouped convolutions on an M2 with the metal plugin I get an error. Example code: Using TF2.11 and no metal plugin I get import tensorflow as tf tf.keras.layers.Conv1D(5,1,padding="same", kernel_initializer="ones", groups=5)(tf.ones((1,1,5))) # displays <tf.Tensor: shape=(1, 1, 5), dtype=float32, numpy=array([[[1., 1., 1., 1., 1.]]], dtype=float32)> On TF2.14 with the plugin I received import tensorflow as tf tf.keras.layers.Conv1D(5,1,padding="same", kernel_initializer="ones", groups=5)(tf.ones((1,1,5))) # displays ... NotFoundError: Exception encountered when calling layer 'conv1d_3' (type Conv1D). could not find registered platform with id: 0x104d8f6f0 [Op:__inference__jit_compiled_convolution_op_78] Call arguments received by layer 'conv1d_3' (type Conv1D): • inputs=tf.Tensor(shape=(1, 1, 5), dtype=float32) could not find registered platform with id
Posted
by roebel.
Last updated
.