diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index c938c6e1..ba502449 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -138,10 +138,20 @@ def all_reduce( if jax_is_available(): # Check if not distributed jax = MODULE_CACHE["jax"] - if jax.process_count() == 1: + if jax.device_count() == 1: return object - raise NotImplementedError("JAX distributed all-reduce not yet implemented") + array = jax.numpy.array(object) + + if operation == "sum": + reduced = jax.lax.psum(array) + elif operation == "max": + reduced = jax.lax.pmax(array) + else: + raise ValueError(f"all_reduce unsupported operation: {operation}") + + # Convert back to list and return + return reduced.tolist() raise RuntimeError("No framework is available.") @@ -153,5 +163,5 @@ def is_distributed() -> bool: return torch.distributed.is_available() and torch.distributed.is_initialized() if jax_is_available(): jax = MODULE_CACHE["jax"] - return jax.process_count() > 1 + return jax.device_count() > 1 raise RuntimeError("No framework is available.")