Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer' #1760

Open
agoliaei opened this issue Sep 10, 2022 · 0 comments

Comments

@agoliaei
Copy link

Description

Hi,
I am trying to follow this tutorial:
https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb
Setting the runtime to TPU on Colab used to work couple of days ago. But now it crashes with error:

TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'

This happens at this step:
training_loop = training.Loop(model,.....

Environment information

OS: 
NAME="Ubuntu"
VERSION="18.04.6 LTS (Bionic Beaver)"
ID=ubuntu
ID_LIKE=debian
PRETTY_NAME="Ubuntu 18.04.6 LTS"
VERSION_ID="18.04"
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
VERSION_CODENAME=bionic
UBUNTU_CODENAME=bionic

$ pip freeze | grep trax
# trax==1.4.1


$ pip freeze | grep tensor
# tensorboard==2.10.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.10.0
tensorflow-datasets==4.6.0
tensorflow-estimator==2.10.0
tensorflow-gcs-config==2.8.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.26.0
tensorflow-metadata==1.10.0
tensorflow-probability==0.16.0
tensorflow-text==2.10.0

$ pip freeze | grep jax
# jax==0.3.17
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp37-none-manylinux2014_x86_64.whl

$ python -V
# Python 3.7.13

For bugs: reproduction and error logs

# Steps to reproduce:
https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb

...
# Error logs:
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-8-2021642a85f0>](https://localhost:8080/#) in <module>
      9                               train_task,
     10                               eval_tasks=[eval_task],
---> 11                               output_dir=output_dir)

16 frames
[/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
    278 
    279     # Create the optimizer for the training loss function.
--> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
    281 
    282     # Sync layers weights/state in memory effcient trainer layers.

[/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in <genexpr>(.0)
    278 
    279     # Create the optimizer for the training loss function.
--> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
    281 
    282     # Sync layers weights/state in memory effcient trainer layers.

[/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in _init_trainer(self, task)
    348         task.optimizer.tree_init(model_in_training.weights)
    349       return optimizers.Trainer(
--> 350           model_in_training, task.optimizer, adasum=self._adasum)
    351     # In the memory-efficient path, we initialize the model here.
    352     blocks, loss_layer = optimizers.trainer.extract_reversible_blocks(

[/usr/local/lib/python3.7/dist-packages/trax/optimizers/trainer.py](https://localhost:8080/#) in __init__(self, model_with_loss, optimizer, n_devices, adasum)
     57     # optimizer slots and opt_params may need to be replicated
     58     self._slots, self._opt_params = tl.on_cpu(tl.for_n_devices(
---> 59         (self._optimizer.slots, self._optimizer.opt_params), self._n_devices))
     60 
     61     # accelerated version of model+loss to replicate weights and state

[/usr/local/lib/python3.7/dist-packages/trax/layers/acceleration.py](https://localhost:8080/#) in on_cpu(x)
    250   """Puts ``x`` in CPU memory in JAX."""
    251   if fastmath.is_backend(fastmath.Backend.JAX):
--> 252     return jax.device_put(x, jax.devices('cpu')[0])
    253   else:
    254     return x

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in device_put(x, device)
   2722   """
   2723   with config_explicit_device_put_scope():
-> 2724     return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
   2725 
   2726 

[/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map(f, tree, is_leaf, *rest)
    203   leaves, treedef = tree_flatten(tree, is_leaf)
    204   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    206 
    207 def build_tree(treedef, xs):

[/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0)
    203   leaves, treedef = tree_flatten(tree, is_leaf)
    204   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    206 
    207 def build_tree(treedef, xs):

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in <lambda>(y)
   2722   """
   2723   with config_explicit_device_put_scope():
-> 2724     return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
   2725 
   2726 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, *args, **params)
    323     assert (not config.jax_enable_checks or
    324             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 325     return self.bind_with_trace(find_top_trace(args), args, params)
    326 
    327   def bind_with_trace(self, trace, args, params):

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
    326 
    327   def bind_with_trace(self, trace, args, params):
--> 328     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    329     return map(full_lower, out) if self.multiple_results else full_lower(out)
    330 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
    684 
    685   def process_primitive(self, primitive, tracers, params):
--> 686     return primitive.impl(*tracers, **params)
    687 
    688   def process_call(self, primitive, f, tracers, params):

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_impl(x, device)
   1219     raise TypeError(
   1220         f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
-> 1221   return aval_to_result_handler(device, a)(None, *device_put(x, device))
   1222 
   1223 

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in device_put(x, device)
   1113   x = xla.canonicalize_dtype(x)
   1114   try:
-> 1115     return device_put_handlers[type(x)](x, device)
   1116   except KeyError as err:
   1117     raise TypeError(f"No device_put handler for type: {type(x)}") from err

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_array(x, device)
   1124   if x.dtype == dtypes.float0:
   1125     x = np.zeros(x.shape, dtype=np.dtype(bool))
-> 1126   return (backend.buffer_from_pyval(x, device),)
   1127 
   1128 def _device_put_scalar(x, device):

[/usr/local/lib/python3.7/dist-packages/jax/_src/device_array.py](https://localhost:8080/#) in __array__(self, dtype, context)
    264 
    265   def __array__(self, dtype=None, context=None):
--> 266     return np.asarray(self._value, dtype=dtype)
    267 
    268   setattr(device_array, "__array__", __array__)

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py](https://localhost:8080/#) in _sda_value(self)
    803     npy_value = np.empty(self.aval.shape, self.aval.dtype)
    804     for i in self.one_replica_buffer_indices:
--> 805       npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
    806     self._npy_value = npy_value
    807   return self._npy_value

TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant