diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 92fef2cc29af..551b9ef3bba7 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: '1572b129c6682211abfe139e112592226c361a6c' # Latest commit as of 2024-12-04 + ref: 'f7a74a685d78d98203fa991fc19a5d1fda57c212' # Latest commit as of 2025-01-07 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 346c399b3332..1835f0857c1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Changes: * The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum supported version until June 2025. + * The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum + supported version until June 2025. * {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than `optimize='optimal'`. This avoids exponentially-scaling trace-time in the case of many arguments ({jax-issue}`#25214`). diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_gpu.sh index 7bc2492781b2..416d985d380c 100644 --- a/ci/run_pytest_gpu.sh +++ b/ci/run_pytest_gpu.sh @@ -56,6 +56,5 @@ echo "Running GPU tests..." "$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \ tests examples \ --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ ---deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \ --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \ ---deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric \ No newline at end of file +--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md index c56e82c77450..eaa3bc7317c8 100644 --- a/docs/advanced-autodiff.md +++ b/docs/advanced-autodiff.md @@ -247,7 +247,7 @@ perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) ### Hessian-vector products with `jax.grad`-of-`jax.grad` -One thing you can do with higher-order {func}`jax.vmap` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.) +One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.) A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)). @@ -259,11 +259,11 @@ for any $v \in \mathbb{R}^n$. The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store. -Luckily, {func}`jax.vmap` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity: +Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity: $\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$, -where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.vmap` is efficient. +where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient. In JAX code, you can just write this: @@ -357,7 +357,7 @@ To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfw ### Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff) -JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.vmap` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background. +JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background. #### JVPs in math @@ -473,7 +473,7 @@ vjp :: (a -> b) -> a -> (b, CT b -> CT a) where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`. -This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.vmap` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters. +This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters. There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). diff --git a/docs/aot.md b/docs/aot.md index a5dd69a72b8f..2dc4eadf388f 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -73,14 +73,8 @@ see the {ref}`export` APIs. See the {mod}`jax.stages` documentation for more details on what functionality the lowering and compiled functions provide. -In place of `jax.jit` above, you can also `lower(...)` the result of -{func}`jax.pmap`, as well as `pjit` and `xmap` (from -{mod}`jax.experimental.pjit` and {mod}`jax.experimental.maps` respectively). In -each case, you can `compile()` the result similarly. - All optional arguments to `jit`---such as `static_argnums`---are respected in -the corresponding lowering, compilation, and execution. Again the same goes for -`pmap`, `pjit`, and `xmap`. +the corresponding lowering, compilation, and execution. In the example above, we can replace the arguments to `lower` with any objects that have `shape` and `dtype` attributes: diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 3ef927e056f2..0938a5da944f 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -354,6 +354,61 @@ print_saved_residuals(loss_checkpoint2, params, x, y) Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`. +#### Custom policies for offload + +You may consider offloading to CPU memory instead of recomputing when checkpointing to save accelerator memory. `jax.checkpoint_policies.offload_dot_with_no_batch_dims` can offload the results of matrix multiplications with no batch dimensions to the CPU. + +```{code-cell} +from jax.ad_checkpoint import checkpoint + +def checkpoint_offload_dot_with_no_batch_dims(self): + policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( + "device", "pinned_host") + + @functools.partial(checkpoint, policy=policy) + def f(x): + x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST) + x = jnp.sin(x) + x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST) + x = jnp.sin(x) + x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST) + x = jnp.sin(x) + x = jnp.sum(x) + return x +``` + +One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed. + +```{code-cell} +from jax.ad_checkpoint import checkpoint, checkpoint_name +from jax._src import test_util as jtu + +def checkpoint_names_saved_offloaded_recomputed(self): + mesh = jtu.create_mesh((2,), ("x",)) + shape = (256, 128) + np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) + s = NamedSharding(mesh, P("x")) + inp = jax.device_put(np_inp, s) + + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"], + offload_src='device', offload_dst='pinned_host') + + @functools.partial(checkpoint, policy=policy) + def f(x): + def g(ys, _): + y, _ = ys + y = checkpoint_name(jnp.sin(y), "y") + z = checkpoint_name(jnp.sin(y), "z") + z = z.T + w = checkpoint_name(jnp.sin(z), "w") + return (w.T, jnp.sum(w)), None + _, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0] + return scan_out +``` + +The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data. + #### List of policies The policies are: diff --git a/examples/ffi/src/jax_ffi_example/cpu_examples.cc b/examples/ffi/src/jax_ffi_example/cpu_examples.cc index 3832c86b29b2..8d808ecd8e30 100644 --- a/examples/ffi/src/jax_ffi_example/cpu_examples.cc +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.cc @@ -103,6 +103,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( Counter, CounterImpl, ffi::Ffi::Bind().Attr("index").Ret>()); +// -------- +// Aliasing +// -------- +// +// This example demonstrates how input-output aliasing works. The handler +// doesn't do anything except to check that the input and output pointers +// address the same data. + +ffi::Error AliasingImpl(ffi::AnyBuffer input, + ffi::Result output) { + if (input.element_type() != output->element_type() || + input.element_count() != output->element_count()) { + return ffi::Error::InvalidArgument( + "The input and output data types and sizes must match."); + } + if (input.untyped_data() != output->untyped_data()) { + return ffi::Error::InvalidArgument( + "When aliased, the input and output buffers should point to the same " + "data."); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Aliasing, AliasingImpl, + ffi::Ffi::Bind().Arg().Ret()); + // Boilerplate for exposing handlers to Python NB_MODULE(_cpu_examples, m) { m.def("registrations", []() { @@ -111,9 +138,8 @@ NB_MODULE(_cpu_examples, m) { nb::capsule(reinterpret_cast(ArrayAttr)); registrations["dictionary_attr"] = nb::capsule(reinterpret_cast(DictionaryAttr)); - registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); - + registrations["aliasing"] = nb::capsule(reinterpret_cast(Aliasing)); return registrations; }); } diff --git a/examples/ffi/src/jax_ffi_example/cpu_examples.py b/examples/ffi/src/jax_ffi_example/cpu_examples.py index 563e5a911b99..155e100dcd77 100644 --- a/examples/ffi/src/jax_ffi_example/cpu_examples.py +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.py @@ -39,3 +39,9 @@ def dictionary_attr(**kwargs): def counter(index): return jax.ffi.ffi_call( "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) + + +def aliasing(x): + return jax.ffi.ffi_call( + "aliasing", jax.ShapeDtypeStruct(x.shape, x.dtype), + input_output_aliases={0: 0})(x) diff --git a/examples/ffi/tests/cpu_examples_test.py b/examples/ffi/tests/cpu_examples_test.py index 0e2cfde02db6..8db524f6264b 100644 --- a/examples/ffi/tests/cpu_examples_test.py +++ b/examples/ffi/tests/cpu_examples_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax import jax.numpy as jnp @@ -91,5 +91,16 @@ def counter_fun(x): self.assertEqual(counter_fun(0)[1], 3) +class AliasingTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + + @parameterized.parameters((jnp.linspace(0, 0.5, 10),), (jnp.int32(6),)) + def test_basic(self, x): + self.assertAllClose(cpu_examples.aliasing(x), x) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 5c42a3b44de3..21bc8038448a 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -33,6 +33,7 @@ from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import version as jaxlib_version +from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir import numpy as np @@ -198,7 +199,10 @@ def get_compile_options( logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.") if env_options_overrides is None: env_options_overrides = {} - env_options_overrides['xla_gpu_enable_command_buffer'] = '' + if xla_extension_version > 302: + env_options_overrides['xla_gpu_enable_command_buffer'] = '' + else: + env_options_overrides['xla_gpu_graph_min_graph_size'] = '100000' if env_options_overrides is not None: # Some overrides are passed directly on build_options. diff --git a/jax/_src/core.py b/jax/_src/core.py index 0868990e1d3b..5017f7f5f676 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -39,7 +39,7 @@ from jax._src import effects from jax._src import compute_on from jax._src import mesh as mesh_lib -from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton +from jax._src.partition_spec import PartitionSpec as P from jax._src.errors import ( ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, TracerIntegerConversionError, UnexpectedTracerError) @@ -1659,12 +1659,12 @@ def _maybe_modify_sharding(sharding): new_spec = [] for s in sharding.spec: - if s is None or isinstance(s, UnconstrainedSingleton): + if s is None: new_spec.append(s) else: temp_s = s[0] if isinstance(s, tuple) else s new_spec.append( - P.UNCONSTRAINED + None if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s) return sharding.with_spec(new_spec) @@ -1762,8 +1762,6 @@ def _get_shape_sharding_str(shape, spec): for s1, s2 in zip(shape, spec): if s2 is None: out.append(f"{s1}") - elif isinstance(s2, UnconstrainedSingleton): - out.append(f"{s1}") elif isinstance(s2, tuple): ss = ','.join(s for s in s2) out.append(f"{s1}@({ss})") diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 4cf34f200e5e..6b0ef293e807 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -190,7 +190,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, closed_jaxpr, name="tmp_xla_computation", platforms=module_context.platforms, - backend_or_name=module_context.backend_or_name, + backend=module_context.backend, axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)), ) result_sharding = _pack_result_sharding(result_shape, result_shardings) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index f011e756da31..471b0c4af74d 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -22,39 +22,38 @@ import enum from functools import partial import itertools -import time -from typing import Any, NamedTuple import logging import threading - -import numpy as np +import time +from typing import Any, NamedTuple import jax +from jax._src import api +from jax._src import array from jax._src import basearray from jax._src import config from jax._src import core -from jax._src import api -from jax._src import array from jax._src import dtypes +from jax._src import lib from jax._src import source_info_util from jax._src import traceback_util from jax._src import util +from jax._src.abstract_arrays import array_types from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.abstract_arrays import array_types from jax._src.interpreters import mlir -from jax._src.interpreters import xla from jax._src.interpreters import pxla -from jax._src import lib -from jax._src.mesh import AbstractMesh, Mesh +from jax._src.interpreters import xla +from jax._src.layout import DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.monitoring import record_event_duration_secs +from jax._src.mesh import AbstractMesh, Mesh +from jax._src.monitoring import record_event_duration_secs, record_event_time_span from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding -from jax._src.sharding_impls import ( - SingleDeviceSharding, NamedSharding, TransferToMemoryKind, +from jax._src.sharding_impls import ( NamedSharding, + SingleDeviceSharding, TransferToMemoryKind, is_single_device_sharding) -from jax._src.layout import Layout, DeviceLocalLayout +import numpy as np JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" @@ -177,12 +176,14 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None): log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG start_time = time.time() yield - elapsed_time = time.time() - start_time + end_time = time.time() + elapsed_time = end_time - start_time if logger.isEnabledFor(log_priority): logger.log(log_priority, fmt.format( fun_name=fun_name, elapsed_time=elapsed_time)) if event is not None: record_event_duration_secs(event, elapsed_time) + record_event_time_span(event, start_time, end_time) def should_tuple_args(num_args: int, platform: str) -> bool: diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 8ba43083d7a3..6b1945746255 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -896,7 +896,7 @@ def is_token(typ, attrs): with ir.InsertionPoint(entry_block): # Make a context just for lowering the dimension value computations module_context = mlir.ModuleContext( - backend_or_name="cpu", platforms=["cpu"], + backend=None, platforms=["cpu"], axis_context=sharding_impls.ShardingContext(0), keepalives=[], channel_iterator=itertools.count(1), host_callbacks=[], module=wrapped_module, context=context, diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 85da4e53b1f7..f552a779a844 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -50,7 +50,6 @@ from jax._src.layout import AutoLayout, DeviceLocalLayout from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import AUTO, NamedSharding -from jax._src.partition_spec import UnconstrainedSingleton from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension from jax._src.lib.mlir import dialects, ir, passmanager @@ -698,10 +697,11 @@ class ModuleContext: module: ir.Module ip: ir.InsertionPoint symbol_table: ir.SymbolTable - backend_or_name: str | xb.XlaBackend | None # The lowering platforms for the module. Can be more than one only when # exporting. platforms: Sequence[str] + # See ModuleContext.get_backend() for backend and platforms usage. + backend: xb.XlaBackend | None axis_context: AxisContext keepalives: list[Any] channel_iterator: Iterator[int] @@ -725,8 +725,8 @@ def axis_env(self) -> sharding_impls.AxisEnv: def __init__( self, *, - backend_or_name: str | xb.XlaBackend | None, platforms: Sequence[str], + backend: xb.XlaBackend | None, axis_context: AxisContext, keepalives: list[Any], channel_iterator: Iterator[int], @@ -745,7 +745,7 @@ def __init__( self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context)) self.ip = ip or ir.InsertionPoint(self.module.body) self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation) - self.backend_or_name = backend_or_name + self.backend = backend self.platforms = platforms self.axis_context = axis_context self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None @@ -760,17 +760,20 @@ def __init__( self.all_default_mem_kind = all_default_mem_kind self.lowering_parameters = lowering_parameters - @property - def backend(self) -> xb.XlaBackend: - # TODO(necula): clean the use of backend and backend_or_name vs. platforms + def get_backend(self) -> xb.XlaBackend: if len(self.platforms) > 1: raise NotImplementedError( "accessing .backend in multi-lowering setting. This can occur when " "lowering a primitive that has not been adapted to multi-platform " "lowering") - if self.backend_or_name is None or isinstance(self.backend_or_name, str): - return xb.get_backend(self.backend_or_name) - return self.backend_or_name + if self.backend is not None: + if xb.canonicalize_platform(self.backend.platform) != self.platforms[0]: + raise ValueError( + "the platform for the specified backend " + f"{xb.canonicalize_platform(self.backend.platform)} is different " + f"from the lowering platform {self.platforms[0]}") + return self.backend + return xb.get_backend(self.platforms[0]) def new_channel(self) -> int: channel = next(self.channel_iterator) @@ -1072,14 +1075,14 @@ def _get_unconstrained_dimensions(s, aval): return (us, all_unconstrained(s, aval), ({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None)) - def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, *, ordered_effects: list[core.Effect], - backend_or_name: str | xb.XlaBackend | None, + # See ModuleContext.get_backend() for backend and platforms usage. platforms: Sequence[str], + backend: xb.XlaBackend | None, axis_context: AxisContext, name_stack: source_info_util.NameStack, donated_args: Sequence[bool], @@ -1170,7 +1173,7 @@ def lower_jaxpr_to_module( else: dim_vars = () - ctx = ModuleContext(backend_or_name=backend_or_name, + ctx = ModuleContext(backend=backend, platforms=platforms, axis_context=axis_context, keepalives=keepalives, channel_iterator=channel_iter, @@ -1202,10 +1205,6 @@ def lower_jaxpr_to_module( arg_layouts=in_layouts, result_layouts=out_layouts, propagated_out_mem_kinds=propagated_out_mem_kinds) - if config.use_shardy_partitioner.value: - pipeline = passmanager.PassManager.parse( - 'builtin.module(sdy-lift-inlined-meshes)') - pipeline.run(ctx.module.operation) try: if not ctx.module.operation.verify(): @@ -1224,6 +1223,12 @@ def emit_diagnostic_info(d): raise ValueError("\n".join(msg_lines) + "\n" + dump_module_message(ctx.module, "verification")) from e + if config.use_shardy_partitioner.value: + with ctx.context: + pipeline = passmanager.PassManager.parse( + 'builtin.module(sdy-lift-inlined-meshes)') + pipeline.run(ctx.module.operation) + return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks, ctx.shape_poly_state) @@ -2595,8 +2600,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): if aval.sharding.mesh._any_axis_collective: unspecified_dims = set(range(aval.ndim)) elif aval.sharding.mesh._any_axis_auto: - unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) - if isinstance(s, UnconstrainedSingleton)} + unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None} return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims) @@ -2892,7 +2896,7 @@ def emit_python_callback( if platform not in {"cpu", "cuda", "rocm", "tpu"}: raise ValueError( f"`EmitPythonCallback` not supported on {platform} backend.") - backend = ctx.module_context.backend + backend = ctx.module_context.get_backend() result_shapes = util.flatten( [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals]) operand_shapes = util.flatten( @@ -3012,13 +3016,14 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function def build_mlir_module_helper( closed_jaxpr: core.ClosedJaxpr, *, name: str, platforms: Sequence[str], - backend_or_name: str, axis_context: AxisContext) -> ir.Module: + backend: xb.XlaBackend | None, + axis_context: AxisContext) -> ir.Module: """Helper to generate pmap-style XLA computations for custom partitioners.""" unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects) if unlowerable_effects: raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}') lowering_result = lower_jaxpr_to_module(name, closed_jaxpr, - backend_or_name=backend_or_name, ordered_effects=[], + backend=backend, ordered_effects=[], name_stack=source_info_util.NameStack(), donated_args=[False] * len(closed_jaxpr.jaxpr.invars), axis_context=axis_context, platforms=platforms, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a61600402b74..ec2d52e48e9c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -871,7 +871,7 @@ def lower_parallel_callable( module_name, closed_jaxpr, ordered_effects=ordered_effects, - backend_or_name=backend, + backend=backend, platforms=platforms, axis_context=sharding_impls.ReplicaAxisContext(axis_env), name_stack=name_stack, @@ -1179,7 +1179,7 @@ def __str__(self): class ResultsHandler: - # `out_avals` is the `Array` global avals when using pjit or xmap. It is the + # `out_avals` is the `Array` global avals when using pjit. It is the # local one when using `pmap`. __slots__ = ("handlers", "out_shardings", "out_avals") @@ -1954,7 +1954,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, module_name, closed_jaxpr, ordered_effects=ordered_effects, - backend_or_name=backend, + backend=backend, platforms=platforms, axis_context=axis_ctx, name_stack=name_stack, @@ -2150,7 +2150,7 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts, donated_invars, out_shardings, out_layouts) -def _concretize_abstract_shardings(shardings, avals, device_assignment): +def _concretize_abstract_out_shardings(shardings, avals, device_assignment): np_dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(np.arange(len(device_assignment))) @@ -2163,8 +2163,14 @@ def _abstract_to_concrete_mesh(abstract_mesh): out = [] for s, a in zip(shardings, avals): if isinstance(s, UnspecifiedValue) and a.sharding is not None: - out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh), - a.sharding.spec)) + if config.use_shardy_partitioner.value: + spec = a.sharding.spec + else: + spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp + for sp in a.sharding.spec]) + if a.sharding.mesh._any_axis_auto else a.sharding.spec) + out.append(NamedSharding( + _abstract_to_concrete_mesh(a.sharding.mesh), spec)) else: out.append(s) return tuple(out) @@ -2243,7 +2249,7 @@ def lower_sharding_computation( unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings] if config.sharding_in_types.value: - out_shardings = _concretize_abstract_shardings( + out_shardings = _concretize_abstract_out_shardings( out_shardings, global_out_avals, device_assignment) # TODO(parkers): One _raw_platform has been unified with platform, diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 25fb2b38f7fa..34f485684e85 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -76,7 +76,7 @@ def __repr__(self): @util.cache(max_size=128, trace_context_in_key=False) def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: if global_mesh.empty: - return global_mesh + return global_mesh is_local_device = np.vectorize( lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices) subcube_indices = [] @@ -96,9 +96,9 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: # subcube that hull will contain non-local devices. if not is_local_device[subcube_indices_tuple].all(): raise ValueError( - "When passing host local inputs to pjit or xmap, devices " - "connected to a single host must form a contiguous subcube of the " - "global device mesh") + "When passing host local inputs to pjit, devices connected to a single" + " host must form a contiguous subcube of the global device mesh" + ) return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names) diff --git a/jax/_src/monitoring.py b/jax/_src/monitoring.py index 3b291de0061a..99e957733ba2 100644 --- a/jax/_src/monitoring.py +++ b/jax/_src/monitoring.py @@ -39,8 +39,17 @@ def __call__(self, event: str, duration_secs: float, ... +class EventTimeSpanListenerWithMetadata(Protocol): + + def __call__( + self, event: str, start_time: float, end_time: float, **kwargs: str | int + ) -> None: + ... + + _event_listeners: list[EventListenerWithMetadata] = [] _event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = [] +_event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = [] def record_event(event: str, **kwargs: str | int) -> None: @@ -64,6 +73,14 @@ def record_event_duration_secs(event: str, duration: float, callback(event, duration, **kwargs) +def record_event_time_span( + event: str, start_time: float, end_time: float, **kwargs: str | int +) -> None: + """Record an event start and end time in seconds (float).""" + for callback in _event_time_span_listeners: + callback(event, start_time, end_time, **kwargs) + + def register_event_listener( callback: EventListenerWithMetadata, ) -> None: @@ -71,6 +88,13 @@ def register_event_listener( _event_listeners.append(callback) +def register_event_time_span_listener( + callback: EventTimeSpanListenerWithMetadata, +) -> None: + """Register a callback to be invoked during record_event_time_span().""" + _event_time_span_listeners.append(callback) + + def register_event_duration_secs_listener( callback : EventDurationListenerWithMetadata) -> None: """Register a callback to be invoked during record_event_duration_secs().""" @@ -80,15 +104,22 @@ def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]: """Get event duration listeners.""" return list(_event_duration_secs_listeners) + +def get_event_time_span_listeners() -> list[EventTimeSpanListenerWithMetadata]: + """Get event time span listeners.""" + return list(_event_time_span_listeners) + + def get_event_listeners() -> list[EventListenerWithMetadata]: """Get event listeners.""" return list(_event_listeners) def clear_event_listeners(): """Clear event listeners.""" - global _event_listeners, _event_duration_secs_listeners + global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners _event_listeners = [] _event_duration_secs_listeners = [] + _event_time_span_listeners = [] def _unregister_event_duration_listener_by_callback( callback: EventDurationListenerWithMetadata) -> None: @@ -108,6 +139,18 @@ def _unregister_event_duration_listener_by_index(index: int) -> None: assert -size <= index < size del _event_duration_secs_listeners[index] + +def _unregister_event_time_span_listener_by_callback( + callback: EventTimeSpanListenerWithMetadata, +) -> None: + """Unregister an event time span listener by callback. + + This function is supposed to be called for testing only. + """ + assert callback in _event_time_span_listeners + _event_time_span_listeners.remove(callback) + + def _unregister_event_listener_by_callback( callback: EventListenerWithMetadata) -> None: """Unregister an event listener by callback. diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index a126cbba0ce1..20c160e0e436 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -473,6 +473,8 @@ def to_block_mapping( mapping.check_invariants() return mapping + replace = dataclasses.replace + class NoBlockSpec: def __repr__(self): @@ -1028,18 +1030,37 @@ def to_json(self) -> bytes: core_map_p = jax_core.Primitive("core_map") core_map_p.multiple_results = True -def core_map(mesh): + +def core_map( + mesh, + *, + compiler_params: Any | None = None, + interpret: bool = False, + debug: bool = False, + cost_estimate: CostEstimate | None = None, +): """Runs a function on a mesh, mapping it over the devices in the mesh. The function should be stateful in that it takes in no inputs and returns no outputs but can mutate closed-over Refs, for example. + + Args: + mesh: The mesh to run the function on. + compiler_params: The compiler parameters to pass to the backend. + interpret: Whether to run the function in interpret mode. + debug: Whether or not to out helpful debugging information. + cost_estimate: The cost estimate of the function. """ def wrapped(f): flat_args, in_tree = tree_util.tree_flatten(((), {})) flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) with jax_core.extend_axis_env_nd(mesh.shape.items()): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args) - out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh) + out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh, + compiler_params=compiler_params, + interpret=interpret, + debug=debug, + cost_estimate=cost_estimate) if out: raise ValueError("core_map-ped functions must not return any outputs.") return tree_util.tree_unflatten(out_tree_thunk(), out) @@ -1047,7 +1068,7 @@ def wrapped(f): @core_map_p.def_effectful_abstract_eval -def _core_map_abstract_eval(*args, jaxpr, mesh): +def _core_map_abstract_eval(*args, jaxpr, mesh, **_): del args if jaxpr.outvars: raise ValueError("core_map must not return any outputs.") @@ -1074,6 +1095,9 @@ def default_mesh_discharge_rule( compiler_params, backend, jaxpr, + debug, + interpret, + cost_estimate, ): """Discharges a ``core_map`` over a mesh to a ``pallas_call``.""" del out_avals # Unused. @@ -1103,6 +1127,9 @@ def body(*args): grid=grid, compiler_params=compiler_params, backend=backend, + interpret=interpret, + debug=debug, + cost_estimate=cost_estimate, )(*args) # ``outs`` lacks the unmodified inputs. Add them back in. all_outs = [None] * len(args) @@ -1120,8 +1147,8 @@ def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwa ) -def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh): - del in_atoms +def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh, **kwargs): + del in_atoms, kwargs with jax_core.extend_axis_env_nd(tuple(mesh.shape.items())): jax_core.check_jaxpr(jaxpr) effs = set() diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index ad9a6cb13f42..80f8a7de544c 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -92,6 +92,8 @@ class TPUCompilerParams(pallas_core.CompilerParams): serialization_format: int = 1 device_type: str | None = None + replace = dataclasses.replace + class TPUMemorySpace(enum.Enum): ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY. VMEM = "vmem" @@ -240,19 +242,38 @@ def _tensorcore_mesh_discharge_rule( *args, mesh, jaxpr, + compiler_params: TPUCompilerParams, + interpret: bool, + debug: bool, + cost_estimate: pallas_core.CostEstimate | None, ): assert isinstance(mesh, TensorCoreMesh) + if compiler_params and not isinstance(compiler_params, TPUCompilerParams): + raise ValueError( + "compiler_params must be a pltpu.TPUCompilerParams" + ) + if not compiler_params: + compiler_params = TPUCompilerParams() if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") core_axis_name, num_cores = list(mesh.shape.items())[0] + if compiler_params.dimension_semantics is not None: + raise ValueError( + "dimension_semantics must be None for TensorCoreMesh" + ) return pallas_core.default_mesh_discharge_rule( in_avals, out_avals, *args, jaxpr=jaxpr, grid=((core_axis_name, num_cores),), - compiler_params=TPUCompilerParams(dimension_semantics=("parallel",)), + compiler_params=compiler_params.replace( + dimension_semantics=("parallel",) + ), + debug=debug, + interpret=interpret, backend="mosaic_tpu", + cost_estimate=cost_estimate, ) pallas_core._core_map_mesh_rules[TensorCoreMesh] = ( diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index d77ae4358703..36e6e47cbf4f 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -483,7 +483,6 @@ class GPUMesh: # Those are NOT CUDA threads. On Hopper they correspond to warpgroups. num_threads: int | None = None axis_names: tuple[str, ...] = () - approx_math: bool = False def __post_init__(self): if len(self.axis_names) != len(self.grid) + (self.num_threads is not None): @@ -521,12 +520,24 @@ def _gpu_mesh_discharge_rule( *args, mesh, jaxpr, + compiler_params, + interpret, + debug, + cost_estimate, ): - assert isinstance(mesh, GPUMesh) + if not isinstance(mesh, GPUMesh): + raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}") if mesh.cluster: raise NotImplementedError if mesh.num_threads is None: raise NotImplementedError + if compiler_params and not isinstance(compiler_params, GPUCompilerParams): + raise TypeError( + "Compiler params must be a GPUCompilerParams, got" + f" {type(compiler_params)}" + ) + if not compiler_params: + compiler_params = GPUCompilerParams() return pallas_core.default_mesh_discharge_rule( in_avals, out_avals, @@ -534,8 +545,13 @@ def _gpu_mesh_discharge_rule( jaxpr=jaxpr, grid=tuple(mesh.shape.items()), backend="mosaic_gpu", - compiler_params=GPUCompilerParams(approx_math=mesh.approx_math), + compiler_params=compiler_params, + debug=debug, + interpret=interpret, + cost_estimate=cost_estimate, ) + + pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 1805f8c0923a..59b1b86f33fc 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -61,14 +61,14 @@ def pallas_call_lowering( "scalar prefetch not implemented in the Triton backend" ) triton_params = compiler_params.get("triton", compiler_params) - num_warps = triton_params.pop("num_warps", 4) + num_warps = triton_params.get("num_warps", 4) num_warps = 4 if num_warps is None else num_warps [lowering_platform] = ctx.platforms or ctx.module_context.platforms if lowering_platform == "rocm": - num_stages = triton_params.pop("num_stages", 1) + num_stages = triton_params.get("num_stages", 1) num_stages = 1 if num_stages is None else num_stages else: - num_stages = triton_params.pop("num_stages", 3) + num_stages = triton_params.get("num_stages", 3) num_stages = 3 if num_stages is None else num_stages if debug: diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 21199b9bfd68..88f8853f53cd 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -148,45 +148,51 @@ def _eval_jaxpr_discharge_state( if d and isinstance(v.aval, AbstractRef)} for eqn in jaxpr.eqns: - should_discharge = [id(v.aval) in refs_to_discharge for v in eqn.invars] - if eqn.primitive is core.mutable_array_p: - [invar], [outvar] = eqn.invars, eqn.outvars - ans = env.read(invar) - refs_to_discharge.add(id(outvar.aval)) - elif eqn.primitive is core.freeze_p: - [invar], [outvar] = eqn.invars, eqn.outvars - ans = env.read(invar) - refs_to_discharge.remove(id(invar.aval)) - elif (any(should_discharge) - or core.internal_mutable_array_effect in eqn.effects - ): - if eqn.primitive in _partial_discharge_rules: - rule: DischargeRule = partial(_partial_discharge_rules[eqn.primitive], should_discharge) - elif eqn.primitive in _discharge_rules: - rule = _discharge_rules[eqn.primitive] + name_stack = ( + source_info_util.current_name_stack() + eqn.source_info.name_stack + ) + traceback = eqn.source_info.traceback + with source_info_util.user_context( + traceback, name_stack=name_stack), eqn.ctx.manager: + should_discharge = [id(v.aval) in refs_to_discharge for v in eqn.invars] + if eqn.primitive is core.mutable_array_p: + [invar], [outvar] = eqn.invars, eqn.outvars + ans = env.read(invar) + refs_to_discharge.add(id(outvar.aval)) + elif eqn.primitive is core.freeze_p: + [invar], [outvar] = eqn.invars, eqn.outvars + ans = env.read(invar) + refs_to_discharge.remove(id(invar.aval)) + elif (any(should_discharge) + or core.internal_mutable_array_effect in eqn.effects + ): + if eqn.primitive in _partial_discharge_rules: + rule: DischargeRule = partial(_partial_discharge_rules[eqn.primitive], should_discharge) + elif eqn.primitive in _discharge_rules: + rule = _discharge_rules[eqn.primitive] + else: + raise NotImplementedError("No state discharge rule implemented for " + f"primitive: {eqn.primitive}") + invals = map(env.read, eqn.invars) + in_avals = [v.aval for v in eqn.invars] + out_avals = [v.aval for v in eqn.outvars] + new_invals, ans = rule( + in_avals, out_avals, *invals, **eqn.params) + for invar, should, new_inval in zip(eqn.invars, should_discharge, new_invals): + if new_inval is not None: + if not should: + raise ValueError( + f"Did not ask for inval to be discharged but it was. ({invar=}," + f" {new_inval=})" + ) + env.write(invar, new_inval) # type: ignore[arg-type] else: - raise NotImplementedError("No state discharge rule implemented for " - f"primitive: {eqn.primitive}") - invals = map(env.read, eqn.invars) - in_avals = [v.aval for v in eqn.invars] - out_avals = [v.aval for v in eqn.outvars] - new_invals, ans = rule( - in_avals, out_avals, *invals, **eqn.params) - for invar, should, new_inval in zip(eqn.invars, should_discharge, new_invals): - if new_inval is not None: - if not should: - raise ValueError( - f"Did not ask for inval to be discharged but it was. ({invar=}," - f" {new_inval=})" - ) - env.write(invar, new_inval) # type: ignore[arg-type] - else: - # Default primitive rule, similar to `core.eval_jaxpr`. Note that here - # we assume any higher-order primitives inside of the jaxpr are *not* - # stateful. - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - ans = eqn.primitive.bind(*subfuns, *map(env.read, eqn.invars), - **bind_params) + # Default primitive rule, similar to `core.eval_jaxpr`. Note that here + # we assume any higher-order primitives inside of the jaxpr are *not* + # stateful. + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + ans = eqn.primitive.bind(*subfuns, *map(env.read, eqn.invars), + **bind_params) if eqn.primitive.multiple_results: map(env.write, eqn.outvars, ans) else: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index a2f60887706b..cb0cf39568d1 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -17,6 +17,7 @@ import collections from collections.abc import Callable, Generator, Iterable, Sequence +from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack, contextmanager import datetime import functools @@ -31,6 +32,7 @@ import tempfile import textwrap import threading +import time from typing import Any, TextIO import unittest import warnings @@ -115,6 +117,12 @@ 'deterministic, interactive'), ) +TEST_NUM_THREADS = config.int_flag( + 'jax_test_num_threads', 0, + help='Number of threads to use for running tests. 0 means run everything ' + 'in the main thread. Using > 1 thread is experimental.' +) + # We sanitize test names to ensure they work with "unitttest -k" and # "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k # does not. We replace sequences of problematic characters with a single '_'. @@ -498,29 +506,20 @@ def device_supports_buffer_donation(): ) -@contextmanager -def set_host_platform_device_count(nr_devices: int): - """Context manager to set host platform device count if not specified by user. +def request_cpu_devices(nr_devices: int): + """Requests at least `nr_devices` CPU devices. - This should only be used by tests at the top level in setUpModule(); it will - not work correctly if applied to individual test cases. + request_cpu_devices should be called at the top-level of a test module before + main() runs. + + It is not guaranteed that the number of CPU devices will be exactly + `nr_devices`: it may be more or less, depending on how exactly the test is + invoked. Test cases that require a specific number of devices should skip + themselves if that number is not met. """ - prev_xla_flags = os.getenv("XLA_FLAGS") - flags_str = prev_xla_flags or "" - # Don't override user-specified device count, or other XLA flags. - if "xla_force_host_platform_device_count" not in flags_str: - os.environ["XLA_FLAGS"] = (flags_str + - f" --xla_force_host_platform_device_count={nr_devices}") - # Clear any cached backends so new CPU backend will pick up the env var. - xla_bridge.get_backend.cache_clear() - try: - yield - finally: - if prev_xla_flags is None: - del os.environ["XLA_FLAGS"] - else: - os.environ["XLA_FLAGS"] = prev_xla_flags + if xla_bridge.NUM_CPU_DEVICES.value < nr_devices: xla_bridge.get_backend.cache_clear() + config.update("jax_num_cpu_devices", nr_devices) def skip_on_flag(flag_name, skip_value): @@ -1007,8 +1006,140 @@ def sample_product(*args, **kw): """ return parameterized.parameters(*sample_product_testcases(*args, **kw)) +# We use a reader-writer lock to protect test execution. Tests that may run in +# parallel acquire a read lock; tests that are not thread-safe acquire a write +# lock. +if hasattr(util, 'Mutex'): + _test_rwlock = util.Mutex() + + def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): + _test_rwlock.reader_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.reader_unlock() + + + @contextmanager + def thread_hostile_test(): + "Decorator for tests that are not thread-safe." + _test_rwlock.assert_reader_held() + _test_rwlock.reader_unlock() + _test_rwlock.writer_lock() + try: + yield + finally: + _test_rwlock.writer_unlock() + _test_rwlock.reader_lock() +else: + # TODO(phawkins): remove this branch when jaxlib 0.5.0 is the minimum. + _test_rwlock = threading.Lock() + + def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): + _test_rwlock.acquire() + try: + test(result) # type: ignore + finally: + _test_rwlock.release() + + + @contextmanager + def thread_hostile_test(): + yield # No reader-writer lock, so we get no parallelism. + +class ThreadSafeTestResult: + """ + Wraps a TestResult to make it thread safe. + + We do this by accumulating API calls and applying them in a batch under a + lock at the conclusion of each test case. + + We duck type instead of inheriting from TestResult because we aren't actually + a perfect implementation of TestResult, and would rather get a loud error + for things we haven't implemented. + """ + def __init__(self, lock: threading.Lock, result: unittest.TestResult): + self.lock = lock + self.test_result = result + self.actions: list[Callable] = [] + + def startTest(self, test: unittest.TestCase): + del test + self.start_time = time.time() + + def stopTest(self, test: unittest.TestCase): + stop_time = time.time() + with self.lock: + # We assume test_result is an ABSL _TextAndXMLTestResult, so we can + # override how it gets the time. + time_getter = self.test_result.time_getter + try: + self.test_result.time_getter = lambda: self.start_time + self.test_result.startTest(test) + for callback in self.actions: + callback() + self.test_result.time_getter = lambda: stop_time + self.test_result.stopTest(test) + finally: + self.test_result.time_getter = time_getter + + def addSuccess(self, test: unittest.TestCase): + self.actions.append(lambda: self.test_result.addSuccess(test)) + + def addSkip(self, test: unittest.TestCase, reason: str): + self.actions.append(lambda: self.test_result.addSkip(test, reason)) + + def addError(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addError(test, err)) + + def addFailure(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addFailure(test, err)) + + def addExpectedFailure(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addExpectedFailure(test, err)) + + def addDuration(self, test: unittest.TestCase, elapsed): + self.actions.append(lambda: self.test_result.addDuration(test, elapsed)) + + +class JaxTestSuite(unittest.TestSuite): + """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1. + + Caution: this test suite does not run setUpClass or setUpModule methods if + thread parallelism is enabled. + """ + + def __init__(self, suite: unittest.TestSuite): + super().__init__(list(suite)) + + def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult: + if TEST_NUM_THREADS.value <= 0: + return super().run(result) + + executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) + lock = threading.Lock() + futures = [] + + def run_test(test): + "Recursively runs tests in a test suite or test case." + if isinstance(test, unittest.TestSuite): + for subtest in test: + run_test(subtest) + else: + test_result = ThreadSafeTestResult(lock, result) + futures.append(executor.submit(_run_one_test, test, test_result)) + + with executor: + run_test(self) + for future in futures: + future.result() + + return result + class JaxTestLoader(absltest.TestLoader): + suiteClass = JaxTestSuite + def getTestCaseNames(self, testCaseClass): names = super().getTestCaseNames(testCaseClass) if _TEST_TARGETS.value: @@ -1091,10 +1222,8 @@ class JaxTestCase(parameterized.TestCase): 'jax_legacy_prng_key': 'error', } - _compilation_cache_exit_stack: ExitStack | None = None + _context_stack: ExitStack | None = None - def tearDown(self) -> None: - assert core.reset_trace_state() def setUp(self): super().setUp() @@ -1105,25 +1234,26 @@ def setUp(self): # b) it returns values in int32 range, which RandomState requires. self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode())) - @classmethod - def setUpClass(cls): - cls._compilation_cache_exit_stack = ExitStack() - stack = cls._compilation_cache_exit_stack - stack.enter_context(global_config_context(**cls._default_config)) + # TODO(phawkins): use TestCase.enterContext once Python 3.11 is the minimum + # version. + self._context_stack = ExitStack() + self.addCleanup(self._context_stack.close) + stack = self._context_stack + stack.enter_context(global_config_context(**self._default_config)) if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: + assert TEST_NUM_THREADS.value <= 1, "Persistent compilation cache is not thread-safe." stack.enter_context(config.enable_compilation_cache(True)) stack.enter_context(config.raise_persistent_cache_errors(True)) stack.enter_context(config.persistent_cache_min_compile_time_secs(0)) stack.enter_context(config.persistent_cache_min_entry_size_bytes(0)) - tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) - compilation_cache.set_cache_dir(tmp_dir) - stack.callback(lambda: compilation_cache.reset_cache()) + stack.enter_context(config.compilation_cache_dir(tmp_dir)) + stack.callback(compilation_cache.reset_cache) - @classmethod - def tearDownClass(cls): - cls._compilation_cache_exit_stack.close() + def tearDown(self) -> None: + assert core.reset_trace_state() + super().tearDown() def rng(self): return self._rng diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index bb92afebe8e9..b9645cbefb5e 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -377,6 +377,7 @@ def _lower_tpu_kernel( pipeline = [ ( "func.func(tpu-infer-vector-layout{" + f" hardware-generation={hardware_generation}" f" sublane-count={sl_cnt} lane-count={l_cnt}" "})" ), diff --git a/jax/_src/util.py b/jax/_src/util.py index f262570b9a1b..9450659a2ef4 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -685,3 +685,6 @@ def test_event(name: str, *args) -> None: if not test_event_listener: return test_event_listener(name, *args) + +if hasattr(jaxlib_utils, "Mutex"): + Mutex = jaxlib_utils.Mutex diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 28148761c8a4..bbe6631753cb 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -122,6 +122,14 @@ "inline without async dispatch.", ) +NUM_CPU_DEVICES = config.int_flag( + name="jax_num_cpu_devices", + default=-1, + help="Number of CPU devices to use. If not provided, the value of " + "the XLA flag --xla_force_host_platform_device_count is used." + " Must be set before JAX is initialized.", +) + # Warn the user if they call fork(), because it's not going to go well for them. def _at_fork(): @@ -249,8 +257,8 @@ def make_cpu_client( if collectives is None: collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value if _CPU_ENABLE_GLOO_COLLECTIVES.value: - collectives_impl = 'gloo' - warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is ' + collectives_impl = 'gloo' + warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is ' 'deprecated. Please use `jax.config.update(' '"jax_cpu_collectives_implementation", "gloo")` instead.', DeprecationWarning, @@ -268,12 +276,22 @@ def make_cpu_client( f"{collectives_impl}. Available implementations are " f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.") + num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None + if xla_client._version < 303 and num_devices is not None: + xla_flags = os.getenv("XLA_FLAGS") or "" + os.environ["XLA_FLAGS"] = ( + f"{xla_flags} --xla_force_host_platform_device_count={num_devices}" + ) + num_devices = None + # TODO(phawkins): pass num_devices directly when version 303 is the minimum. + kwargs = {} if num_devices is None else {"num_devices": num_devices} return xla_client.make_cpu_client( asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value, distributed_client=distributed.global_state.client, node_id=distributed.global_state.process_id, num_nodes=distributed.global_state.num_processes, collectives=collectives, + **kwargs, ) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index a4bb168efb2f..6ec621d68ff7 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -14,7 +14,6 @@ """Tests for serialization and deserialization of GDA.""" import asyncio -import contextlib import math from functools import partial import os @@ -36,13 +35,7 @@ import tensorstore as ts jax.config.parse_flags_with_absl() -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class CheckpointTest(jtu.JaxTestCase): diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index f23bd58c48d3..0cde96aeb36d 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -69,18 +69,6 @@ def _named_test(**kwargs): class CallTfTest(tf_test_util.JaxToTfTestCase): - @classmethod - def setUpClass(cls): - # One TF device of each device_type - cls.tf_devices = [] - for tf_device in tf.config.list_logical_devices(): - if tf_device.device_type == "TPU_SYSTEM": - continue # A virtual device - if all(tf_device.device_type != d.device_type for d in cls.tf_devices): - cls.tf_devices.append(tf_device) - - super().setUpClass() - def setUp(self): if tf is None: raise unittest.SkipTest("Test requires tensorflow") @@ -88,6 +76,13 @@ def setUp(self): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + # One TF device of each device_type + self.tf_devices = [] + for tf_device in tf.config.list_logical_devices(): + if tf_device.device_type == "TPU_SYSTEM": + continue # A virtual device + if all(tf_device.device_type != d.device_type for d in self.tf_devices): + self.tf_devices.append(tf_device) self.warning_ctx = jtu.ignore_warning( message=( "(jax2tf.convert with native_serialization=False has been deprecated" @@ -798,7 +793,7 @@ def f_jax(x): jax_and_tf_platforms = ( set(jax_platforms) & {d.device_type.lower() - for d in self.__class__.tf_devices}) + for d in self.tf_devices}) lowering_platforms = ("tpu", "cpu", "cuda") @@ -833,7 +828,7 @@ def f_jax(x): f_jax, native_serialization=True, native_serialization_platforms=lowering_platforms)) - for tf_device in self.__class__.tf_devices: + for tf_device in self.tf_devices: with self.subTest(tf_device.device_type): logging.info( f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}") diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 7d3313be6c92..bea2b76cb7cf 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -50,22 +50,17 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() # One TF device of each device_type - cls.tf_devices = [] + self.tf_devices = [] for tf_device in (tf.config.list_logical_devices("TPU") + tf.config.list_logical_devices("GPU") + tf.config.list_logical_devices()): if tf_device.device_type == "TPU_SYSTEM": continue # A virtual device - if all(tf_device.device_type != d.device_type for d in cls.tf_devices): - cls.tf_devices.append(tf_device) - - super().setUpClass() - - def setUp(self): - super().setUp() + if all(tf_device.device_type != d.device_type for d in self.tf_devices): + self.tf_devices.append(tf_device) self.warning_ctx = jtu.ignore_warning( message="jax2tf.convert with native_serialization=False has been deprecated" ) @@ -1666,7 +1661,7 @@ def f_jax(x): f_jax, native_serialization=True, native_serialization_platforms=("cpu", "cuda", "tpu")) - for tf_device in self.__class__.tf_devices: + for tf_device in self.tf_devices: logging.info( f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}") with tf.device(tf_device): diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 8fe9a1dd9254..653ddce7dca4 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -19,13 +19,13 @@ """ from collections.abc import Sequence -import contextlib from functools import partial import logging import re from typing import Any import unittest +from absl import app from absl.testing import absltest import jax @@ -47,16 +47,15 @@ import tensorflow as tf config.parse_flags_with_absl() +jtu.request_cpu_devices(8) # Must come after initializing the flags from jax.experimental.jax2tf.tests import tf_test_util -_exit_stack = contextlib.ExitStack() topology = None -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) +def initialize_tf_tpu(): global topology if jtu.test_device_matches(["tpu"]): with jtu.ignore_warning(message="the imp module is deprecated"): @@ -67,8 +66,7 @@ def setUpModule(): else: topology = None -def tearDownModule(): - _exit_stack.close() +app.call_after_init(initialize_tf_tpu) class ShardingTest(tf_test_util.JaxToTfTestCase): diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 6e9ba6a382db..3fd34ccdad9d 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -607,7 +607,7 @@ def ref(q, k, v): exit(0) batch_size = 1 - num_q_heads = 4 + num_q_heads = 2 num_kv_heads = 1 prof_spec = None seq_lens = (4096, 32768) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 275fcd84e44c..31de258b2a2f 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -200,9 +200,11 @@ def run(refs): grid=(batch_size, num_q_tiles, num_q_heads), num_threads=3, axis_names=("batch", "q_seq", "heads", "wg"), - approx_math=True, ) - @pl.core_map(mesh) + + @pl.core_map( + mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) + ) def _kernel_entry(): compute_wgs = 2 tiling = plgpu.TilingTransform((64, 64)) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 599e4ab0e56c..910fa4728d69 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -652,10 +652,11 @@ def _shard_map_lowering_shardy( sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) args = (*ctx.dim_var_values, *in_nodes) - manual_axes = sub_ctx.axis_context.manual_axes - mesh_shape = mesh.shape - manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes]) - if manual_axes_size == 1: + # The order of manual axes should match the order of mesh.axis_names to avoid + # non-determinism issues. + manual_axes = [a for a in mesh.axis_names + if a in sub_ctx.axis_context.manual_axes] + if np.prod([mesh.shape[a] for a in manual_axes]) == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. with core.extend_axis_env_nd(tuple(mesh.shape.items())): out_nodes, _ = mlir.jaxpr_subcomp( diff --git a/jax/monitoring.py b/jax/monitoring.py index 374e301b970c..4c9996da582c 100644 --- a/jax/monitoring.py +++ b/jax/monitoring.py @@ -22,9 +22,11 @@ """ from jax._src.monitoring import ( + clear_event_listeners as clear_event_listeners, record_event_duration_secs as record_event_duration_secs, + record_event_time_span as record_event_time_span, record_event as record_event, register_event_duration_secs_listener as register_event_duration_secs_listener, register_event_listener as register_event_listener, - clear_event_listeners as clear_event_listeners, + register_event_time_span_listener as register_event_time_span_listener, ) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index e69432e89384..843ccb112871 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -214,6 +214,7 @@ nanobind_extension( deps = [ "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", "@nanobind", "@xla//third_party/python_runtime:headers", ], diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 37f9a35596d6..1aef1ebdd86d 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -240,6 +240,7 @@ cc_test( deps = [ ":tpu_dialect", "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 6ef809c4cb6a..a486d8fef84d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -847,6 +847,7 @@ def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncO ]; let constructor = "::mlir::tpu::createInferVectorLayoutPass()"; let options = [ + Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, ]; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 307f3582f007..0156798ca88d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -79,7 +79,9 @@ std::unique_ptr> createCanonicalizeMosaicPass( int hardware_generation = -1); std::unique_ptr> createInferVectorLayoutPass( - std::array target_shape = {8, 128}); + int hardware_generation = -1, + std::array target_shape = {8, 128}, + const TpuTilingFlags &tpu_tiling_flags = {}); std::unique_ptr> createRelayoutInsertionPass( std::array target_shape = {8, 128}); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 3a8263573544..4c6353f8c504 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -429,7 +429,8 @@ FailureOr appendConstant(RewriteContext &ctx, func::FuncOp func, MemRefType arg_type, inferMemref( MemRefType::get(value_ty.getShape(), value_ty.getElementType()), - ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{})); + ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{}, + /*is_kernel_argument=*/true)); const BlockArgument argument = entry_block.insertArgument( entry_block.getNumArguments() - 1, arg_type, UnknownLoc::get(mlir_ctx)); const FunctionType func_ty = func.getFunctionType(); @@ -3334,21 +3335,38 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); if (needs_physical_broadcast == std::array{true, false}) { // Sublane broadcast - if (layout_in.bitwidth() != 32) { - return op.emitOpError( - "Not implemented: Only 32-bit supported for sublane broadcast"); - } + const int bitwidth = layout_in.bitwidth(); + const int packing = layout_in.packing(); if (num_tiles != 1) { return op.emitOpError( "Not implemented: Only native tiling supported"); } TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 2), 1); TPU_ASSERT_OP(offsets_in[0].has_value()); - const int64_t offset = *offsets_in[0]; + const int64_t sublane_offset = *offsets_in[0] / packing; + const int64_t subelement_offset = *offsets_in[0] % packing; const DenseI32ArrayAttr indices = builder.getDenseI32ArrayAttr( - SmallVector(ctx.target_shape[0], offset)); + SmallVector(ctx.target_shape[0], sublane_offset)); src_tiles.Each([&](const absl::Span src_idx, - Value *const src_tile) { + Value *const src_vreg) { + Value dst_vreg = *src_vreg; + // Replicate the value within each sublane. + if (packing != 1) { + auto vreg_int_ty = getNativeVregType( + builder.getIntegerType(bitwidth), ctx.target_shape); + auto src_vreg_int = + builder.create(vreg_int_ty, dst_vreg); + auto unpack_elem = builder.create( + getNativeVregType(builder.getI32Type(), ctx.target_shape), + src_vreg_int, subelement_offset, tpu::PackFormat::kInterleaved); + SmallVector packed_vregs(packing, unpack_elem); + auto vreg_int = builder.create( + vreg_int_ty, packed_vregs, tpu::PackFormat::kInterleaved); + dst_vreg = builder.create(dst_vreg.getType(), + vreg_int); + } + dst_vreg = builder.create(dst_vreg.getType(), dst_vreg, + indices, 0); SmallVector dst_starts(dst_tiles_implicit_shape.size()); SmallVector dst_limits(dst_tiles_implicit_shape.size()); for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { @@ -3360,10 +3378,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, dst_limits[i] = dst_starts[i] + 1; } } - updateSlice(dst_tiles, - builder.create( - src_tile->getType(), *src_tile, indices, 0), - dst_starts, dst_limits); + updateSlice(dst_tiles, dst_vreg, dst_starts, dst_limits); }); } else if (needs_physical_broadcast == std::array{false, true}) { // Lane broadcast diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 958fdea96945..e1f690efa85e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -302,8 +302,8 @@ LogicalResult canonicalize_elementwise(int hardware_generation_, // TODO(mvoz): Look into (1) what it would take to support these ops // natively on later hardware, and (2) how to better organize this list. bool needs_cast = hardware_generation_ <= 5 || isa(op) || - isa(op) || isa(op) || - isa(op) || isa(op); + isa(op) || isa(op) || + isa(op); if (needs_cast && element_type.isBF16()) { auto target_f32 = builder.create(op.getLoc(), target_f32_ty, operand) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 046b642f98a3..05667a847691 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -43,10 +43,12 @@ namespace mlir::tpu { // tpu_tiling_flags: A struct of flags indicating which large tiling modes are // enabled by XLA for memrefs. // bitwidth: The bitwidth of the element type of the operand. +// is_kernel_argument: Whether the operand is a kernel argument. int getTilingFactor(const int num_lanes, const int hardware_generation, const int64_t sublane_count, const TpuTilingFlags &tpu_tiling_flags, - const int8_t bitwidth) { + const int8_t bitwidth, + const bool is_kernel_argument) { CHECK(llvm::isPowerOf2_32(bitwidth)); CHECK_LE(4, bitwidth); CHECK_LE(bitwidth, 32); @@ -61,7 +63,11 @@ int getTilingFactor(const int num_lanes, const int hardware_generation, if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) { return sublane_count * 4; } - if (bitwidth == 16 && tpu_tiling_flags.use_x16_large_second_minor) { + // 16-bit values are generally always possible to relayout on the fly in v6, + // so we allow large 2nd minor tiling whenever possible. We can't do this + // for kernel arguments, because the layout of those is controlled by XLA. + if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor || + (!is_kernel_argument && hardware_generation >= 6))) { return sublane_count * 2; } return sublane_count; @@ -84,6 +90,7 @@ FailureOr inferLayout(MemRefType memref_ty, const int hardware_generation, std::array target_shape, const TpuTilingFlags &tpu_tiling_flags, + bool is_kernel_argument, int64_t leading_tile_rows = 0) { if (auto tiled_layout_attr = dyn_cast(memref_ty.getLayout())) { @@ -119,7 +126,8 @@ FailureOr inferLayout(MemRefType memref_ty, const int64_t leading_tile = getTilingFactor( llvm::divideCeil(memref_ty.getShape().back(), lane_count), - hardware_generation, sublane_count, tpu_tiling_flags, bitwidth) * + hardware_generation, sublane_count, tpu_tiling_flags, bitwidth, + is_kernel_argument) * lane_count; SmallVector tiles{xla::Tile({leading_tile})}; if (bitwidth != 32) { @@ -139,7 +147,7 @@ FailureOr inferLayout(MemRefType memref_ty, if (leading_tile_rows == 0) { leading_tile_rows = getTilingFactor(second_minor, hardware_generation, sublane_count, - tpu_tiling_flags, bitwidth); + tpu_tiling_flags, bitwidth, is_kernel_argument); } SmallVector tiles{xla::Tile({leading_tile_rows, lane_count})}; if (bitwidth != 32) { @@ -186,6 +194,7 @@ FailureOr inferMemref(MemRefType memref, const int hardware_generation, std::array target_shape, const TpuTilingFlags &tpu_tiling_flags, + bool is_kernel_argument, int64_t leading_tile_rows) { if (isa(memref.getElementType())) { const Attribute semaphore_mem = tpu::MemorySpaceAttr::get( @@ -209,7 +218,7 @@ FailureOr inferMemref(MemRefType memref, FAILUREOR_ASSIGN_OR_RETURN( const TiledLayoutAttr layout, inferLayout(memref, hardware_generation, target_shape, tpu_tiling_flags, - leading_tile_rows)); + is_kernel_argument, leading_tile_rows)); const ArrayRef tiles = layout.getTiles(); if (failed(checkTiles(memref.getContext(), tiles))) { @@ -248,7 +257,8 @@ LogicalResult inferOp(Operation &op, const int hardware_generation, FAILUREOR_ASSIGN_OR_RETURN( const MemRefType new_memref_ty, inferMemref(memref_ty, hardware_generation, target_shape, - tpu_tiling_flags, leading_tile_rows)); + tpu_tiling_flags, /*is_kernel_argument=*/false, + leading_tile_rows)); alloca_op.getResult().setType(new_memref_ty); if (memref_ty != new_memref_ty) { OpBuilder builder(alloca_op->getContext()); @@ -265,9 +275,10 @@ LogicalResult inferOp(Operation &op, const int hardware_generation, } else if (auto alloca_op = dyn_cast(op)) { TypedValue arg = alloca_op.getResult(); const MemRefType memref_ty = alloca_op.getResult().getType(); - FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation, - target_shape, tpu_tiling_flags)); + FAILUREOR_ASSIGN_OR_RETURN( + const MemRefType new_memref_ty, + inferMemref(memref_ty, hardware_generation, target_shape, + tpu_tiling_flags, /*is_kernel_argument=*/false)); alloca_op.getResult().setType(new_memref_ty); if (memref_ty != new_memref_ty) { OpBuilder builder(alloca_op->getContext()); @@ -320,7 +331,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation, FAILUREOR_ASSIGN_OR_RETURN( MemRefType new_memref_ty, inferMemref(memref_ty, hardware_generation, target_shape, - tpu_tiling_flags, leading_tile_rows)); + tpu_tiling_flags, /*is_kernel_argument=*/true, + leading_tile_rows)); arg.setType(new_memref_ty); new_arg_types.push_back(arg.getType()); if (memref_ty != new_memref_ty) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h index ed2a34793536..f2ab7c624eb1 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h @@ -14,6 +14,7 @@ namespace mlir::tpu { FailureOr inferMemref(MemRefType memref, int hardware_generation, std::array target_shape, const TpuTilingFlags& tpu_tiling_flags, + bool is_kernel_argument, int64_t leading_tile_rows = 0); const std::string_view kLeadingTileRows = "leading_tile_rows"; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index c5448a3df514..d189994d9564 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -92,9 +92,13 @@ LogicalResult verifyDivisibleIndex(Value tiled_index, int64_t tiling, int dim, // have corresponding native instructions. class VectorLayoutInferer { public: - explicit VectorLayoutInferer(std::array target_shape) - : target_shape_({target_shape[0], target_shape[1]}), - default_tiling_(target_shape) {} + explicit VectorLayoutInferer(int hardware_generation, + std::array target_shape, + const TpuTilingFlags &tpu_tiling_flags) + : hardware_generation_(hardware_generation), + target_shape_({target_shape[0], target_shape[1]}), + default_tiling_(target_shape), + tpu_tiling_flags_(tpu_tiling_flags) {} #define TPU_CHECK_OP(cond, msg) \ if (!(cond)) { \ @@ -1062,16 +1066,14 @@ class VectorLayoutInferer { // should always use that when sublane broadcasting is required. if (src_tiled_ishape[0] != dst_tiled_ishape[0] && layout.offsets()[0] != std::nullopt) { - if (layout.bitwidth() != kNativeBitwidth) { - NYI("Only 32-bit broadcasts supported"); - } LayoutOffsets offsets = layout.offsets(); // At the moment relayout can only produce replicated sublanes when // converting to (8, 128) if the input was in (1, 128) tiling - if (layout.tiling()[0] == 1) { + if (layout.tiling()[0] == 1 && layout.bitwidth() == kNativeBitwidth) { offsets[0] = std::nullopt; } - layout = VectorLayout(layout.bitwidth(), offsets, default_tiling_, + layout = VectorLayout(layout.bitwidth(), offsets, + nativeTiling(layout.bitwidth()), layout.implicit_dim()); } LayoutOffsets offsets = layout.offsets(); @@ -1705,6 +1707,21 @@ class VectorLayoutInferer { } auto &layout = *some_layout; bool select_native = allUsersRequireNativeTiling(op->getResult(0)); + // We might want to reconsider enabling native this aggressively in cases + // when it would introduce a lot of padding (e.g. when the value only has + // a small second minor size, but large minor size). + if (dst_ty.getElementTypeBitWidth() == 16) { + // TPUv6 has good support for compute in 16-bit and cheap retiling between + // large 2nd minor and the default tiling, so we bias towards large tiles. + select_native |= hardware_generation_ >= 6 || + tpu_tiling_flags_.use_x16_large_second_minor; + } else if (dst_ty.getElementTypeBitWidth() == 8) { + select_native |= tpu_tiling_flags_.use_x8_large_second_minor; + } else if (dst_ty.getElementTypeBitWidth() == 4) { + select_native |= tpu_tiling_flags_.use_x4_large_second_minor; + } else { + return op->emitOpError("Unsupported target bitwidth for truncation"); + } auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_, layout.implicit_dim()); auto dst_layout = VectorLayout( @@ -2019,15 +2036,15 @@ class VectorLayoutInferer { default_tiling_[1]}; } + int hardware_generation_; std::array target_shape_; std::array default_tiling_; + TpuTilingFlags tpu_tiling_flags_; // TODO(b/342235360): Deprecate force_first_tile_offsets_ once we fully // remove the restriction that offsets must fall within the first tile. bool force_first_tile_offsets_ = false; - // Address alignment requirement, counted in 32-bit increments. - static constexpr int64_t kVmemAlignment32 = 128; // TODO(apaszke): This is not really native on newer generations of TPUs. // Get rid of this temporary stopgap. static constexpr int8_t kNativeBitwidth = 32; @@ -2035,24 +2052,39 @@ class VectorLayoutInferer { struct InferVectorLayoutPass : public impl::InferVectorLayoutPassBase { - InferVectorLayoutPass(std::array target_shape) { + InferVectorLayoutPass(int hardware_generation, + std::array target_shape, + TpuTilingFlags tpu_tiling_flags) { + this->hardware_generation = hardware_generation; this->sublane_count = target_shape[0]; this->lane_count = target_shape[1]; + this->tpu_tiling_flags = tpu_tiling_flags; } void runOnOperation() override { + // Fail if hardware_generation has not been set from the default value. + if (hardware_generation < 0) { + getOperation().emitError("hardware_generation must be set") << hardware_generation; + signalPassFailure(); + return; + } func::FuncOp func = getOperation(); - VectorLayoutInferer run({sublane_count, lane_count}); + VectorLayoutInferer run(hardware_generation, {sublane_count, lane_count}, + tpu_tiling_flags); if (run.infer(func).failed()) { signalPassFailure(); } } + + TpuTilingFlags tpu_tiling_flags; }; } // namespace std::unique_ptr> createInferVectorLayoutPass( - std::array target_shape) { - return std::make_unique(target_shape); + int hardware_generation, std::array target_shape, + const TpuTilingFlags &tpu_tiling_flags) { + return std::make_unique( + hardware_generation, target_shape, tpu_tiling_flags); } } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.cc b/jaxlib/mosaic/dialect/tpu/vreg_util.cc index 7dc5c13c073e..75c15f6a9f6e 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/include/mlir/IR/Types.h" #include "mlir/include/mlir/IR/Value.h" +#include "mlir/include/mlir/IR/ValueRange.h" #include "mlir/include/mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" @@ -91,8 +92,6 @@ TypedValue getZerosLikeVector(ImplicitLocOpBuilder &builder, FailureOr> getX32VmaskByPaddingEnd( ImplicitLocOpBuilder &builder, int64_t padding, const std::array target_shape, int64_t dim) { - VectorType i32_vreg_ty = - getNativeVregType(builder.getI32Type(), target_shape); if (dim != 0 && dim != 1) { return builder.emitError() << "Expected a 2D vector for getX32VmaskByPaddingEnd"; @@ -100,22 +99,29 @@ FailureOr> getX32VmaskByPaddingEnd( if (padding < 0 || padding > target_shape[dim]) { return builder.emitError() - << "Padding must be in [0, target_shape[dim]). Padding: " << padding + << "Padding must be in [0, target_shape[dim]]. Padding: " << padding << ", target_shape[dim]: " << target_shape[dim]; } - Value padding_vreg = - getFullVector(builder, i32_vreg_ty, - builder.getI32IntegerAttr(target_shape[dim] - padding)); - - return cast>( - builder - .create( - arith::CmpIPredicate::slt, - builder.create(i32_vreg_ty, - builder.getI32IntegerAttr(dim)), - padding_vreg) - .getResult()); + auto idx_const = [&builder](int64_t idx) { + return IdxConst(idx, builder, builder.getLoc()); + }; + + tpu::CreateMaskOp mask_op; + const VectorType vmask_ty = getNativeVregOrVmaskType( + builder.getI1Type(), /*layout_bitwidth=*/32, target_shape); + if (dim == 0) { + mask_op = builder.create( + vmask_ty, ValueRange{idx_const(0), idx_const(0)}, + ValueRange{idx_const(target_shape[0] - padding), + idx_const(target_shape[1])}); + } else { + mask_op = builder.create( + vmask_ty, ValueRange{idx_const(0), idx_const(0)}, + ValueRange{idx_const(target_shape[0]), + idx_const(target_shape[1] - padding)}); + } + return cast>(mask_op.getResult()); } LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc index dadbac133fbf..4b9da9505d75 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc @@ -21,9 +21,12 @@ limitations under the License. #include #include +#include "llvm/include/llvm/ADT/TypeSwitch.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/include/mlir/IR/Attributes.h" #include "mlir/include/mlir/IR/Builders.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/BuiltinOps.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" @@ -38,39 +41,56 @@ namespace mlir::tpu { namespace { +using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Optional; -MATCHER_P2(IsConstantOpWithSplatValue, type, splat_value, "") { +MATCHER_P2(IsConstantOpWithSplatOrScalarValue, type, value, "") { auto constant_op = dyn_cast(arg.getDefiningOp()); if (constant_op == nullptr) { *result_listener << "Expected a constant op, got " << debugString(arg); return false; } - auto dense_attr = dyn_cast(constant_op.getValue()); - if (dense_attr == nullptr) { - *result_listener << "Expected a dense elements attr, got " - << debugString(arg); - return false; - } - if (dense_attr.getType() != type) { - *result_listener << "Expected a dense elements attr with type " - << debugString(type) << ", got " - << debugString(dense_attr.getType()); - return false; - } - if (!dense_attr.isSplat()) { - *result_listener << "Expected a splat dense elements attr, got " - << debugString(dense_attr); - return false; - } - if (auto s = dense_attr.template getSplatValue(); - s != splat_value) { - *result_listener << "Expected a splat dense elements attr with value " - << splat_value << ", got " << s; - return false; - } - return true; + + return llvm::TypeSwitch(constant_op.getValue()) + .template Case([&](auto attr) { + // If it's dense, it must be splat. + if (attr.getType() != type) { + *result_listener << "Expected a dense elements attr with type " + << debugString(type) << ", got " + << debugString(attr.getType()); + return false; + } + if (!attr.isSplat()) { + *result_listener << "Expected a splat dense elements attr, got " + << debugString(attr); + return false; + } + if (auto s = attr.template getSplatValue(); + s != value) { + *result_listener << "Expected a splat dense elements attr with value " + << value << ", got " << s; + return false; + } + return true; + }) + .template Case([&](auto attr) { + if (attr.getType() != type) { + *result_listener << "Expected a attr with type " << debugString(type) + << ", got " << debugString(attr.getType()); + return false; + } + if (auto s = attr.getInt(); s != value) { + *result_listener << "Expected a attr with value " << value << ", got " + << s; + return false; + } + return true; + }) + .template Default([&](auto attr) { + *result_listener << "Unsupported attribute type: " << debugString(attr); + return false; + }); } MATCHER_P2(IsVectorTypeWithShape, shape, elem_ty, "") { @@ -150,7 +170,7 @@ TEST_F(VregUtilTest, GetFullVector) { TypedValue vec = getFullVector(Builder(), vty, Builder().getI32IntegerAttr(0x1)); - EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0x1})); + EXPECT_THAT(vec, IsConstantOpWithSplatOrScalarValue(vty, int32_t{0x1})); } TEST_F(VregUtilTest, GetFullLikeVector) { @@ -161,14 +181,14 @@ TEST_F(VregUtilTest, GetFullLikeVector) { TypedValue vec = getFullLikeVector(Builder(), in_vec, Builder().getF32FloatAttr(2.0f)); - EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{2.0f})); + EXPECT_THAT(vec, IsConstantOpWithSplatOrScalarValue(vty, float{2.0f})); } TEST_F(VregUtilTest, GetZerosVector) { VectorType vty = VectorType::get({2, 4}, Builder().getI32Type()); TypedValue vec = getZerosVector(Builder(), vty); - EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0})); + EXPECT_THAT(vec, IsConstantOpWithSplatOrScalarValue(vty, int32_t{0})); } TEST_F(VregUtilTest, GetZerosLikeVector) { @@ -178,7 +198,7 @@ TEST_F(VregUtilTest, GetZerosLikeVector) { vty.getElementType(), Builder().getF32FloatAttr(1.0f))); TypedValue vec = getZerosLikeVector(Builder(), in_vec); - EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{0.0f})); + EXPECT_THAT(vec, IsConstantOpWithSplatOrScalarValue(vty, float{0.0f})); } TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim0) { @@ -188,18 +208,18 @@ TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim0) { /*dim=*/0); ASSERT_TRUE(succeeded(vec)); - auto cmp_op = dyn_cast(vec.value().getDefiningOp()); - ASSERT_TRUE(cmp_op != nullptr); - EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt); - - auto iota_op = dyn_cast(cmp_op.getLhs().getDefiningOp()); - ASSERT_TRUE(iota_op != nullptr); - EXPECT_THAT(iota_op.getDimension(), Optional(Eq(0))); - - EXPECT_THAT( - cmp_op.getRhs(), - IsConstantOpWithSplatValue( - VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{3})); + auto mask_op = dyn_cast(vec.value().getDefiningOp()); + ASSERT_TRUE(mask_op != nullptr); + EXPECT_THAT(ArrayRef({mask_op.getLow()[0], mask_op.getLow()[1]}), + ElementsAre(IsConstantOpWithSplatOrScalarValue( + Builder().getIndexType(), int64_t{0}), + IsConstantOpWithSplatOrScalarValue( + Builder().getIndexType(), int64_t{0}))); + EXPECT_THAT(ArrayRef({mask_op.getHigh()[0], mask_op.getHigh()[1]}), + ElementsAre(IsConstantOpWithSplatOrScalarValue( + Builder().getIndexType(), int64_t{3}), + IsConstantOpWithSplatOrScalarValue( + Builder().getIndexType(), int64_t{8}))); } TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim1) { @@ -209,18 +229,18 @@ TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim1) { /*dim=*/1); ASSERT_TRUE(succeeded(vec)); - auto cmp_op = dyn_cast(vec.value().getDefiningOp()); - ASSERT_TRUE(cmp_op != nullptr); - EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt); - - auto iota_op = dyn_cast(cmp_op.getLhs().getDefiningOp()); - ASSERT_TRUE(iota_op != nullptr); - EXPECT_THAT(iota_op.getDimension(), Optional(Eq(1))); - - EXPECT_THAT( - cmp_op.getRhs(), - IsConstantOpWithSplatValue( - VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{5})); + auto mask_op = dyn_cast(vec.value().getDefiningOp()); + ASSERT_TRUE(mask_op != nullptr); + EXPECT_THAT(ArrayRef({mask_op.getLow()[0], mask_op.getLow()[1]}), + ElementsAre(IsConstantOpWithSplatOrScalarValue( + Builder().getIndexType(), int64_t{0}), + IsConstantOpWithSplatOrScalarValue( + Builder().getIndexType(), int64_t{0}))); + EXPECT_THAT(ArrayRef({mask_op.getHigh()[0], mask_op.getHigh()[1]}), + ElementsAre(IsConstantOpWithSplatOrScalarValue( + Builder().getIndexType(), int64_t{4}), + IsConstantOpWithSplatOrScalarValue( + Builder().getIndexType(), int64_t{5}))); } } // namespace diff --git a/jaxlib/setup.py b/jaxlib/setup.py index c2efd3d7b7a7..b3a37a25f1b2 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -61,8 +61,7 @@ def has_ext_modules(self): packages=['jaxlib', 'jaxlib.xla_extension'], python_requires='>=3.10', install_requires=[ - 'scipy>=1.10', - "scipy>=1.11.1; python_version>='3.12'", + 'scipy>=1.11.1', 'numpy>=1.25', 'ml_dtypes>=0.2.0', ], diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index 28201233566a..6b612b26dce8 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -18,6 +18,7 @@ limitations under the License. #include "nanobind/nanobind.h" #include "absl/cleanup/cleanup.h" #include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" namespace nb = nanobind; @@ -225,4 +226,19 @@ NB_MODULE(utils, m) { PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr())); m.attr("safe_zip") = nb::steal( PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr())); + + // Python has no reader-writer lock in its standard library, so we expose + // bindings around absl::Mutex. + nb::class_(m, "Mutex") + .def(nb::init<>()) + .def("lock", &absl::Mutex::Lock, nb::call_guard()) + .def("unlock", &absl::Mutex::Unlock) + .def("assert_held", &absl::Mutex::AssertHeld) + .def("reader_lock", &absl::Mutex::ReaderLock, + nb::call_guard()) + .def("reader_unlock", &absl::Mutex::ReaderUnlock) + .def("assert_reader_held", &absl::Mutex::AssertReaderHeld) + .def("writer_lock", &absl::Mutex::WriterLock, + nb::call_guard()) + .def("writer_unlock", &absl::Mutex::WriterUnlock); } \ No newline at end of file diff --git a/setup.py b/setup.py index b3bd4a3466d7..39508388ba8a 100644 --- a/setup.py +++ b/setup.py @@ -60,8 +60,7 @@ def load_version_module(pkg_path): 'numpy>=1.25', "numpy>=1.26.0; python_version>='3.12'", 'opt_einsum', - 'scipy>=1.10', - "scipy>=1.11.1; python_version>='3.12'", + 'scipy>=1.11.1', ], extras_require={ # Minimum jaxlib version; used in testing. diff --git a/tests/array_test.py b/tests/array_test.py index 9618a8cf4665..97bf71a5216b 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -43,20 +43,12 @@ from jax._src import prng jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(8) with contextlib.suppress(ImportError): import pytest pytestmark = pytest.mark.multiaccelerator -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() - def create_array(shape, sharding, global_data=None): if global_data is None: diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 4485f5d4f41e..f9dd3ce52b58 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import threading import time from typing import Sequence @@ -29,6 +28,7 @@ import numpy as np config.parse_flags_with_absl() +jtu.request_cpu_devices(8) def _colocated_cpu_devices( @@ -53,18 +53,6 @@ def _colocated_cpu_devices( _count_colocated_python_specialization_cache_miss = jtu.count_events( "colocated_python_func._get_specialized_func") -_exit_stack = contextlib.ExitStack() - - -def setUpModule(): - # TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT - # prepares CPU devices by its own. - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - - -def tearDownModule(): - _exit_stack.close() - class ColocatedPythonTest(jtu.JaxTestCase): diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 428e518eab51..73d76c1a4938 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -52,18 +52,12 @@ FAKE_COMPILE_TIME = 10 _counts = Counter() # Map event name to count - -def setUpModule(): - monitoring.register_event_listener(increment_event_count) - - -def tearDownModule(): - monitoring._unregister_event_listener_by_callback(increment_event_count) - - def increment_event_count(event): _counts[event] += 1 +monitoring.register_event_listener(increment_event_count) + + def msg_exists_in_logs(msg: str, records: list[logging.LogRecord], level: int | None = None) -> bool: return any(msg in record.getMessage() for record in records diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 18693a7bb2c3..419e7b18dfed 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Sequence -import contextlib import io import re import textwrap @@ -29,6 +28,7 @@ import numpy as np jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(2) def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]: fake_stdin = io.StringIO() @@ -41,14 +41,6 @@ def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringI def _format_multiline(text): return textwrap.dedent(text).lstrip() -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() - foo = 2 class CliDebuggerTest(jtu.JaxTestCase): diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 6afb41645405..0fc9665ceaa5 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections -import contextlib import functools import textwrap import unittest @@ -35,19 +34,13 @@ rich = None jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(2) debug_print = debugging.debug_print def _format_multiline(text): return textwrap.dedent(text).lstrip() -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() class DummyDevice: def __init__(self, platform, id): diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index f6625e86ca14..df79e6aaf6df 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -1487,6 +1487,7 @@ def f(i): class JumbleTest(jtu.JaxTestCase): def setUp(self): + super().setUp() if jax.config.x64_enabled: raise unittest.SkipTest() @parameterized.parameters((True,), (False,)) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index e8b1afc224b7..d5878fa50a5a 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -48,11 +48,11 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]: class PrimitiveTest(jtu.JaxTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() # Pick one device from each available platform - cls.devices = [] - cls.platforms = [] + self.devices = [] + self.platforms = [] for backend in ["cpu", "gpu", "tpu"]: try: devices = jax.devices(backend) @@ -60,10 +60,9 @@ def setUpClass(cls): devices = [] for d in devices: - if d.platform not in cls.platforms: - cls.platforms.append(d.platform) - cls.devices.append(d) - super().setUpClass() + if d.platform not in self.platforms: + self.platforms.append(d.platform) + self.devices.append(d) # For each primitive we export for all platforms that are available and # compare the results of running the exported code and running the native @@ -128,7 +127,7 @@ def export_and_compare_to_native( tol: float | None = None): devices = [ d - for d in self.__class__.devices + for d in self.devices if d.platform not in unimplemented_platforms ] logging.info("Using devices %s", [str(d) for d in devices]) diff --git a/tests/export_test.py b/tests/export_test.py index da0e9daf2f00..b13cf3a623e6 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -56,14 +56,8 @@ CAN_SERIALIZE = False config.parse_flags_with_absl() +jtu.request_cpu_devices(8) -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() ### Setup for testing lowering with effects @dataclasses.dataclass(frozen=True) @@ -165,17 +159,16 @@ def serde_exported(*fun_args, **fun_kwargs): @jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version) class JaxExportTest(jtu.JaxTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() # Find the available platforms - cls.platforms = [] + self.platforms = [] for backend in ["cpu", "gpu", "tpu"]: try: jax.devices(backend) except RuntimeError: continue - cls.platforms.append(backend) - super().setUpClass() + self.platforms.append(backend) def test_basic_export_only(self): @jax.jit @@ -1505,7 +1498,7 @@ def test_multi_platform(self): module_str) # Call with argument placed on different plaforms - for platform in self.__class__.platforms: + for platform in self.platforms: x_device = jax.device_put(x, jax.devices(platform)[0]) res_exp = exp.call(x_device) self.assertAllClose( @@ -1530,7 +1523,7 @@ def test_multi_platform_nested(self): self.assertEqual(1, count_sine) # Call with argument placed on different plaforms - for platform in self.__class__.platforms: + for platform in self.platforms: if platform == "tpu": continue x_device = jax.device_put(x, jax.devices(platform)[0]) res_exp = exp2.call(x_device) @@ -1674,7 +1667,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a) # Call with argument placed on different plaforms - for platform in self.__class__.platforms: + for platform in self.platforms: run_devices = jax.devices(platform)[0:len(export_devices)] if len(run_devices) != len(export_devices): continue diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 2e91792aa950..922b37ffa440 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib + import threading import unittest @@ -34,6 +34,7 @@ import numpy as np config.parse_flags_with_absl() +jtu.request_cpu_devices(2) effect_p = core.Primitive('effect') effect_p.multiple_results = True @@ -132,15 +133,6 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out mlir.register_lowering(callback_p, callback_effect_lowering) -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() - - class JaxprEffectsTest(jtu.JaxTestCase): def test_trivial_jaxpr_has_no_effects(self): diff --git a/tests/layout_test.py b/tests/layout_test.py index f958de5cf5bc..903b17886283 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import math from functools import partial from absl.testing import absltest @@ -28,14 +27,7 @@ from jax.experimental.compute_on import compute_on config.parse_flags_with_absl() - -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class LayoutTest(jtu.JaxTestCase): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 0da09e232deb..65f7c8145138 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1686,7 +1686,7 @@ def testTriangularSolveSingularBatched(self): @jtu.sample_product( n=[1, 4, 5, 20, 50, 100], - batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()], + batch_size=[(), (2,), (3, 4)], dtype=int_types + float_types + complex_types ) def testExpm(self, n, batch_size, dtype): diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index b84903618fab..7fb87086d9e6 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -32,9 +32,9 @@ class MockGPUTest(jtu.JaxTestCase): def setUp(self): + super().setUp() if not jtu.test_device_matches(["gpu"]): self.skipTest("Mocking devices only works on the GPU backend.") - super().setUp() @jtu.skip_under_pytest("Test must run in an isolated process") def testMockDeviceCount(self): diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py index 44ec4e2f9529..71ce8f1dde1c 100644 --- a/tests/mock_gpu_topology_test.py +++ b/tests/mock_gpu_topology_test.py @@ -31,9 +31,9 @@ class MockGPUTopologyTest(jtu.JaxTestCase): def setUp(self): + super().setUp() if not jtu.test_device_matches(["gpu"]): self.skipTest("Mocking devices only works on the GPU backend.") - super().setUp() @jtu.skip_under_pytest("Test must run in an isolated process") def testMockDeviceCount(self): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index dd41b264aa1c..d8ae6d3c2ff6 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -171,7 +171,7 @@ def setUp(self): self.context = mlir.make_ir_context() if mgpu_dialect is not None: mgpu_dialect.register_dialect(self.context) - self.enter_context(jtu.global_config_context(jax_traceback_filtering="off")) + self.enter_context(config.traceback_filtering("off")) self.enter_context(self.context) self.enter_context(ir.Location.unknown()) @@ -1756,13 +1756,13 @@ def kernel(ctx, src, dst, _): class TorchTest(TestCase): - @classmethod - def setUpClass(cls): + def setUp(self): + super().setUp() try: import torch except ImportError: raise unittest.SkipTest("Test requires PyTorch") - cls.torch = torch + self.torch = torch def test_basic(self): def kernel(ctx, i_gmem, o_gmem, _): diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 057731cb5d55..1fc6fe1e9298 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib from unittest import SkipTest import tracemalloc as tm @@ -25,15 +24,7 @@ from jax._src import test_util as jtu jax.config.parse_flags_with_absl() - -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class MultiDeviceTest(jtu.JaxTestCase): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 4bd64a5a8dce..a9d5361e7c10 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1481,14 +1481,11 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem): index_map=lambda i, j: (0, 0)) ], ) - mesh = plgpu.GPUMesh( - grid=(1,), - num_threads=3, - axis_names=("_", "wg",), - approx_math=True, - ) + mesh = plgpu.GPUMesh(grid=(1,), num_threads=3, axis_names=("_", "wg")) def run(refs): - @pl.core_map(mesh) + @pl.core_map( + mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) + ) def _kernel_entry(): pipeline(*refs) @jax.jit @@ -1535,13 +1532,12 @@ def tiled_add_kernel(x_smem, y_smem, o_smem): transforms=[])], ) mesh = plgpu.GPUMesh( - grid=(1,), - num_threads=num_compute_wgs + 1, - axis_names=("_", "wg",), - approx_math=True, + grid=(1,), num_threads=num_compute_wgs + 1, axis_names=("_", "wg") ) def run(refs): - @pl.core_map(mesh) + @pl.core_map( + mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) + ) def _kernel_entry(): pipeline(*refs) @jax.jit diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index ec4e5805d056..290d35490bef 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1116,18 +1116,35 @@ def kernel(x_ref, y_ref, out_ref): @parameterized.parameters( ("int32", "float32"), ("float32", "float32"), + ("bfloat16", "bfloat16"), ) def test_true_divide(self, dtype, out_dtype): + if jtu.test_device_matches(["tpu"]): + if out_dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6): + self.skipTest("bfloat16 is not supported on older TPU generations") + if not jtu.if_cloud_tpu_at_least(2025, 1, 9): + self.skipTest("Requires libtpu built after 2025-01-09") + elif jtu.test_device_matches(["gpu"]): + if dtype == "bfloat16": + self.skipTest("bfloat16 not supported") + @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8,), out_dtype), + out_shape=jax.ShapeDtypeStruct((8, 8), out_dtype), ) def kernel(x_ref, y_ref, o_ref): o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) - np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y)) + x = jnp.repeat(x, 8, axis=0).reshape(8, 8) + y = jnp.tile(y, 8).reshape(8, 8) + rtol = 8e-3 if dtype == "bfloat16" else 1e-6 + np.testing.assert_allclose( + jnp.true_divide(x, y).astype(jnp.float32), + kernel(x, y).astype(jnp.float32), + rtol=rtol, + ) @parameterized.parameters("float16", "bfloat16") def test_true_divide_unsupported(self, dtype): diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 373388d97691..b53a057f46fd 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -30,10 +30,12 @@ from jax._src import api_util from jax._src import checkify from jax._src import config +from jax._src import core as jax_core from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax.control_flow.for_loop import for_loop from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl import jax.numpy as jnp @@ -41,8 +43,10 @@ if sys.platform != "win32": from jax.experimental.pallas import tpu as pltpu + from jax.experimental.pallas import triton as plgpu else: pltpu = None + plgpu = None # TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs. @@ -2361,5 +2365,47 @@ class PallasCallNamedGridInterpretTest(PallasCallNamedGridTest): INTERPRET = True +def _find_pallas_call_in_jaxpr( + jaxpr: jax_core.Jaxpr) -> jax_core.JaxprEqn | None: + for eqn in jaxpr.eqns: + call_eqn = None + if eqn.primitive == pallas_call.pallas_call_p: + call_eqn = eqn + elif 'jaxpr' in eqn.params: + call_eqn = _find_pallas_call_in_jaxpr(eqn.params['jaxpr']) + if call_eqn is not None: + return call_eqn + return None + + +class PallasCompilerParamsTest(PallasBaseTest): + def test_triton_params_consistent_across_double_jit(self): + # Test for https://github.com/jax-ml/jax/issues/25714 + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Triton backend only works on GPU.") + params = plgpu.TritonCompilerParams(num_warps=8) + + @jax.jit + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), + compiler_params=params) + def copy_kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + @functools.partial(jax.jit, static_argnames=["z"]) + def plus_z(x, z): + return copy_kernel(x+z) + + x = 0. + extracted_params = _find_pallas_call_in_jaxpr( + plus_z.trace(x, 1).jaxpr).params["compiler_params"] + self.assertEqual(plus_z(0., 1.), 1.) + self.assertEqual(extracted_params["triton"]["num_warps"], 8) + extracted_params = _find_pallas_call_in_jaxpr( + plus_z.trace(x, 2).jaxpr).params["compiler_params"] + self.assertEqual(plus_z(0., 2.), 2.) + self.assertEqual(extracted_params["triton"]["num_warps"], 8) + + if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index cb92794b2758..8f948050b8a5 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -176,6 +176,23 @@ def kernel(x_ref, y_ref, out_ref): )(x, y) np.testing.assert_array_equal(out, inp.reshape(m * 2, n)) + @parameterized.parameters([jnp.int32, jnp.int16, jnp.int8]) + def test_row_broadcast(self, dtype): + if not jtu.if_cloud_tpu_at_least(2024, 1, 9): + self.skipTest("Requires libtpu built after 2024-01-09") + if not self.INTERPRET and jtu.get_tpu_version() < 5: + self.skipTest("Requires TPUv5+") + def kernel(x_ref, y_ref): + y_ref[...] = jnp.broadcast_to(x_ref[pl.ds(3, 1)], y_ref.shape) + m, n = 4, 1024 + x = jax.random.randint( + jax.random.key(12), (m, n), minval=-1000, maxval=1000, dtype=jnp.int32 + ).astype(dtype) + y = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((m, n), dtype) + )(x) + np.testing.assert_array_equal(y, jnp.broadcast_to(x[3:4], y.shape)) + def test_tpu_unsigned_int(self): def body(x_ref, o_ref): # Test cast from uint16 -> uint32 diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 05154ae0376c..46add86e2784 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -37,6 +37,7 @@ import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec import numpy as np +from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member jax.config.parse_flags_with_absl() @@ -89,15 +90,19 @@ def testPGLEProfilerGetFDOProfileLarge(self): mesh = jtu.create_mesh((2,), ('x',)) its = 500 + compiler_options = { + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + } + # TODO(b/37664749): Remove this flag once the bug is fixed. + if xla_extension_version > 302: + compiler_options['xla_gpu_enable_command_buffer'] = '' + else: + compiler_options['xla_gpu_graph_min_graph_size'] = '100000' @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={ - 'xla_gpu_enable_latency_hiding_scheduler': 'True', - # TODO(b/37664749): Remove this flag once the bug is fixed. - 'xla_gpu_enable_command_buffer': '', - }, + compiler_options=compiler_options, ) def f(x): agg = x @@ -127,15 +132,19 @@ def testAutoPgle(self): mesh = jtu.create_mesh((2,), ('x',)) with tempfile.TemporaryDirectory() as dump_dir: + compile_options = { + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + 'xla_dump_to': dump_dir, + 'xla_gpu_experimental_dump_fdo_profiles': 'True', + } + # TODO(b/376647494): Remove this flag once the bug is fixed. + if xla_extension_version <= 302: + compile_options['xla_gpu_graph_min_graph_size'] = '100000' @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={ - 'xla_gpu_enable_latency_hiding_scheduler': 'True', - 'xla_dump_to': dump_dir, - 'xla_gpu_experimental_dump_fdo_profiles': 'True' - }, + compiler_options=compile_options, ) def f(x): return x * 2 @@ -209,15 +218,19 @@ def testAutoPgleWithPersistentCache(self): mesh = jtu.create_mesh((2,), ('x',)) with tempfile.TemporaryDirectory() as dump_dir: + compiler_options = { + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + 'xla_dump_to': dump_dir, + 'xla_gpu_experimental_dump_fdo_profiles': 'True', + } + # TODO(b/376647494): Remove this flag once the bug is fixed. + if xla_extension_version <= 302: + compiler_options['xla_gpu_graph_min_graph_size'] = '100000' @partial( jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={ - 'xla_gpu_enable_latency_hiding_scheduler': 'True', - 'xla_dump_to': dump_dir, - 'xla_gpu_experimental_dump_fdo_profiles': 'True' - }, + compiler_options=compiler_options, ) def f(x): agg = x diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 63620e5ad0c9..3fcc5c81ad91 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import OrderedDict, namedtuple -import contextlib import re from functools import partial import logging @@ -64,14 +63,7 @@ config.parse_flags_with_absl() -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) def create_array(global_shape, global_mesh, mesh_axes, global_data=None, dtype=np.float32): @@ -5589,11 +5581,11 @@ def test_only_auto(self, mesh): @jax.jit def f(x, x2): y = x * 2 - self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, None)) + self.assertEqual(y.sharding.spec, P(None, None)) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, None)) + self.assertEqual(z.sharding.spec, P(None, None)) a = z @ x2 - self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + self.assertEqual(a.sharding.spec, P(None, None)) return a out = f(arr, arr.T) @@ -5626,9 +5618,9 @@ def f(x, x2): arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None))) out = f(arr, arr2) - self.assertEqual(out.sharding, NamedSharding(mesh2, P('x', None))) + self.assertEqual(out.sharding, NamedSharding(mesh2, P('x',))) lowered_text = f.lower(arr, arr2).as_text() - self.assertTrue(lowered_text.count("unspecified_dims") == 3) + self.assertTrue(lowered_text.count("unspecified_dims") == 5) mesh3 = jtu.create_mesh((2, 2), ('x', 'y'), axis_types={mesh_lib.AxisTypes.User: 'y', @@ -5669,11 +5661,11 @@ def f(x): {mesh_lib.AxisTypes.Auto: ('x', 'y')}) y = sharding_cast(y, y.sharding.with_mesh(auto_mesh)) with mesh_lib.set_abstract_mesh(auto_mesh): - self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + self.assertEqual(y.sharding.spec, P(None, None)) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + self.assertEqual(z.sharding.spec, P(None, None)) a = z @ z.T - self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + self.assertEqual(a.sharding.spec, P(None, None)) a = sharding_cast( a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None))) self.assertEqual(a.sharding.spec, P('x', None)) @@ -5707,7 +5699,7 @@ def f(x): self.assertEqual(a.sharding.spec, P(None, None)) a = sharding_cast( a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None))) - self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, None)) + self.assertEqual(a.sharding.spec, P(None, None)) return a out = f(arr) @@ -5729,11 +5721,11 @@ def f(x): {mesh_lib.AxisTypes.Auto: 'x', mesh_lib.AxisTypes.User: 'y'}) y = sharding_cast(y, y.sharding.with_mesh(mix_mesh)) with mesh_lib.set_abstract_mesh(mix_mesh): - self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, 'y')) + self.assertEqual(y.sharding.spec, P(None, 'y')) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, 'y')) + self.assertEqual(z.sharding.spec, P(None, 'y')) a = z @ z.T - self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + self.assertEqual(a.sharding.spec, P(None, None)) a = sharding_cast( a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None))) self.assertEqual(a.sharding.spec, P('x', None)) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index a9de8c896414..795f7d4bf9e8 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -15,7 +15,6 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor -import contextlib from functools import partial import itertools as it import gc @@ -54,15 +53,8 @@ from jax._src.util import safe_map, safe_zip config.parse_flags_with_absl() +jtu.request_cpu_devices(8) -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]] diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 199b90fe524e..efa877fd3a91 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -36,14 +36,7 @@ import numpy as np config.parse_flags_with_absl() - -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(2) map, unsafe_map = util.safe_map, map diff --git a/tests/roofline_test.py b/tests/roofline_test.py index e5003947181b..aec34ff22a57 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -14,7 +14,6 @@ from __future__ import annotations from functools import partial -import contextlib from absl.testing import absltest from jax.sharding import PartitionSpec as P @@ -28,6 +27,7 @@ jax.config.parse_flags_with_absl() +jtu.request_cpu_devices(8) def create_inputs( @@ -45,18 +45,6 @@ def create_inputs( return mesh, tuple(arrays) -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - - -def tearDownModule(): - _exit_stack.close() - - class RooflineTest(jtu.JaxTestCase): def test_scalar_collectives(self): a_spec = P("z", ("x", "y")) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 383746899570..25d46c5add2e 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib - import jax import jax.numpy as jnp import numpy as np @@ -24,15 +22,7 @@ from jax.experimental.shard_map import shard_map jax.config.parse_flags_with_absl() - -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() +jtu.request_cpu_devices(8) class ShardAlikeDownstreamTest(jtu.JaxTestCase): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 74fdb7a47888..19cc870881cf 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -15,7 +15,6 @@ from __future__ import annotations from collections.abc import Callable, Generator, Iterable, Iterator, Sequence -import contextlib from functools import partial import itertools as it import math @@ -53,6 +52,7 @@ from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member config.parse_flags_with_absl() +jtu.request_cpu_devices(8) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -70,16 +70,6 @@ def create_inputs(a_sharding, b_sharding): return mesh, m1, m2 -# Run all tests with 8 CPU devices. -_exit_stack = contextlib.ExitStack() - -def setUpModule(): - _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - -def tearDownModule(): - _exit_stack.close() - - class ShardMapTest(jtu.JaxTestCase): def test_identity(self): @@ -1925,7 +1915,7 @@ def f(x): self.assertAllClose(v*v, f(v), check_dtypes=False) def test_partial_auto_propagate_through(self): - mesh = jtu.create_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2, 2), ('i', 'j', 'k')) sharding = jax.sharding.NamedSharding(mesh, P('i')) def g(x): @@ -1943,16 +1933,17 @@ def f(x): )(x) v = jnp.arange(32.0).reshape(4, 8) - v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i'))) + v = jax.device_put(v, sharding) if config.use_shardy_partitioner.value: self.assertIn( 'in_shardings=[<@mesh, [{?}, {?}]>]' - ' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j"}', + ' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j", "k"}', f.lower(v).as_text(), ) else: self.assertIn( - 'sharding={devices=[1,1,2,2]<=[2,2]T(1,0) last_tile_dims={manual, replicated}}', + 'sharding={devices=[1,1,4,2]<=[2,4]T(1,0) last_tile_dims={manual,' + ' replicated}}', f.lower(v).as_text('hlo'), ) actual = f(v)