From 96b39bba59273728d60584d229c23ee6f85d95f7 Mon Sep 17 00:00:00 2001 From: Parth Raut Date: Fri, 20 Dec 2024 20:44:22 -0500 Subject: [PATCH] impl ax --- zeus/utils/framework.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index c938c6e1..ba502449 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -138,10 +138,20 @@ def all_reduce( if jax_is_available(): # Check if not distributed jax = MODULE_CACHE["jax"] - if jax.process_count() == 1: + if jax.device_count() == 1: return object - raise NotImplementedError("JAX distributed all-reduce not yet implemented") + array = jax.numpy.array(object) + + if operation == "sum": + reduced = jax.lax.psum(array) + elif operation == "max": + reduced = jax.lax.pmax(array) + else: + raise ValueError(f"all_reduce unsupported operation: {operation}") + + # Convert back to list and return + return reduced.tolist() raise RuntimeError("No framework is available.") @@ -153,5 +163,5 @@ def is_distributed() -> bool: return torch.distributed.is_available() and torch.distributed.is_initialized() if jax_is_available(): jax = MODULE_CACHE["jax"] - return jax.process_count() > 1 + return jax.device_count() > 1 raise RuntimeError("No framework is available.")