Skip to content

Commit

Permalink
impl ax
Browse files Browse the repository at this point in the history
  • Loading branch information
parthraut committed Dec 21, 2024
1 parent 4140af5 commit 96b39bb
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions zeus/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,20 @@ def all_reduce(
if jax_is_available():
# Check if not distributed
jax = MODULE_CACHE["jax"]
if jax.process_count() == 1:
if jax.device_count() == 1:
return object

raise NotImplementedError("JAX distributed all-reduce not yet implemented")
array = jax.numpy.array(object)

if operation == "sum":
reduced = jax.lax.psum(array)
elif operation == "max":
reduced = jax.lax.pmax(array)
else:
raise ValueError(f"all_reduce unsupported operation: {operation}")

# Convert back to list and return
return reduced.tolist()

raise RuntimeError("No framework is available.")

Expand All @@ -153,5 +163,5 @@ def is_distributed() -> bool:
return torch.distributed.is_available() and torch.distributed.is_initialized()
if jax_is_available():
jax = MODULE_CACHE["jax"]
return jax.process_count() > 1
return jax.device_count() > 1
raise RuntimeError("No framework is available.")

0 comments on commit 96b39bb

Please sign in to comment.