From 307ea87a8d0311e8fb7b27cd99475009a6056c4e Mon Sep 17 00:00:00 2001 From: kaixih Date: Tue, 29 Oct 2024 22:30:10 +0000 Subject: [PATCH 01/45] support head size of 256 Test large head size only on hopper+ gpus Test large head size only on cudnn 9.5+ --- jax/_src/cudnn/fused_attention_stablehlo.py | 19 +++++--- tests/fused_attention_stablehlo_test.py | 49 ++++++++++++++++++++- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index a5a605002849..0963d762a1d1 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -348,11 +348,20 @@ def check_is_flash_attention( ) else: # Regular attention conditions - if not ((H <= 128 and H % 8 == 0) and - (not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)): - raise NotImplementedError( - f"Unsupported sequence length Q {T}, KV {S} and head dim {H}." - ) + # Check the head dim. + is_on_hopper = check_compute_capability("9.0") + H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128 + if not (H <= H_max and H % 8 == 0): + raise NotImplementedError( + f"The head dim must be <= {H_max} and a mutiple of 8, " + f"but got {H}." + ) + + # Check patterns with bias, seqlen should be divisible by 2 + if (is_training and has_bias and (T % 2 != 0 or S % 2 != 0)): + raise NotImplementedError( + f"Unsupported sequence length Q {T}, KV {S}." + ) def check_cudnn_version(): # check if cuDNN is installed diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 95ec4ce72eb4..c5cfb9d7daf7 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -254,8 +254,6 @@ def dot_product_attention_fp8(query, key, value, fp8_metas): class DotProductAttentionTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if jax.device_count() < 4: - self.skipTest("Requires more than 4 devices.") try: cudnn_version = check_cudnn_version() except RuntimeError as e: @@ -366,6 +364,8 @@ def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, @jtu.run_on_devices("cuda") def test_sdpa_inference(self): + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") k1, k2, k3 = jax.random.split(jax.random.key(0), 3) query = jax.random.normal( k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) @@ -407,6 +407,8 @@ def test_sdpa_inference(self): @jtu.run_on_devices("cuda") def test_sdpa_var_seq(self): + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") self.skipTest("Skip before fixed.") k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) query = jax.random.normal( @@ -438,6 +440,8 @@ def test_sdpa_var_seq(self): @jtu.run_on_devices("cuda") def test_sdpa_broadcast_bias_and_dbias(self): + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") try: cudnn_version = check_cudnn_version() except RuntimeError as e: @@ -504,6 +508,8 @@ def test_sdpa_broadcast_bias_and_dbias(self): ) @jtu.run_on_devices("cuda") def test_sdpa_dbias(self, batch_size: int): + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") # cuDNN only supports dbias when batch size is 1. If the batch size is # greater, dbias is silently set to all zeros. This test verifies this # behavior for both vmap and regular use cases. @@ -540,6 +546,8 @@ def attn_vjp(x, bias, mask, target_fn): @jtu.run_on_devices("cuda") def test_sdpa_sliding_window_length(self): + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) query = jax.random.normal( k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) @@ -571,8 +579,43 @@ def test_sdpa_sliding_window_length(self): self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + @jtu.run_on_devices("cuda") + def test_sdpa_large_head_size(self): + try: + cudnn_version = check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + return + if cudnn_version < 90500: + self.skipTest("Requires >= cuDNN 9.5.0") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Requires at least Hopper arch") + + B, T, N, H = 2, 64, 2, 256 + bf16 = jnp.bfloat16 + keys = jax.random.split(jax.random.key(0), 4) + query = jax.random.normal(keys[0], (B, T, N, H), dtype=bf16) + key = jax.random.normal(keys[1], (B, T, N, H), dtype=bf16) + value = jax.random.normal(keys[2], (B, T, N, H), dtype=bf16) + grad = jax.random.normal(keys[3], (B, T, N, H), dtype=bf16) + sdpa_train_ans = jax.jit(partial( + sdpa_train, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0) + ) + sdpa_train_rfc = jax.jit(partial( + sdpa_train_ref, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0) + ) + + out_ans, grads_ans = sdpa_train_ans(query, key, value, grad, None, None) + out_ref, grads_ref = sdpa_train_rfc(query, key, value, grad, None, None) + self.assertArraysAllClose(out_ref, out_ans) + self.assertArraysAllClose(grads_ref[0], grads_ans[0]) + self.assertArraysAllClose(grads_ref[1], grads_ans[1]) + self.assertArraysAllClose(grads_ref[2], grads_ans[2]) + @jtu.run_on_devices("cuda") def test_layouts(self): + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") dtype = "bfloat16" B, T, N, H = 4, 1024, 8, 128 S = T @@ -600,6 +643,8 @@ def _cvt_back(x): self.assertArraysAllClose(dv_ref, _cvt_back(dv)) def test_sdpa_utils(self): + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") test_cases = [ (1, 257, 64, 8905, False, True, True), (1, 1024, 64, 8905, False, False, True), From 75b56548e28ec5cfca8eada4a8a892e299da7dee Mon Sep 17 00:00:00 2001 From: liblaf <30631553+liblaf@users.noreply.github.com> Date: Mon, 23 Dec 2024 17:18:45 +0800 Subject: [PATCH 02/45] Fix a typo in documentation for `pinv` function. --- jax/_src/numpy/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 02ca6f6ebab2..ff4e4e07e0e6 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -924,7 +924,7 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None, - :func:`jax.numpy.linalg.inv`: multiplicative inverse of a square matrix. Notes: - :func:`jax.numpy.linalg.prng` differs from :func:`numpy.linalg.prng` in the + :func:`jax.numpy.linalg.pinv` differs from :func:`numpy.linalg.pinv` in the default value of `rcond``: in NumPy, the default is `1e-15`. In JAX, the default is ``10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps``. From 3e7f48114ccbc69c24f2ac01895037e4d9e7da16 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 23 Dec 2024 03:13:51 -0800 Subject: [PATCH 03/45] [pallas:mosaic_gpu] Updated the lowering following the changes in in Mosaic GPU internals PiperOrigin-RevId: 709009048 --- jax/_src/pallas/mosaic_gpu/lowering.py | 39 +++++++++++++++----------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index d492e4e1d86d..8fe7f6a15442 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -797,25 +797,30 @@ def _(step, carry): # Each range is 2 events, each event is 4 bytes. prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4) prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) - module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel( - body, - grid=parallel_grid, - cluster=(), - block=block, - in_shapes=in_structs_gmem, - out_shape=out_structs_gmem, - smem_scratch_shape=( - (*in_structs_smem, *out_structs_smem), - *extra_smem_scratch, - ( - mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps), - rs.barriers, - extra_barriers, + module, out_structs_gmem, _, launch_ctx, scratch_arr = ( + mgpu_core._lower_as_gpu_kernel( + body, + grid=parallel_grid, + cluster=(), + block=block, + in_shapes=in_structs_gmem, + out_shape=out_structs_gmem, + smem_scratch_shape=( + (*in_structs_smem, *out_structs_smem), + *extra_smem_scratch, + ( + mgpu.Barrier( + arrival_count=1, num_barriers=max_concurrent_steps + ), + rs.barriers, + extra_barriers, + ), ), - ), - module_name=name_and_src_info.name, - prof_spec=prof_spec, + module_name=name_and_src_info.name, + prof_spec=prof_spec, + ) ) + mgpu_core._initialize_scratch(launch_ctx, scratch_arr) return LoweringResult( module, parallel_grid, block, out_structs_gmem, prof_ctx From a51d6279410a604b5283de8d8718f959d8f8cbf4 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 23 Dec 2024 05:03:34 -0800 Subject: [PATCH 04/45] [pallas:mosaic_gpu] Reduced duplication between `_ensure_fa` and `_ensure_ir_value` PiperOrigin-RevId: 709030824 --- jax/_src/pallas/mosaic_gpu/lowering.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 8fe7f6a15442..dc2d63f0b6ec 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1787,29 +1787,21 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: if isinstance(x, mgpu.FragmentedArray): assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype) return x - elif isinstance(x, (np.number, np.ndarray, int, float)): - return mgpu.FragmentedArray.splat( - _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)), - (), - is_signed=mgpu_utils.is_signed(dtype), - ) - elif isinstance(x, ir.Value): - if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)): - assert x.type == mgpu_utils.dtype_to_ir_type(dtype) - return mgpu.FragmentedArray.splat(x, (), is_signed=mgpu_utils.is_signed(dtype)) - raise NotImplementedError(f"Unsupported type: {type(x)}") + return mgpu.FragmentedArray.splat( + _ensure_ir_value(x, dtype), (), is_signed=mgpu_utils.is_signed(dtype) + ) def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value: if isinstance(x, ir.Value): assert x.type == mgpu_utils.dtype_to_ir_type(dtype) return x - elif isinstance(x, (np.number, np.ndarray, int, float)): - return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) elif isinstance(x, mgpu.FragmentedArray): + assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype) if isinstance(x.layout, mgpu.WGSplatFragLayout): return x.registers.item() - raise NotImplementedError(f"Unsupported type: {type(x)}") + raise NotImplementedError(f"Unsupported layout: {x.layout}") + return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) def _ir_constant(v: object, t: ir.Type) -> ir.Value: From 83e60a9697ec20023f4e11169edf64e910b93031 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 23 Dec 2024 05:12:11 -0800 Subject: [PATCH 05/45] [pallas:triton] Add support for lowering `int4` load. PiperOrigin-RevId: 709032308 --- jax/_src/pallas/triton/lowering.py | 44 ++++++++++++++++++++++++------ tests/pallas/pallas_test.py | 14 ++++++++++ 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index eb614e3e882f..a87c8990e05d 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1652,8 +1652,8 @@ def _reshape_lowering_rule( ) -def _compute_pointers_from_indices( - root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer +def _compute_offsets_from_indices( + block_info: BlockInfo, nd_indexer: NDIndexer ) -> ir.Value: full_shape = block_info.full_shape_dtype.shape num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape) @@ -1732,7 +1732,14 @@ def _compute_pointers_from_indices( dim_offsets = _mul(dim_offsets, _full(dim_offsets.type, dim_stride)) offsets = _add(offsets, dim_offsets) - return _add(_bcast_to(root_ptr, indexer_shape), offsets) + return offsets + + +def _compute_pointers_from_indices( + root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer +) -> ir.Value: + offsets = _compute_offsets_from_indices(block_info, nd_indexer) + return _add(_bcast_to(root_ptr, nd_indexer.get_indexer_shape()), offsets) @register_lowering(sp.get_p) @@ -1848,14 +1855,20 @@ def _masked_load_lowering_rule( if not tt_dialect.PointerType.isinstance(ptr.type): assert len(ctx.avals_in) == 1 return ptr - ptr = _compute_pointers_from_indices(ptr, block_info, idx) + + offsets = _compute_offsets_from_indices(block_info, idx) + ptr_offsets = offsets + + if block_info.full_shape_dtype.dtype in (jnp.int4, jnp.uint4): + ptr_offsets = _floordiv(offsets, _full(offsets.type, 2), signed=False) + + shape = idx.get_indexer_shape() + ptr = _add(_bcast_to(ptr, shape), ptr_offsets) if mask is not None: - mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape()) + mask = _bcast_to(_ensure_ir_value(mask, mask_aval), shape) if other is not None: - other = _bcast_to( - _ensure_ir_value(other, other_aval), idx.get_indexer_shape() - ) - return _load( + other = _bcast_to(_ensure_ir_value(other, other_aval), shape) + values = _load( ptr, mask=mask, other=other, @@ -1864,6 +1877,19 @@ def _masked_load_lowering_rule( eviction_policy=eviction_policy, ) + if block_info.full_shape_dtype.dtype not in (jnp.int4, jnp.uint4): + return values + + # XLA packs pairs of `[u]int4` values into a `uint8` value with the first + # in the most significant bits and the second in the least significant. + offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False) + in_lsb = _mod(offsets, _full(offsets.type, 2), signed=False) + in_msb = arith_dialect.xori(in_lsb, _full(in_lsb.type, 1)) + shift = _mul(in_msb, _full(in_msb.type, 4)) + shift = _ir_cast(shift, values.type, signed=False) + values = arith_dialect.shrui(values, shift) + return _ir_cast(values, ir.IntegerType.get_signless(4), signed=False) + @register_lowering(sp.swap_p) def _swap_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 6e4928082ac6..bdae8d44b926 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -725,6 +725,20 @@ def dot_kernel(x_ref, y_ref, o_ref): ) self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3) + @parameterized.parameters(jnp.int4, jnp.uint4) + def test_subbyte_load(self, dtype): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("`[u]int4` loads only supported on GPU.") + + x = jnp.arange(-128, 128, dtype=jnp.int8) + + @functools.partial(self.pallas_call, out_shape=x) + def copy_kernel(x_ref, o_ref): + o_ref[()] = x_ref[()].astype(jnp.int8) + + expected = x.astype(dtype).astype(jnp.int8) + self.assertAllClose(copy_kernel(x.astype(dtype)), expected) + class PallasCallInterpretTest(PallasCallTest): INTERPRET = True From 8987867faaffadb145922eca87617a7f0a4aa5f3 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 23 Dec 2024 13:46:25 +0000 Subject: [PATCH 06/45] [mosaic_gpu] Include Mosaic GPU dialect fiels into jaxlib --- jaxlib/mosaic/BUILD | 1 + jaxlib/setup.py | 1 + jaxlib/tools/build_wheel.py | 12 ++++++++++++ 3 files changed, 14 insertions(+) diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 238bf42d9651..62cffd26f829 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -28,6 +28,7 @@ package( py_library( name = "mosaic", deps = [ + "//jaxlib/mosaic/python:gpu_dialect", "//jaxlib/mosaic/python:tpu_dialect", ], ) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 4370aa3176aa..c2efd3d7b7a7 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -83,6 +83,7 @@ def has_ext_modules(self): 'cuda/*', 'cuda/nvvm/libdevice/libdevice*', 'mosaic/*.py', + 'mosaic/dialect/gpu/*.py', 'mosaic/gpu/*.so', 'mosaic/python/*.py', 'mosaic/python/*.so', diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index b46a50961169..4b71bd5de2d8 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -218,6 +218,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): dst_dir=mosaic_python_dir, src_files=[ "__main__/jaxlib/mosaic/python/layout_defs.py", + "__main__/jaxlib/mosaic/python/mosaic_gpu.py", "__main__/jaxlib/mosaic/python/tpu.py", ], ) @@ -225,6 +226,16 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): patch_copy_mlir_import( "__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir ) + mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu" + os.makedirs(mosaic_gpu_dir) + patch_copy_mlir_import( + "__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py", + dst_dir=mosaic_gpu_dir, + ) + patch_copy_mlir_import( + "__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py", + dst_dir=mosaic_gpu_dir, + ) copy_runfiles( dst_dir=jaxlib_dir / "mlir", @@ -316,6 +327,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", From cb10710c926817635c755440fe2d805a30d488f9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 23 Dec 2024 07:33:49 -0800 Subject: [PATCH 07/45] Remove casting from jax.nn.one_hot This change was made after the most recent release, so is safe to remove. Casting float to int potentially changes intentional beavior: e.g. NaN casts to 0. Some downstream users currently use NaN to mark rows which should have no one-hot entry. --- jax/_src/nn/functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 301ebb181056..7566e6bf32fc 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -710,7 +710,6 @@ def one_hot(x: Any, num_classes: int, *, 'jax-nn-one-hot-float-input', f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}", stacklevel=1) - x_arr = x_arr.astype('int32') return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis) From 68ec202d452e24eaf0e9ab330d8fc3058ee85ec5 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 23 Dec 2024 07:34:04 -0800 Subject: [PATCH 08/45] Use the right include for gmock and gtest PiperOrigin-RevId: 709058082 --- jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index 34f6241661d5..e2e1b623b624 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "testing/base/public/gmock.h" -#include "testing/base/public/gunit.h" +#include +#include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" From ccc3a29537aae0d19ba88933a6b675cb9da25077 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 23 Dec 2024 08:44:35 -0800 Subject: [PATCH 09/45] Internal: use a single registry for abstractify APIs --- jax/_src/abstract_arrays.py | 4 -- jax/_src/api.py | 1 - jax/_src/array.py | 1 - jax/_src/core.py | 89 ++++++++++++++------------- jax/_src/earray.py | 1 - jax/_src/export/shape_poly.py | 1 - jax/_src/interpreters/partial_eval.py | 5 +- jax/_src/numpy/lax_numpy.py | 1 - jax/_src/prng.py | 2 - jax/core.py | 4 +- 10 files changed, 50 insertions(+), 59 deletions(-) diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 2502b705b8fa..8ddc33fd8983 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -49,7 +49,6 @@ def masked_array_error(*args, **kwargs): "Use arr.filled() to convert the value to a standard numpy array.") core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error -core.shaped_abstractify_handlers[np.ma.MaskedArray] = masked_array_error def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: @@ -58,7 +57,6 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype)) core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array -core.shaped_abstractify_handlers[np.ndarray] = _make_shaped_array_for_numpy_array def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: @@ -68,7 +66,6 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: for t in numpy_scalar_types: core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar - core.shaped_abstractify_handlers[t] = _make_shaped_array_for_numpy_scalar core.literalable_types.update(array_types) @@ -81,6 +78,5 @@ def _make_abstract_python_scalar(typ, val): for t in dtypes.python_scalar_dtypes: core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t) - core.shaped_abstractify_handlers[t] = partial(_make_abstract_python_scalar, t) core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) diff --git a/jax/_src/api.py b/jax/_src/api.py index 38ba4fd2d381..4bf964a72239 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2564,7 +2564,6 @@ def _sds_aval_mapping(x): x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), weak_type=x.weak_type) core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping -core.shaped_abstractify_handlers[ShapeDtypeStruct] = _sds_aval_mapping @api_boundary diff --git a/jax/_src/array.py b/jax/_src/array.py index 1ce8e7786bb2..2ee8b01c77d4 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1035,7 +1035,6 @@ def _get_aval_array(self): else: return self.aval -core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array core.pytype_aval_mappings[ArrayImpl] = _get_aval_array # TODO(jakevdp) replace this with true inheritance at the C++ level. diff --git a/jax/_src/core.py b/jax/_src/core.py index 5f351bd46883..5d5173f3922a 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -656,6 +656,13 @@ def check_bool_conversion(arr: Array): " is ambiguous. Use a.any() or a.all()") +pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} + +def _str_abstractify(x): + raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type") +pytype_aval_mappings[str] = _str_abstractify + + def _aval_property(name): return property(lambda self: getattr(self.aval, name)) @@ -918,6 +925,8 @@ def unsafe_buffer_pointer(self): aval_property = namedtuple("aval_property", ["fget"]) aval_method = namedtuple("aval_method", ["fun"]) +pytype_aval_mappings[Tracer] = lambda x: x.aval + def check_eval_args(args): for arg in args: if isinstance(arg, Tracer): @@ -1400,45 +1409,51 @@ def check_valid_jaxtype(x): f"Value {x!r} of type {type(x)} is not a valid JAX type") -def _shaped_abstractify_slow(x): - try: - return x if isinstance(x, AbstractValue) else get_aval(x) - except TypeError: - pass - - weak_type = getattr(x, 'weak_type', False) - if hasattr(x, 'dtype'): - dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) - else: - raise TypeError( - f"Cannot interpret value of type {type(x)} as an abstract array; it " - "does not have a dtype attribute") - return ShapedArray(np.shape(x), dtype, weak_type=weak_type) +# We have three flavors of abstractification APIs here which each used to have +# their own separate implementation. Now they're effectively the same, with the +# following differences: +# +# - abstractify returns avals for non-traced array-like objects. +# - get_aval is like abstractify, but also accepts tracers. +# - shaped_abstractify is like get_aval, but also accepts duck-typed arrays. +# +# TODO(jakevdp): can these be unified further? -# TODO(jakevdp): deduplicate this with abstractify def shaped_abstractify(x): - # This was originally api_util.shaped_abstractify; temporarily moved - # here in order to facilitate combining it with abstractify. - handler = shaped_abstractify_handlers.get(type(x), None) - return handler(x) if handler is not None else _shaped_abstractify_slow(x) + typ = type(x) + if (aval_fn := pytype_aval_mappings.get(typ)): # fast path + return aval_fn(x) + for t in typ.__mro__[1:]: + if (aval_fn := pytype_aval_mappings.get(t)): + return aval_fn(x) + if isinstance(x, AbstractValue): + return x + if hasattr(x, '__jax_array__'): + return shaped_abstractify(x.__jax_array__()) + if hasattr(x, 'dtype'): + return ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False)) + raise TypeError( + f"Cannot interpret value of type {typ} as an abstract array; it " + "does not have a dtype attribute") def abstractify(x): - for typ in type(x).__mro__: - aval_fn = pytype_aval_mappings.get(typ) - if aval_fn: return aval_fn(x) - if hasattr(x, '__jax_array__'): - return abstractify(x.__jax_array__()) - raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") + if isinstance(x, Tracer): + raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") + return get_aval(x) def get_aval(x): - if isinstance(x, Tracer): - return x.aval - else: - return abstractify(x) + typ = type(x) + if (aval_fn := pytype_aval_mappings.get(typ)): # fast path + return aval_fn(x) + for t in typ.__mro__[1:]: + if (aval_fn := pytype_aval_mappings.get(t)): + return aval_fn(x) + if hasattr(x, '__jax_array__'): + return get_aval(x.__jax_array__()) + raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type") -get_type = get_aval def is_concrete(x): return to_concrete_value(x) is not None @@ -1831,13 +1846,6 @@ def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) -pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} -shaped_abstractify_handlers: dict[Any, Callable[[Any], AbstractValue]] = {} - -def _str_abstractify(x): - raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type") -pytype_aval_mappings[str] = _str_abstractify -shaped_abstractify_handlers[str] = _str_abstractify class DArray: _aval: DShapedArray @@ -1894,7 +1902,6 @@ def _darray_aval(x): return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) pytype_aval_mappings[DArray] = _darray_aval -shaped_abstractify_handlers[DArray] = _darray_aval @dataclass(frozen=True) @@ -1924,11 +1931,10 @@ def __init__(self, aval, buf): aval = property(lambda self: self._aval) shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) - def __getitem__(self, idx): return get_aval(self)._getitem(self, idx) - def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x) + def __getitem__(self, idx): return self._aval._getitem(self, idx) + def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) pytype_aval_mappings[MutableArray] = lambda x: x._aval -shaped_abstractify_handlers[MutableArray] = lambda x: x._aval def mutable_array(init_val): return mutable_array_p.bind(init_val) @@ -1984,7 +1990,6 @@ def __init__(self, buf): def block_until_ready(self): self._buf.block_until_ready() pytype_aval_mappings[Token] = lambda _: abstract_token -shaped_abstractify_handlers[Token] = lambda _: abstract_token # TODO(dougalm): Deprecate these. They're just here for backwards compat. diff --git a/jax/_src/earray.py b/jax/_src/earray.py index 25c2bc2bf7ec..98a0a863981e 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -115,7 +115,6 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics): return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler -core.shaped_abstractify_handlers[EArray] = lambda self: self.aval core.pytype_aval_mappings[EArray] = lambda x: x.aval xla.canonicalize_dtype_handlers[EArray] = lambda x: x tree_util.dispatch_registry.register_node( diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index b82890cab682..5462723c8335 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -1205,7 +1205,6 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool: f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}") core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval -core.shaped_abstractify_handlers[_DimExpr] = _DimExpr._get_aval dtypes._weak_types.append(_DimExpr) def _convertible_to_int(p: DimSize) -> bool: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 154b5e972682..ac0ae3a13967 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1569,10 +1569,7 @@ def get_referent(self): val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) return self if val is None else get_referent(val) - -def _dynamic_jaxpr_tracer_shaped_abstractify(x): - return x.aval -core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify +core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: sentinel = object() diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 83ede1e48c3f..259c47948a9d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -192,7 +192,6 @@ def __instancecheck__(self, instance: Any) -> bool: def _abstractify_scalar_meta(x): raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.") core.pytype_aval_mappings[_ScalarMeta] = _abstractify_scalar_meta -core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: meta = _ScalarMeta(np_scalar_type.__name__, (object,), diff --git a/jax/_src/prng.py b/jax/_src/prng.py index d29bad5d5304..4f43b54bb478 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -461,8 +461,6 @@ def __hash__(self) -> int: core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval -core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval') - xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x diff --git a/jax/core.py b/jax/core.py index 54bbdac51c87..ef1551b2f1ba 100644 --- a/jax/core.py +++ b/jax/core.py @@ -128,7 +128,7 @@ _src_core.escaped_tracer_error), "extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.", _src_core.extend_axis_env_nd), - "get_type": ("jax.core.get_type is deprecated.", _src_core.get_type), + "get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval), "get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent), "join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects), "leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.", @@ -212,7 +212,7 @@ escaped_tracer_error = _src_core.escaped_tracer_error extend_axis_env_nd = _src_core.extend_axis_env_nd full_lower = _src_core.full_lower - get_type = _src_core.get_type + get_type = _src_core.get_aval get_referent = _src_core.get_referent jaxpr_as_fun = _src_core.jaxpr_as_fun join_effects = _src_core.join_effects From 23965b74f608ce268bac9b2195b7a6ab86308f15 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 23 Dec 2024 09:24:33 -0800 Subject: [PATCH 10/45] Update XLA dependency to use revision http://github.com/openxla/xla/commit/7e03b71f8abdf58bf6ec966821619e8dcf76175a. PiperOrigin-RevId: 709080323 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index c88e78f1c053..06d0f7c47ecb 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "dc7aaf834a0bb5a543f6cf98626284783a4a921c" -XLA_SHA256 = "eda76cce64b33c00139120d6b4d4c2167d9f99dc957da54225a67ddb7ec7cb23" +XLA_COMMIT = "7e03b71f8abdf58bf6ec966821619e8dcf76175a" +XLA_SHA256 = "eff3f8bf78c1b254b72502973047937652569c84bfb3b4d753049c07afdca7ed" def repo(): tf_http_archive( From c206ae7fe8146905da73e8047cfefbf5f00cb18c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 23 Dec 2024 09:39:45 -0800 Subject: [PATCH 11/45] changelog: link to api compatibility & python version docs --- CHANGELOG.md | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e86dece51013..db9d05088af5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html). For the changes specific to the experimental Pallas APIs, see {ref}`pallas-changelog`. +JAX follows Effort-based versioning; for a discussion of this and JAX's API +compatibility policy, refer to {ref}`api-compatibility`. For the Python and +NumPy version support policy, refer to {ref}`version-support-policy`. +