Skip to content

Commit

Permalink
[JAX] Add a new jax_num_cpu_devices flag that allows the user to spec…
Browse files Browse the repository at this point in the history
…ify the number of CPU directly.

This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Jan 8, 2025
1 parent f96339b commit 51b9fe3
Show file tree
Hide file tree
Showing 18 changed files with 48 additions and 172 deletions.
31 changes: 11 additions & 20 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,29 +498,20 @@ def device_supports_buffer_donation():
)


@contextmanager
def set_host_platform_device_count(nr_devices: int):
"""Context manager to set host platform device count if not specified by user.
def request_cpu_devices(nr_devices: int):
"""Requests at least `nr_devices` CPU devices.
request_cpu_devices should be called at the top-level of a test module before
main() runs.
This should only be used by tests at the top level in setUpModule(); it will
not work correctly if applied to individual test cases.
It is not guaranteed that the number of CPU devices will be exactly
`nr_devices`: it may be more or less, depending on how exactly the test is
invoked. Test cases that require a specific number of devices should skip
themselves if that number is not met.
"""
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
f" --xla_force_host_platform_device_count={nr_devices}")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
try:
yield
finally:
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
if xla_bridge.NUM_CPU_DEVICES.value < nr_devices:
xla_bridge.get_backend.cache_clear()
config.update("jax_num_cpu_devices", nr_devices)


def skip_on_flag(flag_name, skip_value):
Expand Down
22 changes: 20 additions & 2 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@
"inline without async dispatch.",
)

NUM_CPU_DEVICES = config.int_flag(
name="jax_num_cpu_devices",
default=-1,
help="Number of CPU devices to use. If not provided, the value of "
"the XLA flag --xla_force_host_platform_device_count is used."
" Must be set before JAX is initialized.",
)


# Warn the user if they call fork(), because it's not going to go well for them.
def _at_fork():
Expand Down Expand Up @@ -249,8 +257,8 @@ def make_cpu_client(
if collectives is None:
collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
collectives_impl = 'gloo'
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
collectives_impl = 'gloo'
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
'deprecated. Please use `jax.config.update('
'"jax_cpu_collectives_implementation", "gloo")` instead.',
DeprecationWarning,
Expand All @@ -268,12 +276,22 @@ def make_cpu_client(
f"{collectives_impl}. Available implementations are "
f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")

num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
if xla_client._version < 303 and num_devices is not None:
xla_flags = os.getenv("XLA_FLAGS") or ""
os.environ["XLA_FLAGS"] = (
f"{xla_flags} --xla_force_host_platform_device_count={num_devices}"
)
num_devices = None
# TODO(phawkins): pass num_devices directly when version 303 is the minimum.
kwargs = {} if num_devices is None else {"num_devices": num_devices}
return xla_client.make_cpu_client(
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
collectives=collectives,
**kwargs,
)


Expand Down
9 changes: 1 addition & 8 deletions jax/experimental/array_serialization/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for serialization and deserialization of GDA."""

import asyncio
import contextlib
import math
from functools import partial
import os
Expand All @@ -36,13 +35,7 @@
import tensorstore as ts

jax.config.parse_flags_with_absl()
_exit_stack = contextlib.ExitStack()

def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))

def tearDownModule():
_exit_stack.close()
jtu.request_cpu_devices(8)


class CheckpointTest(jtu.JaxTestCase):
Expand Down
7 changes: 1 addition & 6 deletions jax/experimental/jax2tf/tests/sharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""
from collections.abc import Sequence
import contextlib
from functools import partial
import logging
import re
Expand Down Expand Up @@ -47,16 +46,14 @@
import tensorflow as tf

config.parse_flags_with_absl()
jtu.request_cpu_devices(8)

# Must come after initializing the flags
from jax.experimental.jax2tf.tests import tf_test_util

_exit_stack = contextlib.ExitStack()
topology = None

def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))

global topology
if jtu.test_device_matches(["tpu"]):
with jtu.ignore_warning(message="the imp module is deprecated"):
Expand All @@ -67,8 +64,6 @@ def setUpModule():
else:
topology = None

def tearDownModule():
_exit_stack.close()


class ShardingTest(tf_test_util.JaxToTfTestCase):
Expand Down
10 changes: 1 addition & 9 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,12 @@
from jax._src import prng

jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(8)

with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator

# Run all tests with 8 CPU devices.
_exit_stack = contextlib.ExitStack()

def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))

def tearDownModule():
_exit_stack.close()


def create_array(shape, sharding, global_data=None):
if global_data is None:
Expand Down
14 changes: 1 addition & 13 deletions tests/colocated_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import threading
import time
from typing import Sequence
Expand All @@ -29,6 +28,7 @@
import numpy as np

config.parse_flags_with_absl()
jtu.request_cpu_devices(8)


def _colocated_cpu_devices(
Expand All @@ -53,18 +53,6 @@ def _colocated_cpu_devices(
_count_colocated_python_specialization_cache_miss = jtu.count_events(
"colocated_python_func._get_specialized_func")

_exit_stack = contextlib.ExitStack()


def setUpModule():
# TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT
# prepares CPU devices by its own.
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))


def tearDownModule():
_exit_stack.close()


class ColocatedPythonTest(jtu.JaxTestCase):

Expand Down
10 changes: 1 addition & 9 deletions tests/debugger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from collections.abc import Sequence
import contextlib
import io
import re
import textwrap
Expand All @@ -29,6 +28,7 @@
import numpy as np

jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(2)

def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]:
fake_stdin = io.StringIO()
Expand All @@ -41,14 +41,6 @@ def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringI
def _format_multiline(text):
return textwrap.dedent(text).lstrip()

_exit_stack = contextlib.ExitStack()

def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))

def tearDownModule():
_exit_stack.close()

foo = 2

class CliDebuggerTest(jtu.JaxTestCase):
Expand Down
9 changes: 1 addition & 8 deletions tests/debugging_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import functools
import textwrap
import unittest
Expand All @@ -35,19 +34,13 @@
rich = None

jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(2)

debug_print = debugging.debug_print

def _format_multiline(text):
return textwrap.dedent(text).lstrip()

_exit_stack = contextlib.ExitStack()

def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))

def tearDownModule():
_exit_stack.close()

class DummyDevice:
def __init__(self, platform, id):
Expand Down
8 changes: 1 addition & 7 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,8 @@
CAN_SERIALIZE = False

config.parse_flags_with_absl()
jtu.request_cpu_devices(8)

_exit_stack = contextlib.ExitStack()

def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))

def tearDownModule():
_exit_stack.close()

### Setup for testing lowering with effects
@dataclasses.dataclass(frozen=True)
Expand Down
12 changes: 2 additions & 10 deletions tests/jaxpr_effects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib

import threading
import unittest

Expand All @@ -34,6 +34,7 @@
import numpy as np

config.parse_flags_with_absl()
jtu.request_cpu_devices(2)

effect_p = core.Primitive('effect')
effect_p.multiple_results = True
Expand Down Expand Up @@ -132,15 +133,6 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
mlir.register_lowering(callback_p, callback_effect_lowering)


_exit_stack = contextlib.ExitStack()

def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))

def tearDownModule():
_exit_stack.close()


class JaxprEffectsTest(jtu.JaxTestCase):

def test_trivial_jaxpr_has_no_effects(self):
Expand Down
10 changes: 1 addition & 9 deletions tests/layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import math
from functools import partial
from absl.testing import absltest
Expand All @@ -28,14 +27,7 @@
from jax.experimental.compute_on import compute_on

config.parse_flags_with_absl()

_exit_stack = contextlib.ExitStack()

def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))

def tearDownModule():
_exit_stack.close()
jtu.request_cpu_devices(8)


class LayoutTest(jtu.JaxTestCase):
Expand Down
11 changes: 1 addition & 10 deletions tests/multi_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
from unittest import SkipTest
import tracemalloc as tm

Expand All @@ -25,15 +24,7 @@
from jax._src import test_util as jtu

jax.config.parse_flags_with_absl()

# Run all tests with 8 CPU devices.
_exit_stack = contextlib.ExitStack()

def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))

def tearDownModule():
_exit_stack.close()
jtu.request_cpu_devices(8)


class MultiDeviceTest(jtu.JaxTestCase):
Expand Down
10 changes: 1 addition & 9 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from collections import OrderedDict, namedtuple
import contextlib
import re
from functools import partial
import logging
Expand Down Expand Up @@ -64,14 +63,7 @@

config.parse_flags_with_absl()

# Run all tests with 8 CPU devices.
_exit_stack = contextlib.ExitStack()

def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))

def tearDownModule():
_exit_stack.close()
jtu.request_cpu_devices(8)

def create_array(global_shape, global_mesh, mesh_axes, global_data=None,
dtype=np.float32):
Expand Down
Loading

0 comments on commit 51b9fe3

Please sign in to comment.