diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml
index 92fef2cc29af..551b9ef3bba7 100644
--- a/.github/workflows/jax-array-api.yml
+++ b/.github/workflows/jax-array-api.yml
@@ -28,7 +28,7 @@ jobs:
       with:
         repository: data-apis/array-api-tests
         # TODO(jakevdp) update this to a stable release/tag when available.
-        ref: '1572b129c6682211abfe139e112592226c361a6c'  # Latest commit as of 2024-12-04
+        ref: 'f7a74a685d78d98203fa991fc19a5d1fda57c212'  # Latest commit as of 2025-01-07
         submodules: 'true'
         path: 'array-api-tests'
     - name: Set up Python ${{ matrix.python-version }}
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 346c399b3332..1835f0857c1d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -19,6 +19,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
 * Changes:
   * The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
     supported version until June 2025.
+  * The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
+    supported version until June 2025.
   * {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
     `optimize='optimal'`. This avoids exponentially-scaling trace-time in
     the case of many arguments ({jax-issue}`#25214`).
diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_gpu.sh
index 7bc2492781b2..416d985d380c 100644
--- a/ci/run_pytest_gpu.sh
+++ b/ci/run_pytest_gpu.sh
@@ -56,6 +56,5 @@ echo "Running GPU tests..."
 "$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \
 tests examples \
 --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \
---deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \
 --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \
---deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric
\ No newline at end of file
+--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric
diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md
index c56e82c77450..eaa3bc7317c8 100644
--- a/docs/advanced-autodiff.md
+++ b/docs/advanced-autodiff.md
@@ -247,7 +247,7 @@ perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
 
 ### Hessian-vector products with `jax.grad`-of-`jax.grad`
 
-One thing you can do with higher-order {func}`jax.vmap` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)
+One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)
 
 A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)).
 
@@ -259,11 +259,11 @@ for any $v \in \mathbb{R}^n$.
 
 The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.
 
-Luckily, {func}`jax.vmap` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity:
+Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity:
 
 $\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$,
 
-where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.vmap` is efficient.
+where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient.
 
 In JAX code, you can just write this:
 
@@ -357,7 +357,7 @@ To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfw
 
 ### Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff)
 
-JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.vmap` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.
+JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.
 
 
 #### JVPs in math
@@ -473,7 +473,7 @@ vjp :: (a -> b) -> a -> (b, CT b -> CT a)
 
 where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`.
 
-This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.vmap` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
+This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
 
 There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).
 
diff --git a/docs/aot.md b/docs/aot.md
index a5dd69a72b8f..2dc4eadf388f 100644
--- a/docs/aot.md
+++ b/docs/aot.md
@@ -73,14 +73,8 @@ see the {ref}`export` APIs.
 See the {mod}`jax.stages` documentation for more details on what functionality
 the lowering and compiled functions provide.
 
-In place of `jax.jit` above, you can also `lower(...)` the result of
-{func}`jax.pmap`, as well as `pjit` and `xmap` (from
-{mod}`jax.experimental.pjit` and {mod}`jax.experimental.maps` respectively). In
-each case, you can `compile()` the result similarly.
-
 All optional arguments to `jit`---such as `static_argnums`---are respected in
-the corresponding lowering, compilation, and execution. Again the same goes for
-`pmap`, `pjit`, and `xmap`.
+the corresponding lowering, compilation, and execution.
 
 In the example above, we can replace the arguments to `lower` with any objects
 that have `shape` and `dtype` attributes:
diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md
index 3ef927e056f2..0938a5da944f 100644
--- a/docs/gradient-checkpointing.md
+++ b/docs/gradient-checkpointing.md
@@ -354,6 +354,61 @@ print_saved_residuals(loss_checkpoint2, params, x, y)
 
 Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`.
 
+#### Custom policies for offload
+
+You may consider offloading to CPU memory instead of recomputing when checkpointing to save accelerator memory. `jax.checkpoint_policies.offload_dot_with_no_batch_dims` can offload the results of matrix multiplications with no batch dimensions to the CPU.
+
+```{code-cell}
+from jax.ad_checkpoint import checkpoint
+
+def checkpoint_offload_dot_with_no_batch_dims(self):
+  policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
+      "device", "pinned_host")
+
+  @functools.partial(checkpoint, policy=policy)
+  def f(x):
+    x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
+    x = jnp.sin(x)
+    x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
+    x = jnp.sin(x)
+    x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
+    x = jnp.sin(x)
+    x = jnp.sum(x)
+    return x
+```
+
+One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed.
+
+```{code-cell}
+from jax.ad_checkpoint import checkpoint, checkpoint_name
+from jax._src import test_util as jtu
+
+def checkpoint_names_saved_offloaded_recomputed(self):
+  mesh = jtu.create_mesh((2,), ("x",))
+  shape = (256, 128)
+  np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
+  s = NamedSharding(mesh, P("x"))
+  inp = jax.device_put(np_inp, s)
+
+  policy = jax.checkpoint_policies.save_and_offload_only_these_names(
+      names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
+      offload_src='device', offload_dst='pinned_host')
+
+  @functools.partial(checkpoint, policy=policy)
+  def f(x):
+    def g(ys, _):
+      y, _ = ys
+      y = checkpoint_name(jnp.sin(y), "y")
+      z = checkpoint_name(jnp.sin(y), "z")
+      z = z.T
+      w = checkpoint_name(jnp.sin(z), "w")
+      return (w.T, jnp.sum(w)), None
+    _, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
+    return scan_out
+```
+
+The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data.
+
 #### List of policies
 
 The policies are:
diff --git a/examples/ffi/src/jax_ffi_example/cpu_examples.cc b/examples/ffi/src/jax_ffi_example/cpu_examples.cc
index 3832c86b29b2..8d808ecd8e30 100644
--- a/examples/ffi/src/jax_ffi_example/cpu_examples.cc
+++ b/examples/ffi/src/jax_ffi_example/cpu_examples.cc
@@ -103,6 +103,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
     Counter, CounterImpl,
     ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
 
+// --------
+// Aliasing
+// --------
+//
+// This example demonstrates how input-output aliasing works. The handler
+// doesn't do anything except to check that the input and output pointers
+// address the same data.
+
+ffi::Error AliasingImpl(ffi::AnyBuffer input,
+                        ffi::Result<ffi::AnyBuffer> output) {
+  if (input.element_type() != output->element_type() ||
+      input.element_count() != output->element_count()) {
+    return ffi::Error::InvalidArgument(
+        "The input and output data types and sizes must match.");
+  }
+  if (input.untyped_data() != output->untyped_data()) {
+    return ffi::Error::InvalidArgument(
+        "When aliased, the input and output buffers should point to the same "
+        "data.");
+  }
+  return ffi::Error::Success();
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+    Aliasing, AliasingImpl,
+    ffi::Ffi::Bind().Arg<ffi::AnyBuffer>().Ret<ffi::AnyBuffer>());
+
 // Boilerplate for exposing handlers to Python
 NB_MODULE(_cpu_examples, m) {
   m.def("registrations", []() {
@@ -111,9 +138,8 @@ NB_MODULE(_cpu_examples, m) {
         nb::capsule(reinterpret_cast<void *>(ArrayAttr));
     registrations["dictionary_attr"] =
         nb::capsule(reinterpret_cast<void *>(DictionaryAttr));
-
     registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));
-
+    registrations["aliasing"] = nb::capsule(reinterpret_cast<void *>(Aliasing));
     return registrations;
   });
 }
diff --git a/examples/ffi/src/jax_ffi_example/cpu_examples.py b/examples/ffi/src/jax_ffi_example/cpu_examples.py
index 563e5a911b99..155e100dcd77 100644
--- a/examples/ffi/src/jax_ffi_example/cpu_examples.py
+++ b/examples/ffi/src/jax_ffi_example/cpu_examples.py
@@ -39,3 +39,9 @@ def dictionary_attr(**kwargs):
 def counter(index):
   return jax.ffi.ffi_call(
     "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))
+
+
+def aliasing(x):
+  return jax.ffi.ffi_call(
+      "aliasing", jax.ShapeDtypeStruct(x.shape, x.dtype),
+      input_output_aliases={0: 0})(x)
diff --git a/examples/ffi/tests/cpu_examples_test.py b/examples/ffi/tests/cpu_examples_test.py
index 0e2cfde02db6..8db524f6264b 100644
--- a/examples/ffi/tests/cpu_examples_test.py
+++ b/examples/ffi/tests/cpu_examples_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from absl.testing import absltest
+from absl.testing import absltest, parameterized
 
 import jax
 import jax.numpy as jnp
@@ -91,5 +91,16 @@ def counter_fun(x):
     self.assertEqual(counter_fun(0)[1], 3)
 
 
+class AliasingTests(jtu.JaxTestCase):
+  def setUp(self):
+    super().setUp()
+    if not jtu.test_device_matches(["cpu"]):
+      self.skipTest("Unsupported platform")
+
+  @parameterized.parameters((jnp.linspace(0, 0.5, 10),), (jnp.int32(6),))
+  def test_basic(self, x):
+    self.assertAllClose(cpu_examples.aliasing(x), x)
+
+
 if __name__ == "__main__":
   absltest.main(testLoader=jtu.JaxTestLoader())
diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py
index 5c42a3b44de3..21bc8038448a 100644
--- a/jax/_src/compiler.py
+++ b/jax/_src/compiler.py
@@ -33,6 +33,7 @@
 from jax._src import traceback_util
 from jax._src.interpreters import mlir
 from jax._src.lib import version as jaxlib_version
+from jax._src.lib import xla_extension_version  # pylint: disable=g-importing-member
 from jax._src.lib import xla_client as xc
 from jax._src.lib.mlir import ir
 import numpy as np
@@ -198,7 +199,10 @@ def get_compile_options(
     logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
     if env_options_overrides is None:
       env_options_overrides = {}
-    env_options_overrides['xla_gpu_enable_command_buffer'] = ''
+    if xla_extension_version > 302:
+      env_options_overrides['xla_gpu_enable_command_buffer'] = ''
+    else:
+      env_options_overrides['xla_gpu_graph_min_graph_size'] = '100000'
 
   if env_options_overrides is not None:
     # Some overrides are passed directly on build_options.
diff --git a/jax/_src/core.py b/jax/_src/core.py
index 0868990e1d3b..5017f7f5f676 100644
--- a/jax/_src/core.py
+++ b/jax/_src/core.py
@@ -39,7 +39,7 @@
 from jax._src import effects
 from jax._src import compute_on
 from jax._src import mesh as mesh_lib
-from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton
+from jax._src.partition_spec import PartitionSpec as P
 from jax._src.errors import (
     ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
     TracerIntegerConversionError, UnexpectedTracerError)
@@ -1659,12 +1659,12 @@ def _maybe_modify_sharding(sharding):
 
   new_spec = []
   for s in sharding.spec:
-    if s is None or isinstance(s, UnconstrainedSingleton):
+    if s is None:
       new_spec.append(s)
     else:
       temp_s = s[0] if isinstance(s, tuple) else s
       new_spec.append(
-          P.UNCONSTRAINED
+          None
           if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
   return sharding.with_spec(new_spec)
 
@@ -1762,8 +1762,6 @@ def _get_shape_sharding_str(shape, spec):
   for s1, s2 in zip(shape, spec):
     if s2 is None:
       out.append(f"{s1}")
-    elif isinstance(s2, UnconstrainedSingleton):
-      out.append(f"{s1}")
     elif isinstance(s2, tuple):
       ss = ','.join(s for s in s2)
       out.append(f"{s1}@({ss})")
diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py
index 4cf34f200e5e..6b0ef293e807 100644
--- a/jax/_src/custom_partitioning.py
+++ b/jax/_src/custom_partitioning.py
@@ -190,7 +190,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
         closed_jaxpr,
         name="tmp_xla_computation",
         platforms=module_context.platforms,
-        backend_or_name=module_context.backend_or_name,
+        backend=module_context.backend,
         axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
     )
   result_sharding = _pack_result_sharding(result_shape, result_shardings)
diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py
index f011e756da31..471b0c4af74d 100644
--- a/jax/_src/dispatch.py
+++ b/jax/_src/dispatch.py
@@ -22,39 +22,38 @@
 import enum
 from functools import partial
 import itertools
-import time
-from typing import Any, NamedTuple
 import logging
 import threading
-
-import numpy as np
+import time
+from typing import Any, NamedTuple
 
 import jax
+from jax._src import api
+from jax._src import array
 from jax._src import basearray
 from jax._src import config
 from jax._src import core
-from jax._src import api
-from jax._src import array
 from jax._src import dtypes
+from jax._src import lib
 from jax._src import source_info_util
 from jax._src import traceback_util
 from jax._src import util
+from jax._src.abstract_arrays import array_types
 from jax._src.interpreters import ad
 from jax._src.interpreters import batching
-from jax._src.abstract_arrays import array_types
 from jax._src.interpreters import mlir
-from jax._src.interpreters import xla
 from jax._src.interpreters import pxla
-from jax._src import lib
-from jax._src.mesh import AbstractMesh, Mesh
+from jax._src.interpreters import xla
+from jax._src.layout import DeviceLocalLayout, Layout
 from jax._src.lib import xla_client as xc
-from jax._src.monitoring import record_event_duration_secs
+from jax._src.mesh import AbstractMesh, Mesh
+from jax._src.monitoring import record_event_duration_secs, record_event_time_span
 from jax._src.partition_spec import PartitionSpec
 from jax._src.sharding import Sharding
-from jax._src.sharding_impls import (
-    SingleDeviceSharding, NamedSharding, TransferToMemoryKind,
+from jax._src.sharding_impls import ( NamedSharding,
+    SingleDeviceSharding, TransferToMemoryKind,
     is_single_device_sharding)
-from jax._src.layout import Layout, DeviceLocalLayout
+import numpy as np
 
 
 JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
@@ -177,12 +176,14 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
     log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
     start_time = time.time()
     yield
-    elapsed_time = time.time() - start_time
+    end_time = time.time()
+    elapsed_time = end_time - start_time
     if logger.isEnabledFor(log_priority):
       logger.log(log_priority, fmt.format(
           fun_name=fun_name, elapsed_time=elapsed_time))
     if event is not None:
       record_event_duration_secs(event, elapsed_time)
+      record_event_time_span(event, start_time, end_time)
 
 
 def should_tuple_args(num_args: int, platform: str) -> bool:
diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py
index 8ba43083d7a3..6b1945746255 100644
--- a/jax/_src/export/_export.py
+++ b/jax/_src/export/_export.py
@@ -896,7 +896,7 @@ def is_token(typ, attrs):
     with ir.InsertionPoint(entry_block):
       # Make a context just for lowering the dimension value computations
       module_context = mlir.ModuleContext(
-          backend_or_name="cpu", platforms=["cpu"],
+          backend=None, platforms=["cpu"],
           axis_context=sharding_impls.ShardingContext(0),
           keepalives=[], channel_iterator=itertools.count(1),
           host_callbacks=[], module=wrapped_module, context=context,
diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py
index 85da4e53b1f7..f552a779a844 100644
--- a/jax/_src/interpreters/mlir.py
+++ b/jax/_src/interpreters/mlir.py
@@ -50,7 +50,6 @@
 from jax._src.layout import AutoLayout, DeviceLocalLayout
 from jax._src.sharding import Sharding as JSharding
 from jax._src.sharding_impls import AUTO, NamedSharding
-from jax._src.partition_spec import UnconstrainedSingleton
 from jax._src.lib import xla_client as xc
 from jax._src.lib import xla_extension
 from jax._src.lib.mlir import dialects, ir, passmanager
@@ -698,10 +697,11 @@ class ModuleContext:
   module: ir.Module
   ip: ir.InsertionPoint
   symbol_table: ir.SymbolTable
-  backend_or_name: str | xb.XlaBackend | None
   # The lowering platforms for the module. Can be more than one only when
   # exporting.
   platforms: Sequence[str]
+  # See ModuleContext.get_backend() for backend and platforms usage.
+  backend: xb.XlaBackend | None
   axis_context: AxisContext
   keepalives: list[Any]
   channel_iterator: Iterator[int]
@@ -725,8 +725,8 @@ def axis_env(self) -> sharding_impls.AxisEnv:
   def __init__(
       self,
       *,
-      backend_or_name: str | xb.XlaBackend | None,
       platforms: Sequence[str],
+      backend: xb.XlaBackend | None,
       axis_context: AxisContext,
       keepalives: list[Any],
       channel_iterator: Iterator[int],
@@ -745,7 +745,7 @@ def __init__(
     self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
     self.ip = ip or ir.InsertionPoint(self.module.body)
     self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
-    self.backend_or_name = backend_or_name
+    self.backend = backend
     self.platforms = platforms
     self.axis_context = axis_context
     self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
@@ -760,17 +760,20 @@ def __init__(
     self.all_default_mem_kind = all_default_mem_kind
     self.lowering_parameters = lowering_parameters
 
-  @property
-  def backend(self) -> xb.XlaBackend:
-    # TODO(necula): clean the use of backend and backend_or_name vs. platforms
+  def get_backend(self) -> xb.XlaBackend:
     if len(self.platforms) > 1:
       raise NotImplementedError(
         "accessing .backend in multi-lowering setting. This can occur when "
         "lowering a primitive that has not been adapted to multi-platform "
         "lowering")
-    if self.backend_or_name is None or isinstance(self.backend_or_name, str):
-      return xb.get_backend(self.backend_or_name)
-    return self.backend_or_name
+    if self.backend is not None:
+      if xb.canonicalize_platform(self.backend.platform) != self.platforms[0]:
+        raise ValueError(
+          "the platform for the specified backend "
+          f"{xb.canonicalize_platform(self.backend.platform)} is different "
+          f"from the lowering platform {self.platforms[0]}")
+      return self.backend
+    return xb.get_backend(self.platforms[0])
 
   def new_channel(self) -> int:
     channel = next(self.channel_iterator)
@@ -1072,14 +1075,14 @@ def _get_unconstrained_dimensions(s, aval):
   return (us, all_unconstrained(s, aval),
           ({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None))
 
-
 def lower_jaxpr_to_module(
     module_name: str,
     jaxpr: core.ClosedJaxpr,
     *,
     ordered_effects: list[core.Effect],
-    backend_or_name: str | xb.XlaBackend | None,
+    # See ModuleContext.get_backend() for backend and platforms usage.
     platforms: Sequence[str],
+    backend: xb.XlaBackend | None,
     axis_context: AxisContext,
     name_stack: source_info_util.NameStack,
     donated_args: Sequence[bool],
@@ -1170,7 +1173,7 @@ def lower_jaxpr_to_module(
   else:
     dim_vars = ()
 
-  ctx = ModuleContext(backend_or_name=backend_or_name,
+  ctx = ModuleContext(backend=backend,
                       platforms=platforms, axis_context=axis_context,
                       keepalives=keepalives,
                       channel_iterator=channel_iter,
@@ -1202,10 +1205,6 @@ def lower_jaxpr_to_module(
         arg_layouts=in_layouts,
         result_layouts=out_layouts,
         propagated_out_mem_kinds=propagated_out_mem_kinds)
-    if config.use_shardy_partitioner.value:
-      pipeline = passmanager.PassManager.parse(
-          'builtin.module(sdy-lift-inlined-meshes)')
-      pipeline.run(ctx.module.operation)
 
   try:
     if not ctx.module.operation.verify():
@@ -1224,6 +1223,12 @@ def emit_diagnostic_info(d):
     raise ValueError("\n".join(msg_lines) + "\n" +
                      dump_module_message(ctx.module, "verification")) from e
 
+  if config.use_shardy_partitioner.value:
+    with ctx.context:
+      pipeline = passmanager.PassManager.parse(
+          'builtin.module(sdy-lift-inlined-meshes)')
+      pipeline.run(ctx.module.operation)
+
   return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
                         ctx.shape_poly_state)
 
@@ -2595,8 +2600,7 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
   if aval.sharding.mesh._any_axis_collective:
     unspecified_dims = set(range(aval.ndim))
   elif aval.sharding.mesh._any_axis_auto:
-    unspecified_dims = {i for i, s in enumerate(aval.sharding.spec)
-                        if isinstance(s, UnconstrainedSingleton)}
+    unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}
   return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
 
 
@@ -2892,7 +2896,7 @@ def emit_python_callback(
   if platform not in {"cpu", "cuda", "rocm", "tpu"}:
     raise ValueError(
         f"`EmitPythonCallback` not supported on {platform} backend.")
-  backend = ctx.module_context.backend
+  backend = ctx.module_context.get_backend()
   result_shapes = util.flatten(
       [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
   operand_shapes = util.flatten(
@@ -3012,13 +3016,14 @@ def _wrapped_callback(token, *args):  # type: ignore  # pylint: disable=function
 def build_mlir_module_helper(
     closed_jaxpr: core.ClosedJaxpr, *, name: str,
     platforms: Sequence[str],
-    backend_or_name: str, axis_context: AxisContext) -> ir.Module:
+    backend: xb.XlaBackend | None,
+    axis_context: AxisContext) -> ir.Module:
   """Helper to generate pmap-style XLA computations for custom partitioners."""
   unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects)
   if unlowerable_effects:
     raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}')
   lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
-      backend_or_name=backend_or_name, ordered_effects=[],
+      backend=backend, ordered_effects=[],
       name_stack=source_info_util.NameStack(),
       donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
       axis_context=axis_context, platforms=platforms,
diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py
index a61600402b74..ec2d52e48e9c 100644
--- a/jax/_src/interpreters/pxla.py
+++ b/jax/_src/interpreters/pxla.py
@@ -871,7 +871,7 @@ def lower_parallel_callable(
           module_name,
           closed_jaxpr,
           ordered_effects=ordered_effects,
-          backend_or_name=backend,
+          backend=backend,
           platforms=platforms,
           axis_context=sharding_impls.ReplicaAxisContext(axis_env),
           name_stack=name_stack,
@@ -1179,7 +1179,7 @@ def __str__(self):
 
 
 class ResultsHandler:
-  # `out_avals` is the `Array` global avals when using pjit or xmap. It is the
+  # `out_avals` is the `Array` global avals when using pjit. It is the
   # local one when using `pmap`.
   __slots__ = ("handlers", "out_shardings", "out_avals")
 
@@ -1954,7 +1954,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
         module_name,
         closed_jaxpr,
         ordered_effects=ordered_effects,
-        backend_or_name=backend,
+        backend=backend,
         platforms=platforms,
         axis_context=axis_ctx,
         name_stack=name_stack,
@@ -2150,7 +2150,7 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
   return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
           donated_invars, out_shardings, out_layouts)
 
-def _concretize_abstract_shardings(shardings, avals, device_assignment):
+def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
   np_dev = np.vectorize(lambda i: device_assignment[i],
                         otypes=[object])(np.arange(len(device_assignment)))
 
@@ -2163,8 +2163,14 @@ def _abstract_to_concrete_mesh(abstract_mesh):
   out = []
   for s, a in zip(shardings, avals):
     if isinstance(s, UnspecifiedValue) and a.sharding is not None:
-      out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
-                               a.sharding.spec))
+      if config.use_shardy_partitioner.value:
+        spec = a.sharding.spec
+      else:
+        spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
+                               for sp in a.sharding.spec])
+                if a.sharding.mesh._any_axis_auto else a.sharding.spec)
+      out.append(NamedSharding(
+          _abstract_to_concrete_mesh(a.sharding.mesh), spec))
     else:
       out.append(s)
   return tuple(out)
@@ -2243,7 +2249,7 @@ def lower_sharding_computation(
   unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings]
 
   if config.sharding_in_types.value:
-    out_shardings = _concretize_abstract_shardings(
+    out_shardings = _concretize_abstract_out_shardings(
         out_shardings, global_out_avals, device_assignment)
 
   # TODO(parkers): One _raw_platform has been unified with platform,
diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py
index 25fb2b38f7fa..34f485684e85 100644
--- a/jax/_src/mesh.py
+++ b/jax/_src/mesh.py
@@ -76,7 +76,7 @@ def __repr__(self):
 @util.cache(max_size=128, trace_context_in_key=False)
 def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
   if global_mesh.empty:
-      return global_mesh
+    return global_mesh
   is_local_device = np.vectorize(
       lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices)
   subcube_indices = []
@@ -96,9 +96,9 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
   # subcube that hull will contain non-local devices.
   if not is_local_device[subcube_indices_tuple].all():
     raise ValueError(
-        "When passing host local inputs to pjit or xmap, devices "
-        "connected to a single host must form a contiguous subcube of the "
-        "global device mesh")
+        "When passing host local inputs to pjit, devices connected to a single"
+        " host must form a contiguous subcube of the global device mesh"
+    )
   return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)
 
 
diff --git a/jax/_src/monitoring.py b/jax/_src/monitoring.py
index 3b291de0061a..99e957733ba2 100644
--- a/jax/_src/monitoring.py
+++ b/jax/_src/monitoring.py
@@ -39,8 +39,17 @@ def __call__(self, event: str, duration_secs: float,
     ...
 
 
+class EventTimeSpanListenerWithMetadata(Protocol):
+
+  def __call__(
+      self, event: str, start_time: float, end_time: float, **kwargs: str | int
+  ) -> None:
+    ...
+
+
 _event_listeners: list[EventListenerWithMetadata] = []
 _event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = []
+_event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = []
 
 
 def record_event(event: str, **kwargs: str | int) -> None:
@@ -64,6 +73,14 @@ def record_event_duration_secs(event: str, duration: float,
     callback(event, duration, **kwargs)
 
 
+def record_event_time_span(
+    event: str, start_time: float, end_time: float, **kwargs: str | int
+) -> None:
+  """Record an event start and end time in seconds (float)."""
+  for callback in _event_time_span_listeners:
+    callback(event, start_time, end_time, **kwargs)
+
+
 def register_event_listener(
     callback: EventListenerWithMetadata,
 ) -> None:
@@ -71,6 +88,13 @@ def register_event_listener(
   _event_listeners.append(callback)
 
 
+def register_event_time_span_listener(
+    callback: EventTimeSpanListenerWithMetadata,
+) -> None:
+  """Register a callback to be invoked during record_event_time_span()."""
+  _event_time_span_listeners.append(callback)
+
+
 def register_event_duration_secs_listener(
     callback : EventDurationListenerWithMetadata) -> None:
   """Register a callback to be invoked during record_event_duration_secs()."""
@@ -80,15 +104,22 @@ def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]:
   """Get event duration listeners."""
   return list(_event_duration_secs_listeners)
 
+
+def get_event_time_span_listeners() -> list[EventTimeSpanListenerWithMetadata]:
+  """Get event time span listeners."""
+  return list(_event_time_span_listeners)
+
+
 def get_event_listeners() -> list[EventListenerWithMetadata]:
   """Get event listeners."""
   return list(_event_listeners)
 
 def clear_event_listeners():
   """Clear event listeners."""
-  global _event_listeners, _event_duration_secs_listeners
+  global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners
   _event_listeners = []
   _event_duration_secs_listeners = []
+  _event_time_span_listeners = []
 
 def _unregister_event_duration_listener_by_callback(
     callback: EventDurationListenerWithMetadata) -> None:
@@ -108,6 +139,18 @@ def _unregister_event_duration_listener_by_index(index: int) -> None:
   assert -size <= index < size
   del _event_duration_secs_listeners[index]
 
+
+def _unregister_event_time_span_listener_by_callback(
+    callback: EventTimeSpanListenerWithMetadata,
+) -> None:
+  """Unregister an event time span listener by callback.
+
+  This function is supposed to be called for testing only.
+  """
+  assert callback in _event_time_span_listeners
+  _event_time_span_listeners.remove(callback)
+
+
 def _unregister_event_listener_by_callback(
     callback: EventListenerWithMetadata) -> None:
   """Unregister an event listener by callback.
diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py
index a126cbba0ce1..20c160e0e436 100644
--- a/jax/_src/pallas/core.py
+++ b/jax/_src/pallas/core.py
@@ -473,6 +473,8 @@ def to_block_mapping(
     mapping.check_invariants()
     return mapping
 
+  replace = dataclasses.replace
+
 
 class NoBlockSpec:
   def __repr__(self):
@@ -1028,18 +1030,37 @@ def to_json(self) -> bytes:
 core_map_p = jax_core.Primitive("core_map")
 core_map_p.multiple_results = True
 
-def core_map(mesh):
+
+def core_map(
+    mesh,
+    *,
+    compiler_params: Any | None = None,
+    interpret: bool = False,
+    debug: bool = False,
+    cost_estimate: CostEstimate | None = None,
+):
   """Runs a function on a mesh, mapping it over the devices in the mesh.
 
   The function should be stateful in that it takes in no inputs and returns
   no outputs but can mutate closed-over Refs, for example.
+
+  Args:
+    mesh: The mesh to run the function on.
+    compiler_params: The compiler parameters to pass to the backend.
+    interpret: Whether to run the function in interpret mode.
+    debug: Whether or not to out helpful debugging information.
+    cost_estimate: The cost estimate of the function.
   """
   def wrapped(f):
     flat_args, in_tree = tree_util.tree_flatten(((), {}))
     flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
     with jax_core.extend_axis_env_nd(mesh.shape.items()):
       jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args)
-    out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh)
+    out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh,
+                          compiler_params=compiler_params,
+                          interpret=interpret,
+                          debug=debug,
+                          cost_estimate=cost_estimate)
     if out:
       raise ValueError("core_map-ped functions must not return any outputs.")
     return tree_util.tree_unflatten(out_tree_thunk(), out)
@@ -1047,7 +1068,7 @@ def wrapped(f):
 
 
 @core_map_p.def_effectful_abstract_eval
-def _core_map_abstract_eval(*args, jaxpr, mesh):
+def _core_map_abstract_eval(*args, jaxpr, mesh, **_):
   del args
   if jaxpr.outvars:
     raise ValueError("core_map must not return any outputs.")
@@ -1074,6 +1095,9 @@ def default_mesh_discharge_rule(
     compiler_params,
     backend,
     jaxpr,
+    debug,
+    interpret,
+    cost_estimate,
 ):
   """Discharges a ``core_map`` over a mesh to a ``pallas_call``."""
   del out_avals  # Unused.
@@ -1103,6 +1127,9 @@ def body(*args):
       grid=grid,
       compiler_params=compiler_params,
       backend=backend,
+      interpret=interpret,
+      debug=debug,
+      cost_estimate=cost_estimate,
   )(*args)
   # ``outs`` lacks the unmodified inputs. Add them back in.
   all_outs = [None] * len(args)
@@ -1120,8 +1147,8 @@ def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwa
   )
 
 
-def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh):
-  del in_atoms
+def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh, **kwargs):
+  del in_atoms, kwargs
   with jax_core.extend_axis_env_nd(tuple(mesh.shape.items())):
     jax_core.check_jaxpr(jaxpr)
   effs = set()
diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py
index ad9a6cb13f42..80f8a7de544c 100644
--- a/jax/_src/pallas/mosaic/core.py
+++ b/jax/_src/pallas/mosaic/core.py
@@ -92,6 +92,8 @@ class TPUCompilerParams(pallas_core.CompilerParams):
   serialization_format: int = 1
   device_type: str | None = None
 
+  replace = dataclasses.replace
+
 class TPUMemorySpace(enum.Enum):
   ANY = "any"  # TODO(b/368401328): Remove this and just use pl.ANY.
   VMEM = "vmem"
@@ -240,19 +242,38 @@ def _tensorcore_mesh_discharge_rule(
     *args,
     mesh,
     jaxpr,
+    compiler_params: TPUCompilerParams,
+    interpret: bool,
+    debug: bool,
+    cost_estimate: pallas_core.CostEstimate | None,
 ):
   assert isinstance(mesh, TensorCoreMesh)
+  if compiler_params and not isinstance(compiler_params, TPUCompilerParams):
+    raise ValueError(
+        "compiler_params must be a pltpu.TPUCompilerParams"
+    )
+  if not compiler_params:
+    compiler_params = TPUCompilerParams()
   if len(mesh.shape) > 1:
     raise NotImplementedError("Mesh must be 1D")
   core_axis_name, num_cores = list(mesh.shape.items())[0]
+  if compiler_params.dimension_semantics is not None:
+    raise ValueError(
+        "dimension_semantics must be None for TensorCoreMesh"
+    )
   return pallas_core.default_mesh_discharge_rule(
       in_avals,
       out_avals,
       *args,
       jaxpr=jaxpr,
       grid=((core_axis_name, num_cores),),
-      compiler_params=TPUCompilerParams(dimension_semantics=("parallel",)),
+      compiler_params=compiler_params.replace(
+          dimension_semantics=("parallel",)
+      ),
+      debug=debug,
+      interpret=interpret,
       backend="mosaic_tpu",
+      cost_estimate=cost_estimate,
   )
 
 pallas_core._core_map_mesh_rules[TensorCoreMesh] = (
diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py
index d77ae4358703..36e6e47cbf4f 100644
--- a/jax/_src/pallas/mosaic_gpu/core.py
+++ b/jax/_src/pallas/mosaic_gpu/core.py
@@ -483,7 +483,6 @@ class GPUMesh:
   # Those are NOT CUDA threads. On Hopper they correspond to warpgroups.
   num_threads: int | None = None
   axis_names: tuple[str, ...] = ()
-  approx_math: bool = False
 
   def __post_init__(self):
     if len(self.axis_names) != len(self.grid) + (self.num_threads is not None):
@@ -521,12 +520,24 @@ def _gpu_mesh_discharge_rule(
     *args,
     mesh,
     jaxpr,
+    compiler_params,
+    interpret,
+    debug,
+    cost_estimate,
 ):
-  assert isinstance(mesh, GPUMesh)
+  if not isinstance(mesh, GPUMesh):
+    raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}")
   if mesh.cluster:
     raise NotImplementedError
   if mesh.num_threads is None:
     raise NotImplementedError
+  if compiler_params and not isinstance(compiler_params, GPUCompilerParams):
+    raise TypeError(
+        "Compiler params must be a GPUCompilerParams, got"
+        f" {type(compiler_params)}"
+    )
+  if not compiler_params:
+    compiler_params = GPUCompilerParams()
   return pallas_core.default_mesh_discharge_rule(
       in_avals,
       out_avals,
@@ -534,8 +545,13 @@ def _gpu_mesh_discharge_rule(
       jaxpr=jaxpr,
       grid=tuple(mesh.shape.items()),
       backend="mosaic_gpu",
-      compiler_params=GPUCompilerParams(approx_math=mesh.approx_math),
+      compiler_params=compiler_params,
+      debug=debug,
+      interpret=interpret,
+      cost_estimate=cost_estimate,
   )
+
+
 pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule
 
 
diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py
index 1805f8c0923a..59b1b86f33fc 100644
--- a/jax/_src/pallas/triton/pallas_call_registration.py
+++ b/jax/_src/pallas/triton/pallas_call_registration.py
@@ -61,14 +61,14 @@ def pallas_call_lowering(
         "scalar prefetch not implemented in the Triton backend"
     )
   triton_params = compiler_params.get("triton", compiler_params)
-  num_warps = triton_params.pop("num_warps", 4)
+  num_warps = triton_params.get("num_warps", 4)
   num_warps = 4 if num_warps is None else num_warps
   [lowering_platform] = ctx.platforms or ctx.module_context.platforms
   if lowering_platform == "rocm":
-    num_stages = triton_params.pop("num_stages", 1)
+    num_stages = triton_params.get("num_stages", 1)
     num_stages = 1 if num_stages is None else num_stages
   else:
-    num_stages = triton_params.pop("num_stages", 3)
+    num_stages = triton_params.get("num_stages", 3)
     num_stages = 3 if num_stages is None else num_stages
 
   if debug:
diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py
index 21199b9bfd68..88f8853f53cd 100644
--- a/jax/_src/state/discharge.py
+++ b/jax/_src/state/discharge.py
@@ -148,45 +148,51 @@ def _eval_jaxpr_discharge_state(
                        if d and isinstance(v.aval, AbstractRef)}
 
   for eqn in jaxpr.eqns:
-    should_discharge = [id(v.aval) in refs_to_discharge for v in eqn.invars]
-    if eqn.primitive is core.mutable_array_p:
-      [invar], [outvar] = eqn.invars, eqn.outvars
-      ans = env.read(invar)
-      refs_to_discharge.add(id(outvar.aval))
-    elif eqn.primitive is core.freeze_p:
-      [invar], [outvar] = eqn.invars, eqn.outvars
-      ans = env.read(invar)
-      refs_to_discharge.remove(id(invar.aval))
-    elif (any(should_discharge)
-          or core.internal_mutable_array_effect in eqn.effects
-      ):
-      if eqn.primitive in _partial_discharge_rules:
-        rule: DischargeRule = partial(_partial_discharge_rules[eqn.primitive], should_discharge)
-      elif eqn.primitive in _discharge_rules:
-        rule = _discharge_rules[eqn.primitive]
+    name_stack = (
+        source_info_util.current_name_stack() + eqn.source_info.name_stack
+    )
+    traceback = eqn.source_info.traceback
+    with source_info_util.user_context(
+        traceback, name_stack=name_stack), eqn.ctx.manager:
+      should_discharge = [id(v.aval) in refs_to_discharge for v in eqn.invars]
+      if eqn.primitive is core.mutable_array_p:
+        [invar], [outvar] = eqn.invars, eqn.outvars
+        ans = env.read(invar)
+        refs_to_discharge.add(id(outvar.aval))
+      elif eqn.primitive is core.freeze_p:
+        [invar], [outvar] = eqn.invars, eqn.outvars
+        ans = env.read(invar)
+        refs_to_discharge.remove(id(invar.aval))
+      elif (any(should_discharge)
+            or core.internal_mutable_array_effect in eqn.effects
+        ):
+        if eqn.primitive in _partial_discharge_rules:
+          rule: DischargeRule = partial(_partial_discharge_rules[eqn.primitive], should_discharge)
+        elif eqn.primitive in _discharge_rules:
+          rule = _discharge_rules[eqn.primitive]
+        else:
+          raise NotImplementedError("No state discharge rule implemented for "
+              f"primitive: {eqn.primitive}")
+        invals = map(env.read, eqn.invars)
+        in_avals = [v.aval for v in eqn.invars]
+        out_avals = [v.aval for v in eqn.outvars]
+        new_invals, ans = rule(
+            in_avals, out_avals, *invals, **eqn.params)
+        for invar, should, new_inval in zip(eqn.invars, should_discharge, new_invals):
+          if new_inval is not None:
+            if not should:
+              raise ValueError(
+                  f"Did not ask for inval to be discharged but it was. ({invar=},"
+                  f" {new_inval=})"
+              )
+            env.write(invar, new_inval)  # type: ignore[arg-type]
       else:
-        raise NotImplementedError("No state discharge rule implemented for "
-            f"primitive: {eqn.primitive}")
-      invals = map(env.read, eqn.invars)
-      in_avals = [v.aval for v in eqn.invars]
-      out_avals = [v.aval for v in eqn.outvars]
-      new_invals, ans = rule(
-          in_avals, out_avals, *invals, **eqn.params)
-      for invar, should, new_inval in zip(eqn.invars, should_discharge, new_invals):
-        if new_inval is not None:
-          if not should:
-            raise ValueError(
-                f"Did not ask for inval to be discharged but it was. ({invar=},"
-                f" {new_inval=})"
-            )
-          env.write(invar, new_inval)  # type: ignore[arg-type]
-    else:
-      # Default primitive rule, similar to `core.eval_jaxpr`. Note that here
-      # we assume any higher-order primitives inside of the jaxpr are *not*
-      # stateful.
-      subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
-      ans = eqn.primitive.bind(*subfuns, *map(env.read, eqn.invars),
-                               **bind_params)
+        # Default primitive rule, similar to `core.eval_jaxpr`. Note that here
+        # we assume any higher-order primitives inside of the jaxpr are *not*
+        # stateful.
+        subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
+        ans = eqn.primitive.bind(*subfuns, *map(env.read, eqn.invars),
+                                **bind_params)
     if eqn.primitive.multiple_results:
       map(env.write, eqn.outvars, ans)
     else:
diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py
index a2f60887706b..cb0cf39568d1 100644
--- a/jax/_src/test_util.py
+++ b/jax/_src/test_util.py
@@ -17,6 +17,7 @@
 
 import collections
 from collections.abc import Callable, Generator, Iterable, Sequence
+from concurrent.futures import ThreadPoolExecutor
 from contextlib import ExitStack, contextmanager
 import datetime
 import functools
@@ -31,6 +32,7 @@
 import tempfile
 import textwrap
 import threading
+import time
 from typing import Any, TextIO
 import unittest
 import warnings
@@ -115,6 +117,12 @@
           'deterministic, interactive'),
 )
 
+TEST_NUM_THREADS = config.int_flag(
+    'jax_test_num_threads', 0,
+    help='Number of threads to use for running tests. 0 means run everything '
+    'in the main thread. Using > 1 thread is experimental.'
+)
+
 # We sanitize test names to ensure they work with "unitttest -k" and
 # "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k
 # does not. We replace sequences of problematic characters with a single '_'.
@@ -498,29 +506,20 @@ def device_supports_buffer_donation():
   )
 
 
-@contextmanager
-def set_host_platform_device_count(nr_devices: int):
-  """Context manager to set host platform device count if not specified by user.
+def request_cpu_devices(nr_devices: int):
+  """Requests at least `nr_devices` CPU devices.
 
-  This should only be used by tests at the top level in setUpModule(); it will
-  not work correctly if applied to individual test cases.
+  request_cpu_devices should be called at the top-level of a test module before
+  main() runs.
+
+  It is not guaranteed that the number of CPU devices will be exactly
+  `nr_devices`: it may be more or less, depending on how exactly the test is
+  invoked. Test cases that require a specific number of devices should skip
+  themselves if that number is not met.
   """
-  prev_xla_flags = os.getenv("XLA_FLAGS")
-  flags_str = prev_xla_flags or ""
-  # Don't override user-specified device count, or other XLA flags.
-  if "xla_force_host_platform_device_count" not in flags_str:
-    os.environ["XLA_FLAGS"] = (flags_str +
-                               f" --xla_force_host_platform_device_count={nr_devices}")
-  # Clear any cached backends so new CPU backend will pick up the env var.
-  xla_bridge.get_backend.cache_clear()
-  try:
-    yield
-  finally:
-    if prev_xla_flags is None:
-      del os.environ["XLA_FLAGS"]
-    else:
-      os.environ["XLA_FLAGS"] = prev_xla_flags
+  if xla_bridge.NUM_CPU_DEVICES.value < nr_devices:
     xla_bridge.get_backend.cache_clear()
+    config.update("jax_num_cpu_devices", nr_devices)
 
 
 def skip_on_flag(flag_name, skip_value):
@@ -1007,8 +1006,140 @@ def sample_product(*args, **kw):
   """
   return parameterized.parameters(*sample_product_testcases(*args, **kw))
 
+# We use a reader-writer lock to protect test execution. Tests that may run in
+# parallel acquire a read lock; tests that are not thread-safe acquire a write
+# lock.
+if hasattr(util, 'Mutex'):
+  _test_rwlock = util.Mutex()
+
+  def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
+    _test_rwlock.reader_lock()
+    try:
+      test(result)  # type: ignore
+    finally:
+      _test_rwlock.reader_unlock()
+
+
+  @contextmanager
+  def thread_hostile_test():
+    "Decorator for tests that are not thread-safe."
+    _test_rwlock.assert_reader_held()
+    _test_rwlock.reader_unlock()
+    _test_rwlock.writer_lock()
+    try:
+      yield
+    finally:
+      _test_rwlock.writer_unlock()
+      _test_rwlock.reader_lock()
+else:
+  # TODO(phawkins): remove this branch when jaxlib 0.5.0 is the minimum.
+  _test_rwlock = threading.Lock()
+
+  def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
+    _test_rwlock.acquire()
+    try:
+      test(result)  # type: ignore
+    finally:
+      _test_rwlock.release()
+
+
+  @contextmanager
+  def thread_hostile_test():
+    yield  # No reader-writer lock, so we get no parallelism.
+
+class ThreadSafeTestResult:
+  """
+  Wraps a TestResult to make it thread safe.
+
+  We do this by accumulating API calls and applying them in a batch under a
+  lock at the conclusion of each test case.
+
+  We duck type instead of inheriting from TestResult because we aren't actually
+  a perfect implementation of TestResult, and would rather get a loud error
+  for things we haven't implemented.
+  """
+  def __init__(self, lock: threading.Lock, result: unittest.TestResult):
+    self.lock = lock
+    self.test_result = result
+    self.actions: list[Callable] = []
+
+  def startTest(self, test: unittest.TestCase):
+    del test
+    self.start_time = time.time()
+
+  def stopTest(self, test: unittest.TestCase):
+    stop_time = time.time()
+    with self.lock:
+      # We assume test_result is an ABSL _TextAndXMLTestResult, so we can
+      # override how it gets the time.
+      time_getter = self.test_result.time_getter
+      try:
+        self.test_result.time_getter = lambda: self.start_time
+        self.test_result.startTest(test)
+        for callback in self.actions:
+          callback()
+        self.test_result.time_getter = lambda: stop_time
+        self.test_result.stopTest(test)
+      finally:
+        self.test_result.time_getter = time_getter
+
+  def addSuccess(self, test: unittest.TestCase):
+    self.actions.append(lambda: self.test_result.addSuccess(test))
+
+  def addSkip(self, test: unittest.TestCase, reason: str):
+    self.actions.append(lambda: self.test_result.addSkip(test, reason))
+
+  def addError(self, test: unittest.TestCase, err):
+    self.actions.append(lambda: self.test_result.addError(test, err))
+
+  def addFailure(self, test: unittest.TestCase, err):
+    self.actions.append(lambda: self.test_result.addFailure(test, err))
+
+  def addExpectedFailure(self, test: unittest.TestCase, err):
+    self.actions.append(lambda: self.test_result.addExpectedFailure(test, err))
+
+  def addDuration(self, test: unittest.TestCase, elapsed):
+    self.actions.append(lambda: self.test_result.addDuration(test, elapsed))
+
+
+class JaxTestSuite(unittest.TestSuite):
+  """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1.
+
+  Caution: this test suite does not run setUpClass or setUpModule methods if
+  thread parallelism is enabled.
+  """
+
+  def __init__(self, suite: unittest.TestSuite):
+    super().__init__(list(suite))
+
+  def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult:
+    if TEST_NUM_THREADS.value <= 0:
+      return super().run(result)
+
+    executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
+    lock = threading.Lock()
+    futures = []
+
+    def run_test(test):
+      "Recursively runs tests in a test suite or test case."
+      if isinstance(test, unittest.TestSuite):
+        for subtest in test:
+          run_test(subtest)
+      else:
+        test_result = ThreadSafeTestResult(lock, result)
+        futures.append(executor.submit(_run_one_test, test, test_result))
+
+    with executor:
+      run_test(self)
+      for future in futures:
+        future.result()
+
+    return result
+
 
 class JaxTestLoader(absltest.TestLoader):
+  suiteClass = JaxTestSuite
+
   def getTestCaseNames(self, testCaseClass):
     names = super().getTestCaseNames(testCaseClass)
     if _TEST_TARGETS.value:
@@ -1091,10 +1222,8 @@ class JaxTestCase(parameterized.TestCase):
     'jax_legacy_prng_key': 'error',
   }
 
-  _compilation_cache_exit_stack: ExitStack | None = None
+  _context_stack: ExitStack | None = None
 
-  def tearDown(self) -> None:
-    assert core.reset_trace_state()
 
   def setUp(self):
     super().setUp()
@@ -1105,25 +1234,26 @@ def setUp(self):
     # b) it returns values in int32 range, which RandomState requires.
     self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))
 
-  @classmethod
-  def setUpClass(cls):
-    cls._compilation_cache_exit_stack = ExitStack()
-    stack = cls._compilation_cache_exit_stack
-    stack.enter_context(global_config_context(**cls._default_config))
+    # TODO(phawkins): use TestCase.enterContext once Python 3.11 is the minimum
+    # version.
+    self._context_stack = ExitStack()
+    self.addCleanup(self._context_stack.close)
+    stack = self._context_stack
+    stack.enter_context(global_config_context(**self._default_config))
 
     if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
+      assert TEST_NUM_THREADS.value <= 1, "Persistent compilation cache is not thread-safe."
       stack.enter_context(config.enable_compilation_cache(True))
       stack.enter_context(config.raise_persistent_cache_errors(True))
       stack.enter_context(config.persistent_cache_min_compile_time_secs(0))
       stack.enter_context(config.persistent_cache_min_entry_size_bytes(0))
-
       tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
-      compilation_cache.set_cache_dir(tmp_dir)
-      stack.callback(lambda: compilation_cache.reset_cache())
+      stack.enter_context(config.compilation_cache_dir(tmp_dir))
+      stack.callback(compilation_cache.reset_cache)
 
-  @classmethod
-  def tearDownClass(cls):
-    cls._compilation_cache_exit_stack.close()
+  def tearDown(self) -> None:
+    assert core.reset_trace_state()
+    super().tearDown()
 
   def rng(self):
     return self._rng
diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py
index bb92afebe8e9..b9645cbefb5e 100644
--- a/jax/_src/tpu_custom_call.py
+++ b/jax/_src/tpu_custom_call.py
@@ -377,6 +377,7 @@ def _lower_tpu_kernel(
     pipeline = [
         (
             "func.func(tpu-infer-vector-layout{"
+            f" hardware-generation={hardware_generation}"
             f" sublane-count={sl_cnt} lane-count={l_cnt}"
             "})"
         ),
diff --git a/jax/_src/util.py b/jax/_src/util.py
index f262570b9a1b..9450659a2ef4 100644
--- a/jax/_src/util.py
+++ b/jax/_src/util.py
@@ -685,3 +685,6 @@ def test_event(name: str, *args) -> None:
   if not test_event_listener:
     return
   test_event_listener(name, *args)
+
+if hasattr(jaxlib_utils, "Mutex"):
+  Mutex = jaxlib_utils.Mutex
diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py
index 28148761c8a4..bbe6631753cb 100644
--- a/jax/_src/xla_bridge.py
+++ b/jax/_src/xla_bridge.py
@@ -122,6 +122,14 @@
     "inline without async dispatch.",
 )
 
+NUM_CPU_DEVICES = config.int_flag(
+    name="jax_num_cpu_devices",
+    default=-1,
+    help="Number of CPU devices to use. If not provided, the value of "
+         "the XLA flag --xla_force_host_platform_device_count is used."
+         " Must be set before JAX is initialized.",
+)
+
 
 # Warn the user if they call fork(), because it's not going to go well for them.
 def _at_fork():
@@ -249,8 +257,8 @@ def make_cpu_client(
   if collectives is None:
     collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value
     if _CPU_ENABLE_GLOO_COLLECTIVES.value:
-        collectives_impl = 'gloo'
-        warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
+      collectives_impl = 'gloo'
+      warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
                       'deprecated. Please use `jax.config.update('
                       '"jax_cpu_collectives_implementation", "gloo")` instead.',
                       DeprecationWarning,
@@ -268,12 +276,22 @@ def make_cpu_client(
                         f"{collectives_impl}. Available implementations are "
                         f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")
 
+  num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
+  if xla_client._version < 303 and num_devices is not None:
+    xla_flags = os.getenv("XLA_FLAGS") or ""
+    os.environ["XLA_FLAGS"] = (
+        f"{xla_flags} --xla_force_host_platform_device_count={num_devices}"
+    )
+    num_devices = None
+  # TODO(phawkins): pass num_devices directly when version 303 is the minimum.
+  kwargs = {} if num_devices is None else {"num_devices": num_devices}
   return xla_client.make_cpu_client(
     asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
     distributed_client=distributed.global_state.client,
     node_id=distributed.global_state.process_id,
     num_nodes=distributed.global_state.num_processes,
     collectives=collectives,
+    **kwargs,
   )
 
 
diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py
index a4bb168efb2f..6ec621d68ff7 100644
--- a/jax/experimental/array_serialization/serialization_test.py
+++ b/jax/experimental/array_serialization/serialization_test.py
@@ -14,7 +14,6 @@
 """Tests for serialization and deserialization of GDA."""
 
 import asyncio
-import contextlib
 import math
 from functools import partial
 import os
@@ -36,13 +35,7 @@
 import tensorstore as ts
 
 jax.config.parse_flags_with_absl()
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
-
-def tearDownModule():
-  _exit_stack.close()
+jtu.request_cpu_devices(8)
 
 
 class CheckpointTest(jtu.JaxTestCase):
diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py
index f23bd58c48d3..0cde96aeb36d 100644
--- a/jax/experimental/jax2tf/tests/call_tf_test.py
+++ b/jax/experimental/jax2tf/tests/call_tf_test.py
@@ -69,18 +69,6 @@ def _named_test(**kwargs):
 
 class CallTfTest(tf_test_util.JaxToTfTestCase):
 
-  @classmethod
-  def setUpClass(cls):
-    # One TF device of each device_type
-    cls.tf_devices = []
-    for tf_device in tf.config.list_logical_devices():
-      if tf_device.device_type == "TPU_SYSTEM":
-        continue  # A virtual device
-      if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
-        cls.tf_devices.append(tf_device)
-
-    super().setUpClass()
-
   def setUp(self):
     if tf is None:
       raise unittest.SkipTest("Test requires tensorflow")
@@ -88,6 +76,13 @@ def setUp(self):
     # bug in TensorFlow.
     _ = tf.add(1, 1)
     super().setUp()
+    # One TF device of each device_type
+    self.tf_devices = []
+    for tf_device in tf.config.list_logical_devices():
+      if tf_device.device_type == "TPU_SYSTEM":
+        continue  # A virtual device
+      if all(tf_device.device_type != d.device_type for d in self.tf_devices):
+        self.tf_devices.append(tf_device)
     self.warning_ctx = jtu.ignore_warning(
         message=(
             "(jax2tf.convert with native_serialization=False has been deprecated"
@@ -798,7 +793,7 @@ def f_jax(x):
 
     jax_and_tf_platforms = (
       set(jax_platforms) & {d.device_type.lower()
-                            for d in self.__class__.tf_devices})
+                            for d in self.tf_devices})
 
     lowering_platforms = ("tpu", "cpu", "cuda")
 
@@ -833,7 +828,7 @@ def f_jax(x):
       f_jax,
       native_serialization=True,
       native_serialization_platforms=lowering_platforms))
-    for tf_device in self.__class__.tf_devices:
+    for tf_device in self.tf_devices:
       with self.subTest(tf_device.device_type):
         logging.info(
           f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py
index 7d3313be6c92..bea2b76cb7cf 100644
--- a/jax/experimental/jax2tf/tests/jax2tf_test.py
+++ b/jax/experimental/jax2tf/tests/jax2tf_test.py
@@ -50,22 +50,17 @@
 
 class Jax2TfTest(tf_test_util.JaxToTfTestCase):
 
-  @classmethod
-  def setUpClass(cls):
+  def setUp(self):
+    super().setUp()
     # One TF device of each device_type
-    cls.tf_devices = []
+    self.tf_devices = []
     for tf_device in (tf.config.list_logical_devices("TPU") +
                       tf.config.list_logical_devices("GPU") +
                       tf.config.list_logical_devices()):
       if tf_device.device_type == "TPU_SYSTEM":
         continue  # A virtual device
-      if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
-        cls.tf_devices.append(tf_device)
-
-    super().setUpClass()
-
-  def setUp(self):
-    super().setUp()
+      if all(tf_device.device_type != d.device_type for d in self.tf_devices):
+        self.tf_devices.append(tf_device)
     self.warning_ctx = jtu.ignore_warning(
         message="jax2tf.convert with native_serialization=False has been deprecated"
     )
@@ -1666,7 +1661,7 @@ def f_jax(x):
       f_jax,
       native_serialization=True,
       native_serialization_platforms=("cpu", "cuda", "tpu"))
-    for tf_device in self.__class__.tf_devices:
+    for tf_device in self.tf_devices:
       logging.info(
         f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
       with tf.device(tf_device):
diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py
index 8fe9a1dd9254..653ddce7dca4 100644
--- a/jax/experimental/jax2tf/tests/sharding_test.py
+++ b/jax/experimental/jax2tf/tests/sharding_test.py
@@ -19,13 +19,13 @@
 
 """
 from collections.abc import Sequence
-import contextlib
 from functools import partial
 import logging
 import re
 from typing import Any
 import unittest
 
+from absl import app
 from absl.testing import absltest
 
 import jax
@@ -47,16 +47,15 @@
 import tensorflow as tf
 
 config.parse_flags_with_absl()
+jtu.request_cpu_devices(8)
 
 # Must come after initializing the flags
 from jax.experimental.jax2tf.tests import tf_test_util
 
-_exit_stack = contextlib.ExitStack()
 topology = None
 
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
 
+def initialize_tf_tpu():
   global topology
   if jtu.test_device_matches(["tpu"]):
     with jtu.ignore_warning(message="the imp module is deprecated"):
@@ -67,8 +66,7 @@ def setUpModule():
   else:
     topology = None
 
-def tearDownModule():
-  _exit_stack.close()
+app.call_after_init(initialize_tf_tpu)
 
 
 class ShardingTest(tf_test_util.JaxToTfTestCase):
diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py
index 6e9ba6a382db..3fd34ccdad9d 100644
--- a/jax/experimental/mosaic/gpu/examples/flash_attention.py
+++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py
@@ -607,7 +607,7 @@ def ref(q, k, v):
     exit(0)
 
   batch_size = 1
-  num_q_heads = 4
+  num_q_heads = 2
   num_kv_heads = 1
   prof_spec = None
   seq_lens = (4096, 32768)
diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py
index 275fcd84e44c..31de258b2a2f 100644
--- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py
+++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py
@@ -200,9 +200,11 @@ def run(refs):
         grid=(batch_size, num_q_tiles, num_q_heads),
         num_threads=3,
         axis_names=("batch", "q_seq", "heads", "wg"),
-        approx_math=True,
     )
-    @pl.core_map(mesh)
+
+    @pl.core_map(
+        mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True)
+    )
     def _kernel_entry():
       compute_wgs = 2
       tiling = plgpu.TilingTransform((64, 64))
diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py
index 599e4ab0e56c..910fa4728d69 100644
--- a/jax/experimental/shard_map.py
+++ b/jax/experimental/shard_map.py
@@ -652,10 +652,11 @@ def _shard_map_lowering_shardy(
   sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
   args = (*ctx.dim_var_values, *in_nodes)
 
-  manual_axes = sub_ctx.axis_context.manual_axes
-  mesh_shape = mesh.shape
-  manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes])
-  if manual_axes_size == 1:
+  # The order of manual axes should match the order of mesh.axis_names to avoid
+  # non-determinism issues.
+  manual_axes = [a for a in mesh.axis_names
+                 if a in sub_ctx.axis_context.manual_axes]
+  if np.prod([mesh.shape[a] for a in manual_axes]) == 1:
     # No need for a `ManualComputationOp` if all manual axes are size 1.
     with core.extend_axis_env_nd(tuple(mesh.shape.items())):
       out_nodes, _ = mlir.jaxpr_subcomp(
diff --git a/jax/monitoring.py b/jax/monitoring.py
index 374e301b970c..4c9996da582c 100644
--- a/jax/monitoring.py
+++ b/jax/monitoring.py
@@ -22,9 +22,11 @@
 """
 
 from jax._src.monitoring import (
+    clear_event_listeners as clear_event_listeners,
     record_event_duration_secs as record_event_duration_secs,
+    record_event_time_span as record_event_time_span,
     record_event as record_event,
     register_event_duration_secs_listener as register_event_duration_secs_listener,
     register_event_listener as register_event_listener,
-    clear_event_listeners as clear_event_listeners,
+    register_event_time_span_listener as register_event_time_span_listener,
 )
diff --git a/jaxlib/BUILD b/jaxlib/BUILD
index e69432e89384..843ccb112871 100644
--- a/jaxlib/BUILD
+++ b/jaxlib/BUILD
@@ -214,6 +214,7 @@ nanobind_extension(
     deps = [
         "@com_google_absl//absl/cleanup",
         "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/synchronization",
         "@nanobind",
         "@xla//third_party/python_runtime:headers",
     ],
diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD
index 37f9a35596d6..1aef1ebdd86d 100644
--- a/jaxlib/mosaic/BUILD
+++ b/jaxlib/mosaic/BUILD
@@ -240,6 +240,7 @@ cc_test(
     deps = [
         ":tpu_dialect",
         "//testing/base/public:gunit_main",
+        "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Support",
diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td
index 6ef809c4cb6a..a486d8fef84d 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu.td
+++ b/jaxlib/mosaic/dialect/tpu/tpu.td
@@ -847,6 +847,7 @@ def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncO
   ];
   let constructor = "::mlir::tpu::createInferVectorLayoutPass()";
   let options = [
+    Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
     Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
     Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
   ];
diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h
index 307f3582f007..0156798ca88d 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h
+++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h
@@ -79,7 +79,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
     int hardware_generation = -1);
 
 std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
-    std::array<int64_t, 2> target_shape = {8, 128});
+    int hardware_generation = -1,
+    std::array<int64_t, 2> target_shape = {8, 128},
+    const TpuTilingFlags &tpu_tiling_flags = {});
 
 std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
     std::array<int64_t, 2> target_shape = {8, 128});
diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
index 3a8263573544..4c6353f8c504 100644
--- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
+++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
@@ -429,7 +429,8 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
       MemRefType arg_type,
       inferMemref(
           MemRefType::get(value_ty.getShape(), value_ty.getElementType()),
-          ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{}));
+          ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{},
+          /*is_kernel_argument=*/true));
   const BlockArgument argument = entry_block.insertArgument(
       entry_block.getNumArguments() - 1, arg_type, UnknownLoc::get(mlir_ctx));
   const FunctionType func_ty = func.getFunctionType();
@@ -3334,21 +3335,38 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
       int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape);
       if (needs_physical_broadcast ==
           std::array{true, false}) {  // Sublane broadcast
-        if (layout_in.bitwidth() != 32) {
-          return op.emitOpError(
-              "Not implemented: Only 32-bit supported for sublane broadcast");
-        }
+        const int bitwidth = layout_in.bitwidth();
+        const int packing = layout_in.packing();
         if (num_tiles != 1) {
           return op.emitOpError(
               "Not implemented: Only native tiling supported");
         }
         TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 2), 1);
         TPU_ASSERT_OP(offsets_in[0].has_value());
-        const int64_t offset = *offsets_in[0];
+        const int64_t sublane_offset = *offsets_in[0] / packing;
+        const int64_t subelement_offset = *offsets_in[0] % packing;
         const DenseI32ArrayAttr indices = builder.getDenseI32ArrayAttr(
-            SmallVector<int32_t>(ctx.target_shape[0], offset));
+            SmallVector<int32_t>(ctx.target_shape[0], sublane_offset));
         src_tiles.Each([&](const absl::Span<const int64_t> src_idx,
-                           Value *const src_tile) {
+                           Value *const src_vreg) {
+          Value dst_vreg = *src_vreg;
+          // Replicate the value within each sublane.
+          if (packing != 1) {
+            auto vreg_int_ty = getNativeVregType(
+                builder.getIntegerType(bitwidth), ctx.target_shape);
+            auto src_vreg_int =
+                builder.create<tpu::BitcastVregOp>(vreg_int_ty, dst_vreg);
+            auto unpack_elem = builder.create<tpu::UnpackSubelementsOp>(
+                getNativeVregType(builder.getI32Type(), ctx.target_shape),
+                src_vreg_int, subelement_offset, tpu::PackFormat::kInterleaved);
+            SmallVector<Value> packed_vregs(packing, unpack_elem);
+            auto vreg_int = builder.create<tpu::PackSubelementsOp>(
+                vreg_int_ty, packed_vregs, tpu::PackFormat::kInterleaved);
+            dst_vreg = builder.create<tpu::BitcastVregOp>(dst_vreg.getType(),
+                                                          vreg_int);
+          }
+          dst_vreg = builder.create<tpu::GatherOp>(dst_vreg.getType(), dst_vreg,
+                                                   indices, 0);
           SmallVector<int64_t> dst_starts(dst_tiles_implicit_shape.size());
           SmallVector<int64_t> dst_limits(dst_tiles_implicit_shape.size());
           for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) {
@@ -3360,10 +3378,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
               dst_limits[i] = dst_starts[i] + 1;
             }
           }
-          updateSlice<Value>(dst_tiles,
-                             builder.create<tpu::GatherOp>(
-                                 src_tile->getType(), *src_tile, indices, 0),
-                             dst_starts, dst_limits);
+          updateSlice<Value>(dst_tiles, dst_vreg, dst_starts, dst_limits);
         });
       } else if (needs_physical_broadcast ==
                  std::array{false, true}) {  // Lane broadcast
diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
index 958fdea96945..e1f690efa85e 100644
--- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
+++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
@@ -302,8 +302,8 @@ LogicalResult canonicalize_elementwise(int hardware_generation_,
       // TODO(mvoz): Look into (1) what it would take to support these ops
       // natively on later hardware, and (2) how to better organize this list.
       bool needs_cast = hardware_generation_ <= 5 || isa<math::PowFOp>(op) ||
-                        isa<arith::DivFOp>(op) || isa<math::TanhOp>(op) ||
-                        isa<math::ExpOp>(op) || isa<math::LogOp>(op);
+                        isa<math::TanhOp>(op) || isa<math::ExpOp>(op) ||
+                        isa<math::LogOp>(op);
       if (needs_cast && element_type.isBF16()) {
         auto target_f32 =
             builder.create<arith::ExtFOp>(op.getLoc(), target_f32_ty, operand)
diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc
index 046b642f98a3..05667a847691 100644
--- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc
+++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc
@@ -43,10 +43,12 @@ namespace mlir::tpu {
 //   tpu_tiling_flags: A struct of flags indicating which large tiling modes are
 //     enabled by XLA for memrefs.
 //   bitwidth: The bitwidth of the element type of the operand.
+//   is_kernel_argument: Whether the operand is a kernel argument.
 int getTilingFactor(const int num_lanes, const int hardware_generation,
                     const int64_t sublane_count,
                     const TpuTilingFlags &tpu_tiling_flags,
-                    const int8_t bitwidth) {
+                    const int8_t bitwidth,
+                    const bool is_kernel_argument) {
   CHECK(llvm::isPowerOf2_32(bitwidth));
   CHECK_LE(4, bitwidth);
   CHECK_LE(bitwidth, 32);
@@ -61,7 +63,11 @@ int getTilingFactor(const int num_lanes, const int hardware_generation,
     if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) {
       return sublane_count * 4;
     }
-    if (bitwidth == 16 && tpu_tiling_flags.use_x16_large_second_minor) {
+    // 16-bit values are generally always possible to relayout on the fly in v6,
+    // so we allow large 2nd minor tiling whenever possible. We can't do this
+    // for kernel arguments, because the layout of those is controlled by XLA.
+    if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor ||
+                           (!is_kernel_argument && hardware_generation >= 6))) {
       return sublane_count * 2;
     }
     return sublane_count;
@@ -84,6 +90,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
                                        const int hardware_generation,
                                        std::array<int64_t, 2> target_shape,
                                        const TpuTilingFlags &tpu_tiling_flags,
+                                       bool is_kernel_argument,
                                        int64_t leading_tile_rows = 0) {
   if (auto tiled_layout_attr =
           dyn_cast<TiledLayoutAttr>(memref_ty.getLayout())) {
@@ -119,7 +126,8 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
       const int64_t leading_tile =
           getTilingFactor(
               llvm::divideCeil(memref_ty.getShape().back(), lane_count),
-              hardware_generation, sublane_count, tpu_tiling_flags, bitwidth) *
+              hardware_generation, sublane_count, tpu_tiling_flags, bitwidth,
+              is_kernel_argument) *
           lane_count;
       SmallVector<xla::Tile> tiles{xla::Tile({leading_tile})};
       if (bitwidth != 32) {
@@ -139,7 +147,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
     if (leading_tile_rows == 0) {
       leading_tile_rows =
           getTilingFactor(second_minor, hardware_generation, sublane_count,
-                          tpu_tiling_flags, bitwidth);
+                          tpu_tiling_flags, bitwidth, is_kernel_argument);
     }
     SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, lane_count})};
     if (bitwidth != 32) {
@@ -186,6 +194,7 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
                                   const int hardware_generation,
                                   std::array<int64_t, 2> target_shape,
                                   const TpuTilingFlags &tpu_tiling_flags,
+                                  bool is_kernel_argument,
                                   int64_t leading_tile_rows) {
   if (isa<SemaphoreType, DMASemaphoreType>(memref.getElementType())) {
     const Attribute semaphore_mem = tpu::MemorySpaceAttr::get(
@@ -209,7 +218,7 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
   FAILUREOR_ASSIGN_OR_RETURN(
       const TiledLayoutAttr layout,
       inferLayout(memref, hardware_generation, target_shape, tpu_tiling_flags,
-                  leading_tile_rows));
+                  is_kernel_argument, leading_tile_rows));
 
   const ArrayRef<xla::Tile> tiles = layout.getTiles();
   if (failed(checkTiles(memref.getContext(), tiles))) {
@@ -248,7 +257,8 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
     FAILUREOR_ASSIGN_OR_RETURN(
         const MemRefType new_memref_ty,
         inferMemref(memref_ty, hardware_generation, target_shape,
-                    tpu_tiling_flags, leading_tile_rows));
+                    tpu_tiling_flags, /*is_kernel_argument=*/false,
+                    leading_tile_rows));
     alloca_op.getResult().setType(new_memref_ty);
     if (memref_ty != new_memref_ty) {
       OpBuilder builder(alloca_op->getContext());
@@ -265,9 +275,10 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
   } else if (auto alloca_op = dyn_cast<tpu::AllocaSemaphoreOp>(op)) {
     TypedValue<MemRefType> arg = alloca_op.getResult();
     const MemRefType memref_ty = alloca_op.getResult().getType();
-    FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
-                               inferMemref(memref_ty, hardware_generation,
-                                           target_shape, tpu_tiling_flags));
+    FAILUREOR_ASSIGN_OR_RETURN(
+        const MemRefType new_memref_ty,
+        inferMemref(memref_ty, hardware_generation, target_shape,
+                    tpu_tiling_flags, /*is_kernel_argument=*/false));
     alloca_op.getResult().setType(new_memref_ty);
     if (memref_ty != new_memref_ty) {
       OpBuilder builder(alloca_op->getContext());
@@ -320,7 +331,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
     FAILUREOR_ASSIGN_OR_RETURN(
         MemRefType new_memref_ty,
         inferMemref(memref_ty, hardware_generation, target_shape,
-                    tpu_tiling_flags, leading_tile_rows));
+                    tpu_tiling_flags, /*is_kernel_argument=*/true,
+                    leading_tile_rows));
     arg.setType(new_memref_ty);
     new_arg_types.push_back(arg.getType());
     if (memref_ty != new_memref_ty) {
diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h
index ed2a34793536..f2ab7c624eb1 100644
--- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h
+++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h
@@ -14,6 +14,7 @@ namespace mlir::tpu {
 FailureOr<MemRefType> inferMemref(MemRefType memref, int hardware_generation,
                                   std::array<int64_t, 2> target_shape,
                                   const TpuTilingFlags& tpu_tiling_flags,
+                                  bool is_kernel_argument,
                                   int64_t leading_tile_rows = 0);
 
 const std::string_view kLeadingTileRows = "leading_tile_rows";
diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
index c5448a3df514..d189994d9564 100644
--- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
+++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
@@ -92,9 +92,13 @@ LogicalResult verifyDivisibleIndex(Value tiled_index, int64_t tiling, int dim,
 // have corresponding native instructions.
 class VectorLayoutInferer {
  public:
-  explicit VectorLayoutInferer(std::array<int64_t, 2> target_shape)
-      : target_shape_({target_shape[0], target_shape[1]}),
-        default_tiling_(target_shape) {}
+  explicit VectorLayoutInferer(int hardware_generation,
+                               std::array<int64_t, 2> target_shape,
+                               const TpuTilingFlags &tpu_tiling_flags)
+      : hardware_generation_(hardware_generation),
+        target_shape_({target_shape[0], target_shape[1]}),
+        default_tiling_(target_shape),
+        tpu_tiling_flags_(tpu_tiling_flags) {}
 
 #define TPU_CHECK_OP(cond, msg) \
   if (!(cond)) {                \
@@ -1062,16 +1066,14 @@ class VectorLayoutInferer {
       // should always use that when sublane broadcasting is required.
       if (src_tiled_ishape[0] != dst_tiled_ishape[0] &&
           layout.offsets()[0] != std::nullopt) {
-        if (layout.bitwidth() != kNativeBitwidth) {
-          NYI("Only 32-bit broadcasts supported");
-        }
         LayoutOffsets offsets = layout.offsets();
         // At the moment relayout can only produce replicated sublanes when
         // converting to (8, 128) if the input was in (1, 128) tiling
-        if (layout.tiling()[0] == 1) {
+        if (layout.tiling()[0] == 1 && layout.bitwidth() == kNativeBitwidth) {
           offsets[0] = std::nullopt;
         }
-        layout = VectorLayout(layout.bitwidth(), offsets, default_tiling_,
+        layout = VectorLayout(layout.bitwidth(), offsets,
+                              nativeTiling(layout.bitwidth()),
                               layout.implicit_dim());
       }
       LayoutOffsets offsets = layout.offsets();
@@ -1705,6 +1707,21 @@ class VectorLayoutInferer {
     }
     auto &layout = *some_layout;
     bool select_native = allUsersRequireNativeTiling(op->getResult(0));
+    // We might want to reconsider enabling native this aggressively in cases
+    // when it would introduce a lot of padding (e.g. when the value only has
+    // a small second minor size, but large minor size).
+    if (dst_ty.getElementTypeBitWidth() == 16) {
+      // TPUv6 has good support for compute in 16-bit and cheap retiling between
+      // large 2nd minor and the default tiling, so we bias towards large tiles.
+      select_native |= hardware_generation_ >= 6 ||
+                       tpu_tiling_flags_.use_x16_large_second_minor;
+    } else if (dst_ty.getElementTypeBitWidth() == 8) {
+      select_native |= tpu_tiling_flags_.use_x8_large_second_minor;
+    } else if (dst_ty.getElementTypeBitWidth() == 4) {
+      select_native |= tpu_tiling_flags_.use_x4_large_second_minor;
+    } else {
+      return op->emitOpError("Unsupported target bitwidth for truncation");
+    }
     auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
                                    layout.implicit_dim());
     auto dst_layout = VectorLayout(
@@ -2019,15 +2036,15 @@ class VectorLayoutInferer {
             default_tiling_[1]};
   }
 
+  int hardware_generation_;
   std::array<int64_t, 2> target_shape_;
   std::array<int64_t, 2> default_tiling_;
+  TpuTilingFlags tpu_tiling_flags_;
 
   // TODO(b/342235360): Deprecate force_first_tile_offsets_ once we fully
   // remove the restriction that offsets must fall within the first tile.
   bool force_first_tile_offsets_ = false;
 
-  // Address alignment requirement, counted in 32-bit increments.
-  static constexpr int64_t kVmemAlignment32 = 128;
   // TODO(apaszke): This is not really native on newer generations of TPUs.
   // Get rid of this temporary stopgap.
   static constexpr int8_t kNativeBitwidth = 32;
@@ -2035,24 +2052,39 @@ class VectorLayoutInferer {
 
 struct InferVectorLayoutPass
     : public impl::InferVectorLayoutPassBase<InferVectorLayoutPass> {
-  InferVectorLayoutPass(std::array<int64_t, 2> target_shape) {
+  InferVectorLayoutPass(int hardware_generation,
+                        std::array<int64_t, 2> target_shape,
+                        TpuTilingFlags tpu_tiling_flags) {
+    this->hardware_generation = hardware_generation;
     this->sublane_count = target_shape[0];
     this->lane_count = target_shape[1];
+    this->tpu_tiling_flags = tpu_tiling_flags;
   }
   void runOnOperation() override {
+    // Fail if hardware_generation has not been set from the default value.
+    if (hardware_generation < 0) {
+      getOperation().emitError("hardware_generation must be set") << hardware_generation;
+      signalPassFailure();
+      return;
+    }
     func::FuncOp func = getOperation();
-    VectorLayoutInferer run({sublane_count, lane_count});
+    VectorLayoutInferer run(hardware_generation, {sublane_count, lane_count},
+                            tpu_tiling_flags);
     if (run.infer(func).failed()) {
       signalPassFailure();
     }
   }
+
+  TpuTilingFlags tpu_tiling_flags;
 };
 
 }  // namespace
 
 std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
-    std::array<int64_t, 2> target_shape) {
-  return std::make_unique<InferVectorLayoutPass>(target_shape);
+    int hardware_generation, std::array<int64_t, 2> target_shape,
+    const TpuTilingFlags &tpu_tiling_flags) {
+  return std::make_unique<InferVectorLayoutPass>(
+      hardware_generation, target_shape, tpu_tiling_flags);
 }
 
 }  // namespace mlir::tpu
diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.cc b/jaxlib/mosaic/dialect/tpu/vreg_util.cc
index 7dc5c13c073e..75c15f6a9f6e 100644
--- a/jaxlib/mosaic/dialect/tpu/vreg_util.cc
+++ b/jaxlib/mosaic/dialect/tpu/vreg_util.cc
@@ -27,6 +27,7 @@ limitations under the License.
 #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/include/mlir/IR/Types.h"
 #include "mlir/include/mlir/IR/Value.h"
+#include "mlir/include/mlir/IR/ValueRange.h"
 #include "mlir/include/mlir/Support/LLVM.h"
 #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
 #include "jaxlib/mosaic/dialect/tpu/util.h"
@@ -91,8 +92,6 @@ TypedValue<VectorType> getZerosLikeVector(ImplicitLocOpBuilder &builder,
 FailureOr<TypedValue<VectorType>> getX32VmaskByPaddingEnd(
     ImplicitLocOpBuilder &builder, int64_t padding,
     const std::array<int64_t, 2> target_shape, int64_t dim) {
-  VectorType i32_vreg_ty =
-      getNativeVregType(builder.getI32Type(), target_shape);
   if (dim != 0 && dim != 1) {
     return builder.emitError()
            << "Expected a 2D vector for getX32VmaskByPaddingEnd";
@@ -100,22 +99,29 @@ FailureOr<TypedValue<VectorType>> getX32VmaskByPaddingEnd(
 
   if (padding < 0 || padding > target_shape[dim]) {
     return builder.emitError()
-           << "Padding must be in [0, target_shape[dim]). Padding: " << padding
+           << "Padding must be in [0, target_shape[dim]]. Padding: " << padding
            << ", target_shape[dim]: " << target_shape[dim];
   }
 
-  Value padding_vreg =
-      getFullVector(builder, i32_vreg_ty,
-                    builder.getI32IntegerAttr(target_shape[dim] - padding));
-
-  return cast<TypedValue<VectorType>>(
-      builder
-          .create<arith::CmpIOp>(
-              arith::CmpIPredicate::slt,
-              builder.create<tpu::IotaOp>(i32_vreg_ty,
-                                          builder.getI32IntegerAttr(dim)),
-              padding_vreg)
-          .getResult());
+  auto idx_const = [&builder](int64_t idx) {
+    return IdxConst(idx, builder, builder.getLoc());
+  };
+
+  tpu::CreateMaskOp mask_op;
+  const VectorType vmask_ty = getNativeVregOrVmaskType(
+      builder.getI1Type(), /*layout_bitwidth=*/32, target_shape);
+  if (dim == 0) {
+    mask_op = builder.create<tpu::CreateMaskOp>(
+        vmask_ty, ValueRange{idx_const(0), idx_const(0)},
+        ValueRange{idx_const(target_shape[0] - padding),
+                   idx_const(target_shape[1])});
+  } else {
+    mask_op = builder.create<tpu::CreateMaskOp>(
+        vmask_ty, ValueRange{idx_const(0), idx_const(0)},
+        ValueRange{idx_const(target_shape[0]),
+                   idx_const(target_shape[1] - padding)});
+  }
+  return cast<TypedValue<VectorType>>(mask_op.getResult());
 }
 
 LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder,
diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc
index dadbac133fbf..4b9da9505d75 100644
--- a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc
+++ b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc
@@ -21,9 +21,12 @@ limitations under the License.
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "llvm/include/llvm/ADT/TypeSwitch.h"
 #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/include/mlir/IR/Attributes.h"
 #include "mlir/include/mlir/IR/Builders.h"
+#include "mlir/include/mlir/IR/BuiltinAttributes.h"
 #include "mlir/include/mlir/IR/BuiltinOps.h"
 #include "mlir/include/mlir/IR/BuiltinTypes.h"
 #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
@@ -38,39 +41,56 @@ namespace mlir::tpu {
 
 namespace {
 
+using ::testing::ElementsAre;
 using ::testing::Eq;
 using ::testing::Optional;
 
-MATCHER_P2(IsConstantOpWithSplatValue, type, splat_value, "") {
+MATCHER_P2(IsConstantOpWithSplatOrScalarValue, type, value, "") {
   auto constant_op = dyn_cast<arith::ConstantOp>(arg.getDefiningOp());
   if (constant_op == nullptr) {
     *result_listener << "Expected a constant op, got " << debugString(arg);
     return false;
   }
-  auto dense_attr = dyn_cast<DenseElementsAttr>(constant_op.getValue());
-  if (dense_attr == nullptr) {
-    *result_listener << "Expected a dense elements attr, got "
-                     << debugString(arg);
-    return false;
-  }
-  if (dense_attr.getType() != type) {
-    *result_listener << "Expected a dense elements attr with type "
-                     << debugString(type) << ", got "
-                     << debugString(dense_attr.getType());
-    return false;
-  }
-  if (!dense_attr.isSplat()) {
-    *result_listener << "Expected a splat dense elements attr, got "
-                     << debugString(dense_attr);
-    return false;
-  }
-  if (auto s = dense_attr.template getSplatValue<decltype(splat_value)>();
-      s != splat_value) {
-    *result_listener << "Expected a splat dense elements attr with value "
-                     << splat_value << ", got " << s;
-    return false;
-  }
-  return true;
+
+  return llvm::TypeSwitch<Attribute, bool>(constant_op.getValue())
+      .template Case<DenseElementsAttr>([&](auto attr) {
+        // If it's dense, it must be splat.
+        if (attr.getType() != type) {
+          *result_listener << "Expected a dense elements attr with type "
+                           << debugString(type) << ", got "
+                           << debugString(attr.getType());
+          return false;
+        }
+        if (!attr.isSplat()) {
+          *result_listener << "Expected a splat dense elements attr, got "
+                           << debugString(attr);
+          return false;
+        }
+        if (auto s = attr.template getSplatValue<decltype(value)>();
+            s != value) {
+          *result_listener << "Expected a splat dense elements attr with value "
+                           << value << ", got " << s;
+          return false;
+        }
+        return true;
+      })
+      .template Case<IntegerAttr>([&](auto attr) {
+        if (attr.getType() != type) {
+          *result_listener << "Expected a attr with type " << debugString(type)
+                           << ", got " << debugString(attr.getType());
+          return false;
+        }
+        if (auto s = attr.getInt(); s != value) {
+          *result_listener << "Expected a attr with value " << value << ", got "
+                           << s;
+          return false;
+        }
+        return true;
+      })
+      .template Default([&](auto attr) {
+        *result_listener << "Unsupported attribute type: " << debugString(attr);
+        return false;
+      });
 }
 
 MATCHER_P2(IsVectorTypeWithShape, shape, elem_ty, "") {
@@ -150,7 +170,7 @@ TEST_F(VregUtilTest, GetFullVector) {
   TypedValue<VectorType> vec =
       getFullVector(Builder(), vty, Builder().getI32IntegerAttr(0x1));
 
-  EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0x1}));
+  EXPECT_THAT(vec, IsConstantOpWithSplatOrScalarValue(vty, int32_t{0x1}));
 }
 
 TEST_F(VregUtilTest, GetFullLikeVector) {
@@ -161,14 +181,14 @@ TEST_F(VregUtilTest, GetFullLikeVector) {
   TypedValue<VectorType> vec =
       getFullLikeVector(Builder(), in_vec, Builder().getF32FloatAttr(2.0f));
 
-  EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{2.0f}));
+  EXPECT_THAT(vec, IsConstantOpWithSplatOrScalarValue(vty, float{2.0f}));
 }
 
 TEST_F(VregUtilTest, GetZerosVector) {
   VectorType vty = VectorType::get({2, 4}, Builder().getI32Type());
   TypedValue<VectorType> vec = getZerosVector(Builder(), vty);
 
-  EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, int32_t{0}));
+  EXPECT_THAT(vec, IsConstantOpWithSplatOrScalarValue(vty, int32_t{0}));
 }
 
 TEST_F(VregUtilTest, GetZerosLikeVector) {
@@ -178,7 +198,7 @@ TEST_F(VregUtilTest, GetZerosLikeVector) {
                vty.getElementType(), Builder().getF32FloatAttr(1.0f)));
   TypedValue<VectorType> vec = getZerosLikeVector(Builder(), in_vec);
 
-  EXPECT_THAT(vec, IsConstantOpWithSplatValue(vty, float{0.0f}));
+  EXPECT_THAT(vec, IsConstantOpWithSplatOrScalarValue(vty, float{0.0f}));
 }
 
 TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim0) {
@@ -188,18 +208,18 @@ TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim0) {
       /*dim=*/0);
   ASSERT_TRUE(succeeded(vec));
 
-  auto cmp_op = dyn_cast<arith::CmpIOp>(vec.value().getDefiningOp());
-  ASSERT_TRUE(cmp_op != nullptr);
-  EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt);
-
-  auto iota_op = dyn_cast<tpu::IotaOp>(cmp_op.getLhs().getDefiningOp());
-  ASSERT_TRUE(iota_op != nullptr);
-  EXPECT_THAT(iota_op.getDimension(), Optional(Eq(0)));
-
-  EXPECT_THAT(
-      cmp_op.getRhs(),
-      IsConstantOpWithSplatValue(
-          VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{3}));
+  auto mask_op = dyn_cast<tpu::CreateMaskOp>(vec.value().getDefiningOp());
+  ASSERT_TRUE(mask_op != nullptr);
+  EXPECT_THAT(ArrayRef<Value>({mask_op.getLow()[0], mask_op.getLow()[1]}),
+              ElementsAre(IsConstantOpWithSplatOrScalarValue(
+                              Builder().getIndexType(), int64_t{0}),
+                          IsConstantOpWithSplatOrScalarValue(
+                              Builder().getIndexType(), int64_t{0})));
+  EXPECT_THAT(ArrayRef<Value>({mask_op.getHigh()[0], mask_op.getHigh()[1]}),
+              ElementsAre(IsConstantOpWithSplatOrScalarValue(
+                              Builder().getIndexType(), int64_t{3}),
+                          IsConstantOpWithSplatOrScalarValue(
+                              Builder().getIndexType(), int64_t{8})));
 }
 
 TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim1) {
@@ -209,18 +229,18 @@ TEST_F(VregUtilTest, GetX32VmaskByPaddingEndDim1) {
       /*dim=*/1);
   ASSERT_TRUE(succeeded(vec));
 
-  auto cmp_op = dyn_cast<arith::CmpIOp>(vec.value().getDefiningOp());
-  ASSERT_TRUE(cmp_op != nullptr);
-  EXPECT_EQ(cmp_op.getPredicate(), arith::CmpIPredicate::slt);
-
-  auto iota_op = dyn_cast<tpu::IotaOp>(cmp_op.getLhs().getDefiningOp());
-  ASSERT_TRUE(iota_op != nullptr);
-  EXPECT_THAT(iota_op.getDimension(), Optional(Eq(1)));
-
-  EXPECT_THAT(
-      cmp_op.getRhs(),
-      IsConstantOpWithSplatValue(
-          VectorType::get(kTargetShape, Builder().getI32Type()), int32_t{5}));
+  auto mask_op = dyn_cast<tpu::CreateMaskOp>(vec.value().getDefiningOp());
+  ASSERT_TRUE(mask_op != nullptr);
+  EXPECT_THAT(ArrayRef<Value>({mask_op.getLow()[0], mask_op.getLow()[1]}),
+              ElementsAre(IsConstantOpWithSplatOrScalarValue(
+                              Builder().getIndexType(), int64_t{0}),
+                          IsConstantOpWithSplatOrScalarValue(
+                              Builder().getIndexType(), int64_t{0})));
+  EXPECT_THAT(ArrayRef<Value>({mask_op.getHigh()[0], mask_op.getHigh()[1]}),
+              ElementsAre(IsConstantOpWithSplatOrScalarValue(
+                              Builder().getIndexType(), int64_t{4}),
+                          IsConstantOpWithSplatOrScalarValue(
+                              Builder().getIndexType(), int64_t{5})));
 }
 
 }  // namespace
diff --git a/jaxlib/setup.py b/jaxlib/setup.py
index c2efd3d7b7a7..b3a37a25f1b2 100644
--- a/jaxlib/setup.py
+++ b/jaxlib/setup.py
@@ -61,8 +61,7 @@ def has_ext_modules(self):
     packages=['jaxlib', 'jaxlib.xla_extension'],
     python_requires='>=3.10',
     install_requires=[
-        'scipy>=1.10',
-        "scipy>=1.11.1; python_version>='3.12'",
+        'scipy>=1.11.1',
         'numpy>=1.25',
         'ml_dtypes>=0.2.0',
     ],
diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc
index 28201233566a..6b612b26dce8 100644
--- a/jaxlib/utils.cc
+++ b/jaxlib/utils.cc
@@ -18,6 +18,7 @@ limitations under the License.
 #include "nanobind/nanobind.h"
 #include "absl/cleanup/cleanup.h"
 #include "absl/container/inlined_vector.h"
+#include "absl/synchronization/mutex.h"
 
 namespace nb = nanobind;
 
@@ -225,4 +226,19 @@ NB_MODULE(utils, m) {
       PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr()));
   m.attr("safe_zip") = nb::steal<nb::object>(
       PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
+
+  // Python has no reader-writer lock in its standard library, so we expose
+  // bindings around absl::Mutex.
+  nb::class_<absl::Mutex>(m, "Mutex")
+      .def(nb::init<>())
+      .def("lock", &absl::Mutex::Lock, nb::call_guard<nb::gil_scoped_release>())
+      .def("unlock", &absl::Mutex::Unlock)
+      .def("assert_held", &absl::Mutex::AssertHeld)
+      .def("reader_lock", &absl::Mutex::ReaderLock,
+           nb::call_guard<nb::gil_scoped_release>())
+      .def("reader_unlock", &absl::Mutex::ReaderUnlock)
+      .def("assert_reader_held", &absl::Mutex::AssertReaderHeld)
+      .def("writer_lock", &absl::Mutex::WriterLock,
+           nb::call_guard<nb::gil_scoped_release>())
+      .def("writer_unlock", &absl::Mutex::WriterUnlock);
 }
\ No newline at end of file
diff --git a/setup.py b/setup.py
index b3bd4a3466d7..39508388ba8a 100644
--- a/setup.py
+++ b/setup.py
@@ -60,8 +60,7 @@ def load_version_module(pkg_path):
         'numpy>=1.25',
         "numpy>=1.26.0; python_version>='3.12'",
         'opt_einsum',
-        'scipy>=1.10',
-        "scipy>=1.11.1; python_version>='3.12'",
+        'scipy>=1.11.1',
     ],
     extras_require={
         # Minimum jaxlib version; used in testing.
diff --git a/tests/array_test.py b/tests/array_test.py
index 9618a8cf4665..97bf71a5216b 100644
--- a/tests/array_test.py
+++ b/tests/array_test.py
@@ -43,20 +43,12 @@
 from jax._src import prng
 
 jax.config.parse_flags_with_absl()
+jtu.request_cpu_devices(8)
 
 with contextlib.suppress(ImportError):
   import pytest
   pytestmark = pytest.mark.multiaccelerator
 
-# Run all tests with 8 CPU devices.
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
-
-def tearDownModule():
-  _exit_stack.close()
-
 
 def create_array(shape, sharding, global_data=None):
   if global_data is None:
diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py
index 4485f5d4f41e..f9dd3ce52b58 100644
--- a/tests/colocated_python_test.py
+++ b/tests/colocated_python_test.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import contextlib
 import threading
 import time
 from typing import Sequence
@@ -29,6 +28,7 @@
 import numpy as np
 
 config.parse_flags_with_absl()
+jtu.request_cpu_devices(8)
 
 
 def _colocated_cpu_devices(
@@ -53,18 +53,6 @@ def _colocated_cpu_devices(
 _count_colocated_python_specialization_cache_miss = jtu.count_events(
   "colocated_python_func._get_specialized_func")
 
-_exit_stack = contextlib.ExitStack()
-
-
-def setUpModule():
-  # TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT
-  # prepares CPU devices by its own.
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
-
-
-def tearDownModule():
-  _exit_stack.close()
-
 
 class ColocatedPythonTest(jtu.JaxTestCase):
 
diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py
index 428e518eab51..73d76c1a4938 100644
--- a/tests/compilation_cache_test.py
+++ b/tests/compilation_cache_test.py
@@ -52,18 +52,12 @@
 FAKE_COMPILE_TIME = 10
 _counts = Counter()  # Map event name to count
 
-
-def setUpModule():
-  monitoring.register_event_listener(increment_event_count)
-
-
-def tearDownModule():
-  monitoring._unregister_event_listener_by_callback(increment_event_count)
-
-
 def increment_event_count(event):
   _counts[event] += 1
 
+monitoring.register_event_listener(increment_event_count)
+
+
 def msg_exists_in_logs(msg: str, records: list[logging.LogRecord],
                        level: int | None = None) -> bool:
   return any(msg in record.getMessage() for record in records
diff --git a/tests/debugger_test.py b/tests/debugger_test.py
index 18693a7bb2c3..419e7b18dfed 100644
--- a/tests/debugger_test.py
+++ b/tests/debugger_test.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 
 from collections.abc import Sequence
-import contextlib
 import io
 import re
 import textwrap
@@ -29,6 +28,7 @@
 import numpy as np
 
 jax.config.parse_flags_with_absl()
+jtu.request_cpu_devices(2)
 
 def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]:
   fake_stdin = io.StringIO()
@@ -41,14 +41,6 @@ def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringI
 def _format_multiline(text):
   return textwrap.dedent(text).lstrip()
 
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(2))
-
-def tearDownModule():
-  _exit_stack.close()
-
 foo = 2
 
 class CliDebuggerTest(jtu.JaxTestCase):
diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py
index 6afb41645405..0fc9665ceaa5 100644
--- a/tests/debugging_primitives_test.py
+++ b/tests/debugging_primitives_test.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import collections
-import contextlib
 import functools
 import textwrap
 import unittest
@@ -35,19 +34,13 @@
   rich = None
 
 jax.config.parse_flags_with_absl()
+jtu.request_cpu_devices(2)
 
 debug_print = debugging.debug_print
 
 def _format_multiline(text):
   return textwrap.dedent(text).lstrip()
 
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(2))
-
-def tearDownModule():
-  _exit_stack.close()
 
 class DummyDevice:
   def __init__(self, platform, id):
diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py
index f6625e86ca14..df79e6aaf6df 100644
--- a/tests/dynamic_api_test.py
+++ b/tests/dynamic_api_test.py
@@ -1487,6 +1487,7 @@ def f(i):
 class JumbleTest(jtu.JaxTestCase):
 
   def setUp(self):
+    super().setUp()
     if jax.config.x64_enabled: raise unittest.SkipTest()
 
   @parameterized.parameters((True,), (False,))
diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py
index e8b1afc224b7..d5878fa50a5a 100644
--- a/tests/export_harnesses_multi_platform_test.py
+++ b/tests/export_harnesses_multi_platform_test.py
@@ -48,11 +48,11 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]:
 
 class PrimitiveTest(jtu.JaxTestCase):
 
-  @classmethod
-  def setUpClass(cls):
+  def setUp(self):
+    super().setUp()
     # Pick one device from each available platform
-    cls.devices = []
-    cls.platforms = []
+    self.devices = []
+    self.platforms = []
     for backend in ["cpu", "gpu", "tpu"]:
       try:
         devices = jax.devices(backend)
@@ -60,10 +60,9 @@ def setUpClass(cls):
         devices = []
 
       for d in devices:
-        if d.platform not in cls.platforms:
-          cls.platforms.append(d.platform)
-          cls.devices.append(d)
-    super().setUpClass()
+        if d.platform not in self.platforms:
+          self.platforms.append(d.platform)
+          self.devices.append(d)
 
   # For each primitive we export for all platforms that are available and
   # compare the results of running the exported code and running the native
@@ -128,7 +127,7 @@ def export_and_compare_to_native(
       tol: float | None = None):
     devices = [
         d
-        for d in self.__class__.devices
+        for d in self.devices
         if d.platform not in unimplemented_platforms
     ]
     logging.info("Using devices %s", [str(d) for d in devices])
diff --git a/tests/export_test.py b/tests/export_test.py
index da0e9daf2f00..b13cf3a623e6 100644
--- a/tests/export_test.py
+++ b/tests/export_test.py
@@ -56,14 +56,8 @@
   CAN_SERIALIZE = False
 
 config.parse_flags_with_absl()
+jtu.request_cpu_devices(8)
 
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(2))
-
-def tearDownModule():
-  _exit_stack.close()
 
 ### Setup for testing lowering with effects
 @dataclasses.dataclass(frozen=True)
@@ -165,17 +159,16 @@ def serde_exported(*fun_args, **fun_kwargs):
 @jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version)
 class JaxExportTest(jtu.JaxTestCase):
 
-  @classmethod
-  def setUpClass(cls):
+  def setUp(self):
+    super().setUp()
     # Find the available platforms
-    cls.platforms = []
+    self.platforms = []
     for backend in ["cpu", "gpu", "tpu"]:
       try:
         jax.devices(backend)
       except RuntimeError:
         continue
-      cls.platforms.append(backend)
-    super().setUpClass()
+      self.platforms.append(backend)
 
   def test_basic_export_only(self):
     @jax.jit
@@ -1505,7 +1498,7 @@ def test_multi_platform(self):
                   module_str)
 
     # Call with argument placed on different plaforms
-    for platform in self.__class__.platforms:
+    for platform in self.platforms:
       x_device = jax.device_put(x, jax.devices(platform)[0])
       res_exp = exp.call(x_device)
       self.assertAllClose(
@@ -1530,7 +1523,7 @@ def test_multi_platform_nested(self):
     self.assertEqual(1, count_sine)
 
     # Call with argument placed on different plaforms
-    for platform in self.__class__.platforms:
+    for platform in self.platforms:
       if platform == "tpu": continue
       x_device = jax.device_put(x, jax.devices(platform)[0])
       res_exp = exp2.call(x_device)
@@ -1674,7 +1667,7 @@ def f_jax(b):  # b: f32[16 // DEVICES, 4]
     exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a)
 
     # Call with argument placed on different plaforms
-    for platform in self.__class__.platforms:
+    for platform in self.platforms:
       run_devices = jax.devices(platform)[0:len(export_devices)]
       if len(run_devices) != len(export_devices):
         continue
diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py
index 2e91792aa950..922b37ffa440 100644
--- a/tests/jaxpr_effects_test.py
+++ b/tests/jaxpr_effects_test.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import contextlib
+
 import threading
 import unittest
 
@@ -34,6 +34,7 @@
 import numpy as np
 
 config.parse_flags_with_absl()
+jtu.request_cpu_devices(2)
 
 effect_p = core.Primitive('effect')
 effect_p.multiple_results = True
@@ -132,15 +133,6 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
 mlir.register_lowering(callback_p, callback_effect_lowering)
 
 
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(2))
-
-def tearDownModule():
-  _exit_stack.close()
-
-
 class JaxprEffectsTest(jtu.JaxTestCase):
 
   def test_trivial_jaxpr_has_no_effects(self):
diff --git a/tests/layout_test.py b/tests/layout_test.py
index f958de5cf5bc..903b17886283 100644
--- a/tests/layout_test.py
+++ b/tests/layout_test.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import contextlib
 import math
 from functools import partial
 from absl.testing import absltest
@@ -28,14 +27,7 @@
 from jax.experimental.compute_on import compute_on
 
 config.parse_flags_with_absl()
-
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
-
-def tearDownModule():
-  _exit_stack.close()
+jtu.request_cpu_devices(8)
 
 
 class LayoutTest(jtu.JaxTestCase):
diff --git a/tests/linalg_test.py b/tests/linalg_test.py
index 0da09e232deb..65f7c8145138 100644
--- a/tests/linalg_test.py
+++ b/tests/linalg_test.py
@@ -1686,7 +1686,7 @@ def testTriangularSolveSingularBatched(self):
 
   @jtu.sample_product(
     n=[1, 4, 5, 20, 50, 100],
-    batch_size=[(), (2,), (3, 4)] if scipy_version >= (1, 9, 0) else [()],
+    batch_size=[(), (2,), (3, 4)],
     dtype=int_types + float_types + complex_types
   )
   def testExpm(self, n, batch_size, dtype):
diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py
index b84903618fab..7fb87086d9e6 100644
--- a/tests/mock_gpu_test.py
+++ b/tests/mock_gpu_test.py
@@ -32,9 +32,9 @@
 class MockGPUTest(jtu.JaxTestCase):
 
   def setUp(self):
+    super().setUp()
     if not jtu.test_device_matches(["gpu"]):
       self.skipTest("Mocking devices only works on the GPU backend.")
-    super().setUp()
 
   @jtu.skip_under_pytest("Test must run in an isolated process")
   def testMockDeviceCount(self):
diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py
index 44ec4e2f9529..71ce8f1dde1c 100644
--- a/tests/mock_gpu_topology_test.py
+++ b/tests/mock_gpu_topology_test.py
@@ -31,9 +31,9 @@
 class MockGPUTopologyTest(jtu.JaxTestCase):
 
   def setUp(self):
+    super().setUp()
     if not jtu.test_device_matches(["gpu"]):
       self.skipTest("Mocking devices only works on the GPU backend.")
-    super().setUp()
 
   @jtu.skip_under_pytest("Test must run in an isolated process")
   def testMockDeviceCount(self):
diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py
index dd41b264aa1c..d8ae6d3c2ff6 100644
--- a/tests/mosaic/gpu_test.py
+++ b/tests/mosaic/gpu_test.py
@@ -171,7 +171,7 @@ def setUp(self):
     self.context = mlir.make_ir_context()
     if mgpu_dialect is not None:
       mgpu_dialect.register_dialect(self.context)
-    self.enter_context(jtu.global_config_context(jax_traceback_filtering="off"))
+    self.enter_context(config.traceback_filtering("off"))
     self.enter_context(self.context)
     self.enter_context(ir.Location.unknown())
 
@@ -1756,13 +1756,13 @@ def kernel(ctx, src, dst, _):
 
 class TorchTest(TestCase):
 
-  @classmethod
-  def setUpClass(cls):
+  def setUp(self):
+    super().setUp()
     try:
       import torch
     except ImportError:
       raise unittest.SkipTest("Test requires PyTorch")
-    cls.torch = torch
+    self.torch = torch
 
   def test_basic(self):
     def kernel(ctx, i_gmem, o_gmem, _):
diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py
index 057731cb5d55..1fc6fe1e9298 100644
--- a/tests/multi_device_test.py
+++ b/tests/multi_device_test.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import contextlib
 from unittest import SkipTest
 import tracemalloc as tm
 
@@ -25,15 +24,7 @@
 from jax._src import test_util as jtu
 
 jax.config.parse_flags_with_absl()
-
-# Run all tests with 8 CPU devices.
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
-
-def tearDownModule():
-  _exit_stack.close()
+jtu.request_cpu_devices(8)
 
 
 class MultiDeviceTest(jtu.JaxTestCase):
diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py
index 4bd64a5a8dce..a9d5361e7c10 100644
--- a/tests/pallas/mosaic_gpu_test.py
+++ b/tests/pallas/mosaic_gpu_test.py
@@ -1481,14 +1481,11 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem):
                                       index_map=lambda i, j: (0, 0))
                    ],
     )
-    mesh = plgpu.GPUMesh(
-        grid=(1,),
-        num_threads=3,
-        axis_names=("_", "wg",),
-        approx_math=True,
-    )
+    mesh = plgpu.GPUMesh(grid=(1,), num_threads=3, axis_names=("_", "wg"))
     def run(refs):
-      @pl.core_map(mesh)
+      @pl.core_map(
+          mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True)
+      )
       def _kernel_entry():
         pipeline(*refs)
     @jax.jit
@@ -1535,13 +1532,12 @@ def tiled_add_kernel(x_smem, y_smem, o_smem):
                 transforms=[])],
     )
     mesh = plgpu.GPUMesh(
-        grid=(1,),
-        num_threads=num_compute_wgs + 1,
-        axis_names=("_", "wg",),
-        approx_math=True,
+        grid=(1,), num_threads=num_compute_wgs + 1, axis_names=("_", "wg")
     )
     def run(refs):
-      @pl.core_map(mesh)
+      @pl.core_map(
+          mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True)
+      )
       def _kernel_entry():
         pipeline(*refs)
     @jax.jit
diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py
index ec4e5805d056..290d35490bef 100644
--- a/tests/pallas/ops_test.py
+++ b/tests/pallas/ops_test.py
@@ -1116,18 +1116,35 @@ def kernel(x_ref, y_ref, out_ref):
   @parameterized.parameters(
       ("int32", "float32"),
       ("float32", "float32"),
+      ("bfloat16", "bfloat16"),
   )
   def test_true_divide(self, dtype, out_dtype):
+    if jtu.test_device_matches(["tpu"]):
+      if out_dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6):
+        self.skipTest("bfloat16 is not supported on older TPU generations")
+      if not jtu.if_cloud_tpu_at_least(2025, 1, 9):
+        self.skipTest("Requires libtpu built after 2025-01-09")
+    elif jtu.test_device_matches(["gpu"]):
+      if dtype == "bfloat16":
+        self.skipTest("bfloat16 not supported")
+
     @functools.partial(
         self.pallas_call,
-        out_shape=jax.ShapeDtypeStruct((8,), out_dtype),
+        out_shape=jax.ShapeDtypeStruct((8, 8), out_dtype),
     )
     def kernel(x_ref, y_ref, o_ref):
       o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])
 
     x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype)
     y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype)
-    np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y))
+    x = jnp.repeat(x, 8, axis=0).reshape(8, 8)
+    y = jnp.tile(y, 8).reshape(8, 8)
+    rtol = 8e-3 if dtype == "bfloat16" else 1e-6
+    np.testing.assert_allclose(
+        jnp.true_divide(x, y).astype(jnp.float32),
+        kernel(x, y).astype(jnp.float32),
+        rtol=rtol,
+    )
 
   @parameterized.parameters("float16", "bfloat16")
   def test_true_divide_unsupported(self, dtype):
diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py
index 373388d97691..b53a057f46fd 100644
--- a/tests/pallas/pallas_test.py
+++ b/tests/pallas/pallas_test.py
@@ -30,10 +30,12 @@
 from jax._src import api_util
 from jax._src import checkify
 from jax._src import config
+from jax._src import core as jax_core
 from jax._src import dtypes
 from jax._src import test_util as jtu
 from jax._src.lax.control_flow.for_loop import for_loop
 from jax._src.pallas import core as pallas_core
+from jax._src.pallas import pallas_call
 from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
 from jax.experimental import pallas as pl
 import jax.numpy as jnp
@@ -41,8 +43,10 @@
 
 if sys.platform != "win32":
   from jax.experimental.pallas import tpu as pltpu
+  from jax.experimental.pallas import triton as plgpu
 else:
   pltpu = None
+  plgpu = None
 
 
 # TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs.
@@ -2361,5 +2365,47 @@ class PallasCallNamedGridInterpretTest(PallasCallNamedGridTest):
   INTERPRET = True
 
 
+def _find_pallas_call_in_jaxpr(
+    jaxpr: jax_core.Jaxpr) -> jax_core.JaxprEqn | None:
+  for eqn in jaxpr.eqns:
+    call_eqn = None
+    if eqn.primitive == pallas_call.pallas_call_p:
+      call_eqn = eqn
+    elif 'jaxpr' in eqn.params:
+      call_eqn = _find_pallas_call_in_jaxpr(eqn.params['jaxpr'])
+    if call_eqn is not None:
+      return call_eqn
+  return None
+
+
+class PallasCompilerParamsTest(PallasBaseTest):
+  def test_triton_params_consistent_across_double_jit(self):
+    # Test for https://github.com/jax-ml/jax/issues/25714
+    if not jtu.test_device_matches(["gpu"]):
+      self.skipTest("Triton backend only works on GPU.")
+    params = plgpu.TritonCompilerParams(num_warps=8)
+
+    @jax.jit
+    @functools.partial(
+        self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
+        compiler_params=params)
+    def copy_kernel(x_ref, o_ref):
+      o_ref[...] = x_ref[...]
+
+    @functools.partial(jax.jit, static_argnames=["z"])
+    def plus_z(x, z):
+      return copy_kernel(x+z)
+
+    x = 0.
+    extracted_params = _find_pallas_call_in_jaxpr(
+        plus_z.trace(x, 1).jaxpr).params["compiler_params"]
+    self.assertEqual(plus_z(0., 1.), 1.)
+    self.assertEqual(extracted_params["triton"]["num_warps"], 8)
+    extracted_params = _find_pallas_call_in_jaxpr(
+        plus_z.trace(x, 2).jaxpr).params["compiler_params"]
+    self.assertEqual(plus_z(0., 2.), 2.)
+    self.assertEqual(extracted_params["triton"]["num_warps"], 8)
+
+
 if __name__ == "__main__":
   absltest.main()
diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py
index cb92794b2758..8f948050b8a5 100644
--- a/tests/pallas/tpu_ops_test.py
+++ b/tests/pallas/tpu_ops_test.py
@@ -176,6 +176,23 @@ def kernel(x_ref, y_ref, out_ref):
     )(x, y)
     np.testing.assert_array_equal(out, inp.reshape(m * 2, n))
 
+  @parameterized.parameters([jnp.int32, jnp.int16, jnp.int8])
+  def test_row_broadcast(self, dtype):
+    if not jtu.if_cloud_tpu_at_least(2024, 1, 9):
+      self.skipTest("Requires libtpu built after 2024-01-09")
+    if not self.INTERPRET and jtu.get_tpu_version() < 5:
+      self.skipTest("Requires TPUv5+")
+    def kernel(x_ref, y_ref):
+      y_ref[...] = jnp.broadcast_to(x_ref[pl.ds(3, 1)], y_ref.shape)
+    m, n = 4, 1024
+    x = jax.random.randint(
+        jax.random.key(12), (m, n), minval=-1000, maxval=1000, dtype=jnp.int32
+    ).astype(dtype)
+    y = self.pallas_call(
+        kernel, out_shape=jax.ShapeDtypeStruct((m, n), dtype)
+    )(x)
+    np.testing.assert_array_equal(y, jnp.broadcast_to(x[3:4], y.shape))
+
   def test_tpu_unsigned_int(self):
     def body(x_ref, o_ref):
       # Test cast from uint16 -> uint32
diff --git a/tests/pgle_test.py b/tests/pgle_test.py
index 05154ae0376c..46add86e2784 100644
--- a/tests/pgle_test.py
+++ b/tests/pgle_test.py
@@ -37,6 +37,7 @@
 import jax.numpy as jnp
 from jax.sharding import NamedSharding, PartitionSpec
 import numpy as np
+from jax._src.lib import xla_extension_version  # pylint: disable=g-importing-member
 
 jax.config.parse_flags_with_absl()
 
@@ -89,15 +90,19 @@ def testPGLEProfilerGetFDOProfileLarge(self):
     mesh = jtu.create_mesh((2,), ('x',))
     its = 500
 
+    compiler_options = {
+        'xla_gpu_enable_latency_hiding_scheduler': 'True',
+    }
+    # TODO(b/37664749): Remove this flag once the bug is fixed.
+    if xla_extension_version > 302:
+      compiler_options['xla_gpu_enable_command_buffer'] = ''
+    else:
+      compiler_options['xla_gpu_graph_min_graph_size'] = '100000'
     @partial(
         jax.jit,
         in_shardings=NamedSharding(mesh, PartitionSpec('x')),
         out_shardings=NamedSharding(mesh, PartitionSpec('x')),
-        compiler_options={
-            'xla_gpu_enable_latency_hiding_scheduler': 'True',
-            # TODO(b/37664749): Remove this flag once the bug is fixed.
-            'xla_gpu_enable_command_buffer': '',
-        },
+        compiler_options=compiler_options,
     )
     def f(x):
       agg = x
@@ -127,15 +132,19 @@ def testAutoPgle(self):
     mesh = jtu.create_mesh((2,), ('x',))
 
     with tempfile.TemporaryDirectory() as dump_dir:
+      compile_options = {
+          'xla_gpu_enable_latency_hiding_scheduler': 'True',
+          'xla_dump_to': dump_dir,
+          'xla_gpu_experimental_dump_fdo_profiles': 'True',
+      }
+      # TODO(b/376647494): Remove this flag once the bug is fixed.
+      if xla_extension_version <= 302:
+        compile_options['xla_gpu_graph_min_graph_size'] = '100000'
       @partial(
           jax.jit,
           in_shardings=NamedSharding(mesh, PartitionSpec('x')),
           out_shardings=NamedSharding(mesh, PartitionSpec('x')),
-          compiler_options={
-              'xla_gpu_enable_latency_hiding_scheduler': 'True',
-              'xla_dump_to': dump_dir,
-              'xla_gpu_experimental_dump_fdo_profiles': 'True'
-          },
+          compiler_options=compile_options,
       )
       def f(x):
         return x * 2
@@ -209,15 +218,19 @@ def testAutoPgleWithPersistentCache(self):
     mesh = jtu.create_mesh((2,), ('x',))
 
     with tempfile.TemporaryDirectory() as dump_dir:
+      compiler_options = {
+          'xla_gpu_enable_latency_hiding_scheduler': 'True',
+          'xla_dump_to': dump_dir,
+          'xla_gpu_experimental_dump_fdo_profiles': 'True',
+      }
+      # TODO(b/376647494): Remove this flag once the bug is fixed.
+      if xla_extension_version <= 302:
+        compiler_options['xla_gpu_graph_min_graph_size'] = '100000'
       @partial(
           jax.jit,
           in_shardings=NamedSharding(mesh, PartitionSpec('x')),
           out_shardings=NamedSharding(mesh, PartitionSpec('x')),
-          compiler_options={
-              'xla_gpu_enable_latency_hiding_scheduler': 'True',
-              'xla_dump_to': dump_dir,
-              'xla_gpu_experimental_dump_fdo_profiles': 'True'
-          },
+          compiler_options=compiler_options,
       )
       def f(x):
         agg = x
diff --git a/tests/pjit_test.py b/tests/pjit_test.py
index 63620e5ad0c9..3fcc5c81ad91 100644
--- a/tests/pjit_test.py
+++ b/tests/pjit_test.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 
 from collections import OrderedDict, namedtuple
-import contextlib
 import re
 from functools import partial
 import logging
@@ -64,14 +63,7 @@
 
 config.parse_flags_with_absl()
 
-# Run all tests with 8 CPU devices.
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
-
-def tearDownModule():
-  _exit_stack.close()
+jtu.request_cpu_devices(8)
 
 def create_array(global_shape, global_mesh, mesh_axes, global_data=None,
                  dtype=np.float32):
@@ -5589,11 +5581,11 @@ def test_only_auto(self, mesh):
     @jax.jit
     def f(x, x2):
       y = x * 2
-      self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, None))
+      self.assertEqual(y.sharding.spec, P(None, None))
       z = jnp.sin(y)
-      self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, None))
+      self.assertEqual(z.sharding.spec, P(None, None))
       a = z @ x2
-      self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
+      self.assertEqual(a.sharding.spec, P(None, None))
       return a
 
     out = f(arr, arr.T)
@@ -5626,9 +5618,9 @@ def f(x, x2):
       arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y')))
       arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None)))
       out = f(arr, arr2)
-      self.assertEqual(out.sharding, NamedSharding(mesh2, P('x', None)))
+      self.assertEqual(out.sharding, NamedSharding(mesh2, P('x',)))
       lowered_text = f.lower(arr, arr2).as_text()
-      self.assertTrue(lowered_text.count("unspecified_dims") == 3)
+      self.assertTrue(lowered_text.count("unspecified_dims") == 5)
 
     mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
                             axis_types={mesh_lib.AxisTypes.User: 'y',
@@ -5669,11 +5661,11 @@ def f(x):
           {mesh_lib.AxisTypes.Auto: ('x', 'y')})
       y = sharding_cast(y, y.sharding.with_mesh(auto_mesh))
       with mesh_lib.set_abstract_mesh(auto_mesh):
-        self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
+        self.assertEqual(y.sharding.spec, P(None, None))
         z = jnp.sin(y)
-        self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
+        self.assertEqual(z.sharding.spec, P(None, None))
         a = z @ z.T
-        self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
+        self.assertEqual(a.sharding.spec, P(None, None))
       a = sharding_cast(
           a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None)))
       self.assertEqual(a.sharding.spec, P('x', None))
@@ -5707,7 +5699,7 @@ def f(x):
         self.assertEqual(a.sharding.spec, P(None, None))
       a = sharding_cast(
           a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None)))
-      self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, None))
+      self.assertEqual(a.sharding.spec, P(None, None))
       return a
 
     out = f(arr)
@@ -5729,11 +5721,11 @@ def f(x):
           {mesh_lib.AxisTypes.Auto: 'x', mesh_lib.AxisTypes.User: 'y'})
       y = sharding_cast(y, y.sharding.with_mesh(mix_mesh))
       with mesh_lib.set_abstract_mesh(mix_mesh):
-        self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, 'y'))
+        self.assertEqual(y.sharding.spec, P(None, 'y'))
         z = jnp.sin(y)
-        self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, 'y'))
+        self.assertEqual(z.sharding.spec, P(None, 'y'))
         a = z @ z.T
-        self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
+        self.assertEqual(a.sharding.spec, P(None, None))
       a = sharding_cast(
           a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None)))
       self.assertEqual(a.sharding.spec, P('x', None))
diff --git a/tests/pmap_test.py b/tests/pmap_test.py
index a9de8c896414..795f7d4bf9e8 100644
--- a/tests/pmap_test.py
+++ b/tests/pmap_test.py
@@ -15,7 +15,6 @@
 from __future__ import annotations
 
 from concurrent.futures import ThreadPoolExecutor
-import contextlib
 from functools import partial
 import itertools as it
 import gc
@@ -54,15 +53,8 @@
 from jax._src.util import safe_map, safe_zip
 
 config.parse_flags_with_absl()
+jtu.request_cpu_devices(8)
 
-# Run all tests with 8 CPU devices.
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
-
-def tearDownModule():
-  _exit_stack.close()
 
 compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
 
diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py
index 199b90fe524e..efa877fd3a91 100644
--- a/tests/python_callback_test.py
+++ b/tests/python_callback_test.py
@@ -36,14 +36,7 @@
 import numpy as np
 
 config.parse_flags_with_absl()
-
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(2))
-
-def tearDownModule():
-  _exit_stack.close()
+jtu.request_cpu_devices(2)
 
 map, unsafe_map = util.safe_map, map
 
diff --git a/tests/roofline_test.py b/tests/roofline_test.py
index e5003947181b..aec34ff22a57 100644
--- a/tests/roofline_test.py
+++ b/tests/roofline_test.py
@@ -14,7 +14,6 @@
 from __future__ import annotations
 
 from functools import partial
-import contextlib
 
 from absl.testing import absltest
 from jax.sharding import PartitionSpec as P
@@ -28,6 +27,7 @@
 
 
 jax.config.parse_flags_with_absl()
+jtu.request_cpu_devices(8)
 
 
 def create_inputs(
@@ -45,18 +45,6 @@ def create_inputs(
   return mesh, tuple(arrays)
 
 
-# Run all tests with 8 CPU devices.
-_exit_stack = contextlib.ExitStack()
-
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
-
-
-def tearDownModule():
-  _exit_stack.close()
-
-
 class RooflineTest(jtu.JaxTestCase):
   def test_scalar_collectives(self):
     a_spec = P("z", ("x", "y"))
diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py
index 383746899570..25d46c5add2e 100644
--- a/tests/shard_alike_test.py
+++ b/tests/shard_alike_test.py
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import contextlib
-
 import jax
 import jax.numpy as jnp
 import numpy as np
@@ -24,15 +22,7 @@
 from jax.experimental.shard_map import shard_map
 
 jax.config.parse_flags_with_absl()
-
-# Run all tests with 8 CPU devices.
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
-
-def tearDownModule():
-  _exit_stack.close()
+jtu.request_cpu_devices(8)
 
 
 class ShardAlikeDownstreamTest(jtu.JaxTestCase):
diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py
index 74fdb7a47888..19cc870881cf 100644
--- a/tests/shard_map_test.py
+++ b/tests/shard_map_test.py
@@ -15,7 +15,6 @@
 from __future__ import annotations
 
 from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
-import contextlib
 from functools import partial
 import itertools as it
 import math
@@ -53,6 +52,7 @@
 from jax._src.lib import xla_extension_version  # pylint: disable=g-importing-member
 
 config.parse_flags_with_absl()
+jtu.request_cpu_devices(8)
 
 map, unsafe_map = safe_map, map
 zip, unsafe_zip = safe_zip, zip
@@ -70,16 +70,6 @@ def create_inputs(a_sharding, b_sharding):
   return mesh, m1, m2
 
 
-# Run all tests with 8 CPU devices.
-_exit_stack = contextlib.ExitStack()
-
-def setUpModule():
-  _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
-
-def tearDownModule():
-  _exit_stack.close()
-
-
 class ShardMapTest(jtu.JaxTestCase):
 
   def test_identity(self):
@@ -1925,7 +1915,7 @@ def f(x):
     self.assertAllClose(v*v, f(v), check_dtypes=False)
 
   def test_partial_auto_propagate_through(self):
-    mesh = jtu.create_mesh((2, 2), ('i', 'j'))
+    mesh = jtu.create_mesh((2, 2, 2), ('i', 'j', 'k'))
     sharding = jax.sharding.NamedSharding(mesh, P('i'))
 
     def g(x):
@@ -1943,16 +1933,17 @@ def f(x):
       )(x)
 
     v = jnp.arange(32.0).reshape(4, 8)
-    v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i')))
+    v = jax.device_put(v, sharding)
     if config.use_shardy_partitioner.value:
       self.assertIn(
           'in_shardings=[<@mesh, [{?}, {?}]>]'
-          ' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j"}',
+          ' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j", "k"}',
           f.lower(v).as_text(),
       )
     else:
       self.assertIn(
-          'sharding={devices=[1,1,2,2]<=[2,2]T(1,0) last_tile_dims={manual, replicated}}',
+          'sharding={devices=[1,1,4,2]<=[2,4]T(1,0) last_tile_dims={manual,'
+          ' replicated}}',
           f.lower(v).as_text('hlo'),
       )
     actual = f(v)