From 51b9fe3010dc6beeab349529e1a6ab87ec49f0e0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 8 Jan 2025 06:37:02 -0800 Subject: [PATCH] [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly. This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS. In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads. PiperOrigin-RevId: 713272197 --- jax/_src/test_util.py | 31 +++++++------------ jax/_src/xla_bridge.py | 22 +++++++++++-- .../array_serialization/serialization_test.py | 9 +----- .../jax2tf/tests/sharding_test.py | 7 +---- tests/array_test.py | 10 +----- tests/colocated_python_test.py | 14 +-------- tests/debugger_test.py | 10 +----- tests/debugging_primitives_test.py | 9 +----- tests/export_test.py | 8 +---- tests/jaxpr_effects_test.py | 12 ++----- tests/layout_test.py | 10 +----- tests/multi_device_test.py | 11 +------ tests/pjit_test.py | 10 +----- tests/pmap_test.py | 10 +----- tests/python_callback_test.py | 9 +----- tests/roofline_test.py | 14 +-------- tests/shard_alike_test.py | 12 +------ tests/shard_map_test.py | 12 +------ 18 files changed, 48 insertions(+), 172 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index a2f60887706b..46c442d63024 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -498,29 +498,20 @@ def device_supports_buffer_donation(): ) -@contextmanager -def set_host_platform_device_count(nr_devices: int): - """Context manager to set host platform device count if not specified by user. +def request_cpu_devices(nr_devices: int): + """Requests at least `nr_devices` CPU devices. + + request_cpu_devices should be called at the top-level of a test module before + main() runs. - This should only be used by tests at the top level in setUpModule(); it will - not work correctly if applied to individual test cases. + It is not guaranteed that the number of CPU devices will be exactly + `nr_devices`: it may be more or less, depending on how exactly the test is + invoked. Test cases that require a specific number of devices should skip + themselves if that number is not met. """ - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - f" --xla_force_host_platform_device_count={nr_devices}") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - try: - yield - finally: - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags + if xla_bridge.NUM_CPU_DEVICES.value < nr_devices: xla_bridge.get_backend.cache_clear() + config.update("jax_num_cpu_devices", nr_devices) def skip_on_flag(flag_name, skip_value): diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 28148761c8a4..bbe6631753cb 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -122,6 +122,14 @@ "inline without async dispatch.", ) +NUM_CPU_DEVICES = config.int_flag( + name="jax_num_cpu_devices", + default=-1, + help="Number of CPU devices to use. If not provided, the value of " + "the XLA flag --xla_force_host_platform_device_count is used." + " Must be set before JAX is initialized.", +) + # Warn the user if they call fork(), because it's not going to go well for them. def _at_fork(): @@ -249,8 +257,8 @@ def make_cpu_client( if collectives is None: collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value if _CPU_ENABLE_GLOO_COLLECTIVES.value: - collectives_impl = 'gloo' - warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is ' + collectives_impl = 'gloo' + warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is ' 'deprecated. Please use `jax.config.update(' '"jax_cpu_collectives_implementation", "gloo")` instead.', DeprecationWarning, @@ -268,12 +276,22 @@ def make_cpu_client( f"{collectives_impl}. Available implementations are " f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.") + num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None + if xla_client._version < 303 and num_devices is not None: + xla_flags = os.getenv("XLA_FLAGS") or "" + os.environ["XLA_FLAGS"] = ( + f"{xla_flags} --xla_force_host_platform_device_count={num_devices}" + ) + num_devices = None + # TODO(phawkins): pass num_devices directly when version 303 is the minimum. + kwargs = {} if num_devices is None else {"num_devices": num_devices} return xla_client.make_cpu_client( asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value, distributed_client=distributed.global_state.client, node_id=distributed.global_state.process_id, num_nodes=distributed.global_state.num_processes, collectives=collectives, + **kwargs, ) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index a4bb168efb2f..6ec621d68ff7 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -14,7 +14,6 @@ """Tests for serialization and deserialization of GDA.""" import asyncio -import contextlib import math from functools import partial import os @@ -36,13 +35,7 @@ import tensorstore as ts jax.config.parse_flags_with_absl() -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class CheckpointTest(jtu.JaxTestCase): diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 8fe9a1dd9254..05d2352a4577 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -19,7 +19,6 @@ """ from collections.abc import Sequence -import contextlib from functools import partial import logging import re @@ -47,16 +46,14 @@ import tensorflow as tf config.parse_flags_with_absl() +jtu.request_cpu_devices(8) # Must come after initializing the flags from jax.experimental.jax2tf.tests import tf_test_util -_exit_stack = contextlib.ExitStack() topology = None def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - global topology if jtu.test_device_matches(["tpu"]): with jtu.ignore_warning(message="the imp module is deprecated"): @@ -67,8 +64,6 @@ def setUpModule(): else: topology = None -def tearDownModule(): - _exit_stack.close() class ShardingTest(tf_test_util.JaxToTfTestCase): diff --git a/tests/array_test.py b/tests/array_test.py index 9618a8cf4665..97bf71a5216b 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -43,20 +43,12 @@ from jax._src import prng jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(8) with contextlib.suppress(ImportError): import pytest pytestmark = pytest.mark.multiaccelerator -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() - def create_array(shape, sharding, global_data=None): if global_data is None: diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 4485f5d4f41e..f9dd3ce52b58 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import threading import time from typing import Sequence @@ -29,6 +28,7 @@ import numpy as np config.parse_flags_with_absl() +jtu.request_cpu_devices(8) def _colocated_cpu_devices( @@ -53,18 +53,6 @@ def _colocated_cpu_devices( _count_colocated_python_specialization_cache_miss = jtu.count_events( "colocated_python_func._get_specialized_func") -_exit_stack = contextlib.ExitStack() - - -def setUpModule(): - # TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT - # prepares CPU devices by its own. - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - - -def tearDownModule(): - _exit_stack.close() - class ColocatedPythonTest(jtu.JaxTestCase): diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 18693a7bb2c3..419e7b18dfed 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Sequence -import contextlib import io import re import textwrap @@ -29,6 +28,7 @@ import numpy as np jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(2) def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]: fake_stdin = io.StringIO() @@ -41,14 +41,6 @@ def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringI def _format_multiline(text): return textwrap.dedent(text).lstrip() -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() - foo = 2 class CliDebuggerTest(jtu.JaxTestCase): diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 6afb41645405..0fc9665ceaa5 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections -import contextlib import functools import textwrap import unittest @@ -35,19 +34,13 @@ rich = None jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(2) debug_print = debugging.debug_print def _format_multiline(text): return textwrap.dedent(text).lstrip() -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() class DummyDevice: def __init__(self, platform, id): diff --git a/tests/export_test.py b/tests/export_test.py index da0e9daf2f00..63fe4a8bc47d 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -56,14 +56,8 @@ CAN_SERIALIZE = False config.parse_flags_with_absl() +jtu.request_cpu_devices(8) -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() ### Setup for testing lowering with effects @dataclasses.dataclass(frozen=True) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 2e91792aa950..922b37ffa440 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib + import threading import unittest @@ -34,6 +34,7 @@ import numpy as np config.parse_flags_with_absl() +jtu.request_cpu_devices(2) effect_p = core.Primitive('effect') effect_p.multiple_results = True @@ -132,15 +133,6 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out mlir.register_lowering(callback_p, callback_effect_lowering) -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() - - class JaxprEffectsTest(jtu.JaxTestCase): def test_trivial_jaxpr_has_no_effects(self): diff --git a/tests/layout_test.py b/tests/layout_test.py index f958de5cf5bc..903b17886283 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import math from functools import partial from absl.testing import absltest @@ -28,14 +27,7 @@ from jax.experimental.compute_on import compute_on config.parse_flags_with_absl() - -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class LayoutTest(jtu.JaxTestCase): diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 057731cb5d55..1fc6fe1e9298 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib from unittest import SkipTest import tracemalloc as tm @@ -25,15 +24,7 @@ from jax._src import test_util as jtu jax.config.parse_flags_with_absl() - -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class MultiDeviceTest(jtu.JaxTestCase): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 77bcee296b0a..3fcc5c81ad91 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import OrderedDict, namedtuple -import contextlib import re from functools import partial import logging @@ -64,14 +63,7 @@ config.parse_flags_with_absl() -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) def create_array(global_shape, global_mesh, mesh_axes, global_data=None, dtype=np.float32): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index a9de8c896414..795f7d4bf9e8 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -15,7 +15,6 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor -import contextlib from functools import partial import itertools as it import gc @@ -54,15 +53,8 @@ from jax._src.util import safe_map, safe_zip config.parse_flags_with_absl() +jtu.request_cpu_devices(8) -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 199b90fe524e..efa877fd3a91 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -36,14 +36,7 @@ import numpy as np config.parse_flags_with_absl() - -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(2) map, unsafe_map = util.safe_map, map diff --git a/tests/roofline_test.py b/tests/roofline_test.py index e5003947181b..aec34ff22a57 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -14,7 +14,6 @@ from __future__ import annotations from functools import partial -import contextlib from absl.testing import absltest from jax.sharding import PartitionSpec as P @@ -28,6 +27,7 @@ jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(8) def create_inputs( @@ -45,18 +45,6 @@ def create_inputs( return mesh, tuple(arrays) -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - - -def tearDownModule(): - _exit_stack.close() - - class RooflineTest(jtu.JaxTestCase): def test_scalar_collectives(self): a_spec = P("z", ("x", "y")) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 383746899570..25d46c5add2e 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib - import jax import jax.numpy as jnp import numpy as np @@ -24,15 +22,7 @@ from jax.experimental.shard_map import shard_map jax.config.parse_flags_with_absl() - -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class ShardAlikeDownstreamTest(jtu.JaxTestCase): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index ec846a32a903..19cc870881cf 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -15,7 +15,6 @@ from __future__ import annotations from collections.abc import Callable, Generator, Iterable, Iterator, Sequence -import contextlib from functools import partial import itertools as it import math @@ -53,6 +52,7 @@ from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member config.parse_flags_with_absl() +jtu.request_cpu_devices(8) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -70,16 +70,6 @@ def create_inputs(a_sharding, b_sharding): return mesh, m1, m2 -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() - - class ShardMapTest(jtu.JaxTestCase): def test_identity(self):