diff --git a/graphcast/xarray_jax.py b/graphcast/xarray_jax.py index ed88f43..5630c73 100644 --- a/graphcast/xarray_jax.py +++ b/graphcast/xarray_jax.py @@ -404,11 +404,14 @@ class JaxArrayWrapper(np.lib.mixins.NDArrayOperatorsMixin): """Wraps a JAX array into a duck-typed array suitable for use with xarray. This uses an older duck-typed array protocol based on __array_ufunc__ and - __array_function__ which works with numpy and xarray. This is in the process - of being superseded by the Python array API standard - (https://data-apis.org/array-api/latest/index.html), but JAX and xarray - haven't implemented it yet. Once they have, we should be able to get rid of + __array_function__ which works with numpy and xarray. (In newer versions + of xarray it implements xarray.namedarray._typing._array_function.) + + This is in the process of being superseded by the Python array API standard + (https://data-apis.org/array-api/latest/index.html), but JAX hasn't + implemented it yet. Once they have, we should be able to get rid of this wrapper and use JAX arrays directly with xarray. + """ def __init__(self, jax_array): @@ -464,6 +467,14 @@ def ndim(self): def size(self): return self.jax_array.size + @property + def real(self): + return self.jax_array.real + + @property + def imag(self): + return self.jax_array.imag + # Array methods not covered by NDArrayOperatorsMixin: # Allows conversion to numpy array using np.asarray etc. Warning: doing this