We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi, I am trying to follow this tutorial: https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb Setting the runtime to TPU on Colab used to work couple of days ago. But now it crashes with error:
TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
This happens at this step: training_loop = training.Loop(model,.....
OS: NAME="Ubuntu" VERSION="18.04.6 LTS (Bionic Beaver)" ID=ubuntu ID_LIKE=debian PRETTY_NAME="Ubuntu 18.04.6 LTS" VERSION_ID="18.04" HOME_URL="https://www.ubuntu.com/" SUPPORT_URL="https://help.ubuntu.com/" BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/" PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy" VERSION_CODENAME=bionic UBUNTU_CODENAME=bionic $ pip freeze | grep trax # trax==1.4.1 $ pip freeze | grep tensor # tensorboard==2.10.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorflow==2.10.0 tensorflow-datasets==4.6.0 tensorflow-estimator==2.10.0 tensorflow-gcs-config==2.8.0 tensorflow-hub==0.12.0 tensorflow-io-gcs-filesystem==0.26.0 tensorflow-metadata==1.10.0 tensorflow-probability==0.16.0 tensorflow-text==2.10.0 $ pip freeze | grep jax # jax==0.3.17 jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp37-none-manylinux2014_x86_64.whl $ python -V # Python 3.7.13
# Steps to reproduce: https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb ...
# Error logs: --------------------------------------------------------------------------- TypeError Traceback (most recent call last) [<ipython-input-8-2021642a85f0>](https://localhost:8080/#) in <module> 9 train_task, 10 eval_tasks=[eval_task], ---> 11 output_dir=output_dir) 16 frames [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks) 278 279 # Create the optimizer for the training loss function. --> 280 self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks) 281 282 # Sync layers weights/state in memory effcient trainer layers. [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in <genexpr>(.0) 278 279 # Create the optimizer for the training loss function. --> 280 self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks) 281 282 # Sync layers weights/state in memory effcient trainer layers. [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in _init_trainer(self, task) 348 task.optimizer.tree_init(model_in_training.weights) 349 return optimizers.Trainer( --> 350 model_in_training, task.optimizer, adasum=self._adasum) 351 # In the memory-efficient path, we initialize the model here. 352 blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( [/usr/local/lib/python3.7/dist-packages/trax/optimizers/trainer.py](https://localhost:8080/#) in __init__(self, model_with_loss, optimizer, n_devices, adasum) 57 # optimizer slots and opt_params may need to be replicated 58 self._slots, self._opt_params = tl.on_cpu(tl.for_n_devices( ---> 59 (self._optimizer.slots, self._optimizer.opt_params), self._n_devices)) 60 61 # accelerated version of model+loss to replicate weights and state [/usr/local/lib/python3.7/dist-packages/trax/layers/acceleration.py](https://localhost:8080/#) in on_cpu(x) 250 """Puts ``x`` in CPU memory in JAX.""" 251 if fastmath.is_backend(fastmath.Backend.JAX): --> 252 return jax.device_put(x, jax.devices('cpu')[0]) 253 else: 254 return x [/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in device_put(x, device) 2722 """ 2723 with config_explicit_device_put_scope(): -> 2724 return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x) 2725 2726 [/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map(f, tree, is_leaf, *rest) 203 leaves, treedef = tree_flatten(tree, is_leaf) 204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) 206 207 def build_tree(treedef, xs): [/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0) 203 leaves, treedef = tree_flatten(tree, is_leaf) 204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) 206 207 def build_tree(treedef, xs): [/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in <lambda>(y) 2722 """ 2723 with config_explicit_device_put_scope(): -> 2724 return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x) 2725 2726 [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, *args, **params) 323 assert (not config.jax_enable_checks or 324 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args --> 325 return self.bind_with_trace(find_top_trace(args), args, params) 326 327 def bind_with_trace(self, trace, args, params): [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params) 326 327 def bind_with_trace(self, trace, args, params): --> 328 out = trace.process_primitive(self, map(trace.full_raise, args), params) 329 return map(full_lower, out) if self.multiple_results else full_lower(out) 330 [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params) 684 685 def process_primitive(self, primitive, tracers, params): --> 686 return primitive.impl(*tracers, **params) 687 688 def process_call(self, primitive, f, tracers, params): [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_impl(x, device) 1219 raise TypeError( 1220 f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err -> 1221 return aval_to_result_handler(device, a)(None, *device_put(x, device)) 1222 1223 [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in device_put(x, device) 1113 x = xla.canonicalize_dtype(x) 1114 try: -> 1115 return device_put_handlers[type(x)](x, device) 1116 except KeyError as err: 1117 raise TypeError(f"No device_put handler for type: {type(x)}") from err [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_array(x, device) 1124 if x.dtype == dtypes.float0: 1125 x = np.zeros(x.shape, dtype=np.dtype(bool)) -> 1126 return (backend.buffer_from_pyval(x, device),) 1127 1128 def _device_put_scalar(x, device): [/usr/local/lib/python3.7/dist-packages/jax/_src/device_array.py](https://localhost:8080/#) in __array__(self, dtype, context) 264 265 def __array__(self, dtype=None, context=None): --> 266 return np.asarray(self._value, dtype=dtype) 267 268 setattr(device_array, "__array__", __array__) [/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py](https://localhost:8080/#) in _sda_value(self) 803 npy_value = np.empty(self.aval.shape, self.aval.dtype) 804 for i in self.one_replica_buffer_indices: --> 805 npy_value[self.indices[i]] = np.asarray(self.device_buffers[i]) 806 self._npy_value = npy_value 807 return self._npy_value TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer' ...
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Description
Hi,
I am trying to follow this tutorial:
https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb
Setting the runtime to TPU on Colab used to work couple of days ago. But now it crashes with error:
TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
This happens at this step:
training_loop = training.Loop(model,.....
Environment information
For bugs: reproduction and error logs
The text was updated successfully, but these errors were encountered: