From 538c64b15ad78d6a680ceadc883c2ad2e5153406 Mon Sep 17 00:00:00 2001 From: Ram Rachum Date: Tue, 23 Jul 2024 10:19:15 +0300 Subject: [PATCH 001/702] Use pathlib for profiler log_dir --- jax/_src/profiler.py | 35 ++++++++++++++++------------------- jax/collect_profile.py | 4 ++-- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 1330a21486b4..4752106c7688 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -17,12 +17,12 @@ from collections.abc import Callable from contextlib import contextmanager from functools import wraps -import glob import gzip import http.server import json import logging import os +import pathlib import socketserver import threading from typing import Any @@ -88,7 +88,7 @@ def reset(self): _profile_state = _ProfileState() -def start_trace(log_dir, create_perfetto_link: bool = False, +def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, create_perfetto_trace: bool = False) -> None: """Starts a profiler trace. @@ -129,20 +129,18 @@ def start_trace(log_dir, create_perfetto_link: bool = False, _profile_state.create_perfetto_link = create_perfetto_link _profile_state.create_perfetto_trace = ( create_perfetto_trace or create_perfetto_link) - _profile_state.log_dir = str(log_dir) + _profile_state.log_dir = pathlib.Path(log_dir) -def _write_perfetto_trace_file(log_dir): +def _write_perfetto_trace_file(log_dir: os.PathLike | str): # Navigate to folder with the latest trace dump to find `trace.json.jz` - curr_path = os.path.abspath(log_dir) - root_trace_folder = os.path.join(curr_path, "plugins", "profile") - trace_folders = [os.path.join(root_trace_folder, trace_folder) for - trace_folder in os.listdir(root_trace_folder)] - latest_folder = max(trace_folders, key=os.path.getmtime) - trace_jsons = glob.glob(os.path.join(latest_folder, "*.trace.json.gz")) - if len(trace_jsons) != 1: - raise ValueError(f"Invalid trace folder: {latest_folder}") - trace_json, = trace_jsons + trace_folders = (pathlib.Path(log_dir).absolute() / "plugins" / "profile").iterdir() + latest_trace_folder = max(trace_folders, key=os.path.getmtime) + trace_jsons = latest_trace_folder.glob("*.trace.json.gz") + try: + trace_json, = trace_jsons + except ValueError as value_error: + raise ValueError(f"Invalid trace folder: {latest_trace_folder}") from value_error logger.info("Loading trace.json.gz and removing its metadata...") # Perfetto doesn't like the `metadata` field in `trace.json` so we remove @@ -152,8 +150,7 @@ def _write_perfetto_trace_file(log_dir): with gzip.open(trace_json, "rb") as fp: trace = json.load(fp) del trace["metadata"] - filename = "perfetto_trace.json.gz" - perfetto_trace = os.path.join(latest_folder, filename) + perfetto_trace = latest_trace_folder / "perfetto_trace.json.gz" logger.info("Writing perfetto_trace.json.gz...") with gzip.open(perfetto_trace, "w") as fp: fp.write(json.dumps(trace).encode("utf-8")) @@ -173,11 +170,11 @@ def do_GET(self): def do_POST(self): self.send_error(404, "File not found") -def _host_perfetto_trace_file(path): +def _host_perfetto_trace_file(path: os.PathLike | str): # ui.perfetto.dev looks for files hosted on `127.0.0.1:9001`. We set up a # TCP server that is hosting the `perfetto_trace.json.gz` file. port = 9001 - orig_directory = os.path.abspath(os.getcwd()) + orig_directory = pathlib.Path.cwd() directory, filename = os.path.split(path) try: os.chdir(directory) @@ -203,7 +200,7 @@ def stop_trace(): if _profile_state.profile_session is None: raise RuntimeError("No profile started") sess = _profile_state.profile_session - sess.export(sess.stop(), _profile_state.log_dir) + sess.export(sess.stop(), str(_profile_state.log_dir)) if _profile_state.create_perfetto_trace: abs_filename = _write_perfetto_trace_file(_profile_state.log_dir) if _profile_state.create_perfetto_link: @@ -227,7 +224,7 @@ def stop_and_get_fdo_profile() -> bytes | str: @contextmanager -def trace(log_dir, create_perfetto_link=False, create_perfetto_trace=False): +def trace(log_dir: os.PathLike | str, create_perfetto_link=False, create_perfetto_trace=False): """Context manager to take a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python diff --git a/jax/collect_profile.py b/jax/collect_profile.py index a7777085ce90..d1309e0c5bca 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -66,7 +66,7 @@ help="Profiler Python tracer level", type=int) def collect_profile(port: int, duration_in_ms: int, host: str, - log_dir: str | None, host_tracer_level: int, + log_dir: os.PathLike | str | None, host_tracer_level: int, device_tracer_level: int, python_tracer_level: int, no_perfetto_link: bool): options = profiler.ProfilerOptions( @@ -97,7 +97,7 @@ def collect_profile(port: int, duration_in_ms: int, host: str, fp.write(result.encode("utf-8")) if not no_perfetto_link: - path = jax_profiler._write_perfetto_trace_file(str(log_dir_)) + path = jax_profiler._write_perfetto_trace_file(log_dir_) jax_profiler._host_perfetto_trace_file(path) def main(args): From 181d17e9bef0fc3bae03847e6f5ae938ce17b4f6 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Thu, 1 Aug 2024 17:18:55 +0000 Subject: [PATCH 002/702] Faster MHA backwards pass --- jax/experimental/pallas/ops/gpu/attention.py | 272 +++++++++++-------- tests/pallas/gpu_ops_test.py | 4 +- 2 files changed, 164 insertions(+), 112 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index a0221ebf6f74..647310dcaacc 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -70,11 +70,6 @@ def body(start_k, carry): curr_k_slice = pl.dslice(start_k * block_k, block_k) k = pl.load(k_ref, (curr_k_slice, slice(None))) - kv_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (curr_k_slice,)) - ) qk = pl.dot(q, k.T) # [block_q, block_k] if sm_scale != 1.: qk *= sm_scale # [block_q, block_k] @@ -87,6 +82,7 @@ def body(start_k, carry): if causal or segment_ids_ref is not None: mask = None if segment_ids_ref is not None: + kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,)) mask = segment_mask(q_segment_ids, kv_segment_ids) if causal: span_q = start_q * block_q + jnp.arange(block_q) @@ -354,6 +350,9 @@ def _preprocess_backward(out, do, l, block_q: int, return do_scaled, delta +# This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence +# length. +# Inspired by the triton tutorial: https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py def mha_backward_kernel( # Inputs q_ref, @@ -365,7 +364,6 @@ def mha_backward_kernel( l_ref, m_ref, delta_ref, - _, # Outputs dq_ref, dk_ref, @@ -373,84 +371,141 @@ def mha_backward_kernel( *, sm_scale: float, causal: bool, - block_q: int, + block_q1: int, + block_k1: int, + block_q2: int, + block_k2: int, block_d: int, - block_k: int, ): del out_ref, l_ref # Not needed seq_len = q_ref.shape[0] - def outer_loop(start_k, _): - - dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) - dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) - k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) - v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) - span_k = start_k * block_k + jnp.arange(block_k) - kv_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (pl.ds(start_k * block_k, block_k),)) - ) - - def inner_loop(start_q, carry): - dv, dk = carry - q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) - qk = pl.dot(q, k.T) - qk = qk.astype(q_ref.dtype) - qk = qk.astype(jnp.float32) - if sm_scale != 1.0: - qk *= sm_scale - - q_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (pl.ds(start_q * block_q, block_q),)) - ) - - if causal or segment_ids_ref is not None: - mask = None - if segment_ids_ref is not None: - mask = segment_mask(q_segment_ids, kv_segment_ids) - - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - causal_mask = span_q[:, None] >= span_k[None, :] - mask = ( - causal_mask - if mask is None - else jnp.logical_and(mask, causal_mask) - ) - qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) - - m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) - p = jnp.exp(qk - m[:, None]) - do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) - dv = dv + pl.dot(p.astype(do.dtype).T, do) - di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) - dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] - dp = dp + pl.dot(do, v.T) - ds = p * dp - if sm_scale != 1.0: - ds = ds * sm_scale - dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) - dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q), - slice(None)), eviction_policy="evict_last") - dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) - pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), - slice(None)), dq, eviction_policy="evict_last") - return dv, dk - if causal: - lower_bound = lax.div(start_k * block_k, block_q) - else: - lower_bound = 0 - dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop, - (dv, dk)) - pl.store(dv_ref, (pl.ds(start_k * block_k, block_k), - slice(None)), dv.astype(dv_ref.dtype)) - pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), - slice(None)), dk.astype(dk_ref.dtype)) - lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None) + # Scan #1: dK and dV + # 1. Load a block of K and V of size (block_k1, head_dim) in SMEM. + # 2. Iterate through Q in chunks of (block_q1, head_dim) to accumulate + # dK and dV. + start_k = pl.program_id(2) + curr_k_slice = pl.dslice(start_k * block_k1, block_k1) + + dv = jnp.zeros([block_k1, block_d], dtype=jnp.float32) + dk = jnp.zeros([block_k1, block_d], dtype=jnp.float32) + + v = pl.load(v_ref, (curr_k_slice, slice(None))) + k = pl.load(k_ref, (curr_k_slice, slice(None))) + span_k = start_k * block_k1 + jnp.arange(block_k1) + kv_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (curr_k_slice,)) + ) + + def inner_loop_dkdv(start_q, carry): + dv, dk = carry + curr_q_slice = pl.dslice(start_q * block_q1, block_q1) + + q = pl.load(q_ref, (curr_q_slice, slice(None))) + qk = pl.dot(q, k.T) + if sm_scale != 1.0: + qk *= sm_scale + + if causal or segment_ids_ref is not None: + mask = None + if segment_ids_ref is not None: + q_segment_ids = pl.load(segment_ids_ref, (curr_q_slice,)) + mask = segment_mask(q_segment_ids, kv_segment_ids) + + if causal: + span_q = start_q * block_q1 + jnp.arange(block_q1) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = ( + causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + ) + qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + + m = pl.load(m_ref, (curr_q_slice,)) + di = pl.load(delta_ref, (curr_q_slice,)) + do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + + p = jnp.exp(qk - m[:, None]) + dv = dv + pl.dot(p.astype(do.dtype).T, do) + dp = jnp.zeros((block_q1, block_k1), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) + + return dv, dk + + lower_bound = lax.div(start_k * block_k1, block_q1) if causal else 0 + dv, dk = lax.fori_loop( + lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk) + ) + pl.store(dv_ref, (curr_k_slice, slice(None)), dv.astype(dv_ref.dtype)) + pl.store(dk_ref, (curr_k_slice, slice(None)), dk.astype(dk_ref.dtype)) + + del dv, dk + + # Scan #2: dQ + # 1. Load a block of Q of size (block_q2, head_dim) in SMEM. + # 2. Iterate through K and V in chunks of (block_k2, head_dim) to + # accumulate dQ. + start_q = pl.program_id(2) + curr_q_slice = pl.ds(start_q * block_q2, block_q2) + span_q = start_q * block_q2 + jnp.arange(block_q2) + dq = jnp.zeros([block_q2, block_d], dtype=jnp.float32) + + q = pl.load(q_ref, (curr_q_slice, slice(None))) + q_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (curr_q_slice,)) + ) + m = pl.load(m_ref, (curr_q_slice,)) + do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + di = pl.load(delta_ref, (curr_q_slice,)) + + def inner_loop_dq(start_k, dq): + curr_k_slice = pl.dslice(start_k * block_k2, block_k2) + k = pl.load(k_ref, (curr_k_slice, slice(None))) + v = pl.load(v_ref, (curr_k_slice, slice(None))) + + qk = pl.dot(q, k.T) + if sm_scale != 1.0: + qk *= sm_scale + + if causal or segment_ids_ref is not None: + mask = None + if segment_ids_ref is not None: + kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,)) + mask = segment_mask(q_segment_ids, kv_segment_ids) + + if causal: + span_k = start_k * block_k2 + jnp.arange(block_k2) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = ( + causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + ) + qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + + p = jnp.exp(qk - m[:, None]) + dp = jnp.zeros((block_q2, block_k2), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + + dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) + + return dq + + if causal: + upper_bound = lax.div((start_q + 1) * block_q2, block_k2) + else: + upper_bound = pl.cdiv(seq_len, block_k2) + + dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) + pl.store(dq_ref, (curr_q_slice, slice(None)), dq.astype(dq_ref.dtype)) def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, @@ -473,75 +528,72 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) - # We accumulate into dq so we need to initialize it to zeros. - dq = jnp.zeros(q.shape, jnp.float32) out_shapes = [ - jax.ShapeDtypeStruct(dq.shape, dq.dtype), - jax.ShapeDtypeStruct(k.shape, k.dtype), - jax.ShapeDtypeStruct(v.shape, v.dtype), + jax.ShapeDtypeStruct(q.shape, q.dtype), + jax.ShapeDtypeStruct(k.shape, k.dtype), + jax.ShapeDtypeStruct(v.shape, v.dtype), ] in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) - ), - pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), - pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), + pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] - input_output_aliases = {8: 0} else: - in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda j, k: (j, 0))) - input_output_aliases = {9: 0} - grid = (batch_size, num_heads) - # TODO(sharadmv): figure out why num_warps=8 doesn't work! + in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda i, j, _: (i, 0))) + + grid = (batch_size, num_heads, pl.cdiv(seq_len, block_k)) num_warps = 8 dq, dk, dv = pl.pallas_call( functools.partial( mha_backward_kernel, - block_q=block_q, - block_d=head_dim, - block_k=block_k, sm_scale=sm_scale, causal=causal, + block_q1=block_q, + block_k1=block_k, + block_q2=block_q, + block_k2=block_k, + block_d=head_dim, ), - grid=grid, out_shape=out_shapes, in_specs=in_specs, + grid=grid, out_specs=[ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), + lambda i, j, _: (i, 0, j, 0), # dq ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), + lambda i, j, _: (i, 0, j, 0), # dk ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), + lambda i, j, _: (i, 0, j, 0), # dv ), ], name="mha_backward", debug=debug, interpret=interpret, - compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)), - input_output_aliases=input_output_aliases, - )(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq) + compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=2)), + )(q, k, v, segment_ids, out, do_scaled, l, m, delta) else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv, None diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index e7b7a4daac3d..4ab957a3a2a1 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -252,8 +252,8 @@ def impl(q, k, v): (1, 384, 1, 32, False, False), (2, 384, 2, 32, False, True), (2, 384, 2, 32, False, False), - # TODO(b/283035396): (1, 384, 1, 32, True, True), - # TODO(b/283035396): (2, 384, 2, 32, True, True), + (1, 384, 1, 32, True, True), + (2, 384, 2, 32, True, True), ] ] ) From 6ff6501aa229871a21bbe9ed04673320126dcf34 Mon Sep 17 00:00:00 2001 From: kaixih Date: Thu, 1 Aug 2024 19:39:34 +0000 Subject: [PATCH 003/702] Init commit --- jax/_src/cudnn/fused_attention_stablehlo.py | 24 ++++++----- jax/_src/nn/functions.py | 46 ++++++++++++--------- tests/nn_test.py | 22 ++++++---- 3 files changed, 54 insertions(+), 38 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 262b8e2c140a..51a86fdcb978 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -618,11 +618,12 @@ def _dot_product_attention_fwd_batcher( *_, S, _, _ = key.shape B = math.prod(Bs) has_bias, _ = variadic_args + original_shape = query.shape # reshape to 4D shape query = jnp.reshape(query, (B,) + query.shape[-3:]) key = jnp.reshape(key, (B,) + key.shape[-3:]) value = jnp.reshape(value, (B,) + key.shape[-3:]) - if has_bias: + if has_bias and batch_dims[3] is not None: bias = jnp.reshape(bias, (B, N, T, S)) if has_padding(mask_type): q_seqlen = jnp.reshape(q_seqlen, (B, )) @@ -635,7 +636,7 @@ def _dot_product_attention_fwd_batcher( # reshape to original shape output = outputs[0] - output = jnp.reshape(output, query.shape) + output = jnp.reshape(output, original_shape) if is_training: activation = outputs[1] activation = jnp.reshape(activation, (*Bs, N, T)) @@ -660,11 +661,15 @@ def _dot_product_attention_bwd_batcher( *_, S, _, _ = key.shape B = math.prod(Bs) has_bias, has_dbias = variadic_args + original_query_shape = query.shape + original_key_shape = key.shape + original_value_shape = value.shape + original_bias_shape = bias.shape if has_bias else None # reshape to 4D shape query = jnp.reshape(query, (B,) + query.shape[-3:]) key = jnp.reshape(key, (B,) + key.shape[-3:]) value = jnp.reshape(value, (B,) + key.shape[-3:]) - if has_bias: + if has_bias and batch_dims[3] is not None: bias = jnp.reshape(bias, (B, N, T, S)) if has_padding(mask_type): q_seqlen = jnp.reshape(q_seqlen, (B, )) @@ -681,15 +686,14 @@ def _dot_product_attention_bwd_batcher( mask_type=mask_type, layout=layout, ) - grad_query, grad_key, grad_value = grads[:3] # reshape to original shape - grad_query = jnp.reshape(grad_query, query.shape) - grad_key = jnp.reshape(grad_key, key.shape) - grad_value = jnp.reshape(grad_value, value.shape) + grads[0] = jnp.reshape(grads[0], original_query_shape) + grads[1] = jnp.reshape(grads[1], original_key_shape) + grads[2] = jnp.reshape(grads[2], original_value_shape) if has_dbias: - grad_bias = grads[3] - grad_bias = jnp.reshape(grad_bias, bias.shape) - return grads + (grad_bias,), out_bdims + (query_bdim,) + assert has_bias + grads[3] = jnp.reshape(grads[3], original_bias_shape) + out_bdims += (batch_dims[3],) return grads, out_bdims # custom partitioning diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 32d543a27966..5d7c941615ea 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -853,9 +853,9 @@ def dot_product_attention( query: ArrayLike, key: ArrayLike, value: ArrayLike, - *, bias: ArrayLike | None = None, mask: ArrayLike | None = None, + *, scale: float | None = None, is_causal: bool = False, implementation: Literal['xla', 'cudnn'] | None = None) -> Array: @@ -882,20 +882,20 @@ def dot_product_attention( G = number of groups, which equals to N // K Args: - query: query array; shape :code:`(BTNH)` - key: key array: shape :code:`(BSKH)`. When `K` equals `N`, multi-headed - attention (MHA: https://arxiv.org/abs/1706.03762) is performed. Otherwise, - grouped query attention (GQA: https://arxiv.org/abs/2305.13245) is performed - if `N` is a multiple of `K`, and multi-query attention (MQA: - https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case - of GQA). + query: query array; shape :code:`(BTNH|TNH)` + key: key array: shape :code:`(BSKH|SKH)`. When `K` equals `N`, multi-headed + attention (MHA https://arxiv.org/abs/1706.03762) is performed. Otherwise, + grouped query attention (GQA https://arxiv.org/abs/2305.13245) is + performed if `N` is a multiple of `K`, and multi-query attention (MQA + https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case + of GQA). value: value array, should have the same shape as the `key` array. bias: optional, bias array to be added to logits; The shape must be 4D and - be broadcastable to :code:`(BNTS)`. + be broadcastable to :code:`(BNTS|NTS)`. mask: optional, mask array used to filter out logits. It is a boolean mask where `True` indicates the element should take part in attention. For an additive mask, users should pass it to `bias`. The shape must be 4D and be - broadcastable to :code:`(BNTS)`. + broadcastable to :code:`(BNTS|NTS)`. scale: scale for the logits. If None, the scale will be set to 1 divided by the square root of query's head dimension (i.e. H). is_causal: If true, causal attention will be applied. Note, some @@ -912,6 +912,18 @@ def dot_product_attention( Returns: An array of the attention output with the same shape as :code:`query`. """ + original_shape = jnp.asarray(query).shape + def _preprocess_array(t): + if t is None: + return t + t = jnp.asarray(t) + return t[None, ...] if t.ndim == 3 else t + query = _preprocess_array(query) + key = _preprocess_array(key) + value = _preprocess_array(value) + bias = _preprocess_array(bias) + mask = _preprocess_array(mask) + def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: if t.ndim != len(shape): raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}") @@ -919,12 +931,6 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: if shape[i] != -1 and t.shape[i] != shape[i]: raise ValueError(f"{name} shape should be {shape}: but got {t.shape}") - query = jnp.asarray(query) - key = jnp.asarray(key) - value = jnp.asarray(value) - bias = bias if bias is None else jnp.asarray(bias) - mask = mask if mask is None else jnp.asarray(mask) - B, S, K, H = key.shape _check_has_shape(value, [B, S, K, H], 'value') _check_has_shape(query, [B, -1, -1, H], 'query') @@ -944,19 +950,21 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: match implementation: case 'xla': - return _dot_product_attention_xla( + out = _dot_product_attention_xla( query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, ) case 'cudnn': mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK - return cudnn_dot_product_attention( + out = cudnn_dot_product_attention( query, key, value, bias, mask, scale=scale_val, mask_type=mask_type ) case None: # TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select # best backend. - return _dot_product_attention_xla( + out = _dot_product_attention_xla( query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, ) case _: raise ValueError(f"Unsupported implementation option: {implementation}") + + return jnp.reshape(out, original_shape) diff --git a/tests/nn_test.py b/tests/nn_test.py index 455f04e5fd12..802ed1b2f1e2 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -57,10 +57,11 @@ class NNFunctionsTest(jtu.JaxTestCase): use_bias=[False, True], causal_mode=[None, 'is_causal', 'is_mask'], group_num=[1, 2, 4], + use_vmap=[False, True], impl=['xla', 'cudnn'], ) def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, - group_num, impl): + group_num, use_vmap, impl): if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: @@ -84,15 +85,17 @@ def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) if impl == 'cudnn': - lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias=bias, mask=causal_mask) + lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask) hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) self.assertIn('__cudnn$fmha', hlo) + if use_vmap: + sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V - out_ref = sdpa_ref(Q, K_ref, V_ref, bias=bias, mask=causal_mask) + out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask) - out_ans = sdpa_ans(Q, K, V, bias=bias, mask=causal_mask) + out_ans = sdpa_ans(Q, K, V, bias, causal_mask) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) @parameterized.product( @@ -100,10 +103,11 @@ def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, use_bias=[False, True], causal_mode=[None, 'is_causal', 'is_mask'], group_num=[1, 2, 4], + use_vmap=[False, True], impl=['xla', 'cudnn'], ) def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, - group_num, impl): + group_num, use_vmap, impl): if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: @@ -127,16 +131,16 @@ def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - fn_ref = lambda q, k, v, b, m: sdpa_ref(q, k, v, bias=b, mask=m) - _, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask) + _, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, causal_mask) dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad) if G != 1: dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) - fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m) - _, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask) + if use_vmap: + sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) + _, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask) dQ_ans, dK_ans, dV_ans, dbias_ans, _ = sdpa_vjp_ans(grad) if impl == 'cudnn': From 9f9e3e6d4e7a2955bca3b8d98ecd3c863700179b Mon Sep 17 00:00:00 2001 From: kaixih Date: Fri, 2 Aug 2024 19:55:28 +0000 Subject: [PATCH 004/702] Address comments --- jax/_src/nn/functions.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 5d7c941615ea..4aabf9521340 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -912,17 +912,19 @@ def dot_product_attention( Returns: An array of the attention output with the same shape as :code:`query`. """ - original_shape = jnp.asarray(query).shape - def _preprocess_array(t): - if t is None: - return t + output_shape = jnp.asarray(query).shape + def _ensure_4d(t): t = jnp.asarray(t) - return t[None, ...] if t.ndim == 3 else t - query = _preprocess_array(query) - key = _preprocess_array(key) - value = _preprocess_array(value) - bias = _preprocess_array(bias) - mask = _preprocess_array(mask) + dims_to_add = 4 - t.ndim + if dims_to_add > 0: + return jnp.expand_dims(t, axis=tuple(range(dims_to_add))) + return t + + query = _ensure_4d(query) + key = _ensure_4d(key) + value = _ensure_4d(value) + bias = _ensure_4d(bias) if bias is not None else None + mask = _ensure_4d(mask) if mask is not None else None def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: if t.ndim != len(shape): @@ -967,4 +969,4 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: case _: raise ValueError(f"Unsupported implementation option: {implementation}") - return jnp.reshape(out, original_shape) + return jnp.reshape(out, output_shape) From 1d425b2d30fec2d60d579d7f1bc17952db889a8d Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 5 Aug 2024 17:51:47 +0100 Subject: [PATCH 005/702] Small tweaks to custom_vmap UI. I'm working on some extensions to `custom_vmap` and came across these small UI improvements (I think!). This includes the two changes: 1. A weakening of the `kwargs` check to be consistent with the one in `custom_vjp`/`custom_jvp`, and 2. An improved error message when `def_vmap` isn't called. --- jax/_src/api_util.py | 11 +++++++++++ jax/_src/custom_batching.py | 9 +++++++-- jax/_src/custom_derivatives.py | 21 +++++---------------- tests/api_test.py | 26 ++++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 18 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index dd1cdcbe6bb8..481dec0065a5 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -556,6 +556,17 @@ def _assert_no_intersection(static_argnames, donate_argnames): f"{out} appear in both static_argnames and donate_argnames") +def resolve_kwargs(fun: Callable, args, kwargs): + if isinstance(fun, partial): + fun = lambda *args, **kwargs: None + ba = inspect.signature(fun).bind(*args, **kwargs) + ba.apply_defaults() + if ba.kwargs: + raise TypeError("keyword arguments could not be resolved to positions") + else: + return ba.args + + def _dtype(x): try: return dtypes.result_type(x) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 4d41849b75d3..4b859e910165 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -27,7 +27,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util -from jax._src.api_util import flatten_fun_nokwargs +from jax._src.api_util import flatten_fun_nokwargs, resolve_kwargs from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters.batching import not_mapped @@ -64,7 +64,12 @@ def def_vmap(self, vmap_rule: Callable) -> Callable: @traceback_util.api_boundary def __call__(self, *args, **kwargs): - assert not kwargs + fun_name = getattr(self.fun, "__name__", str(self.fun)) + if not self.vmap_rule: + raise AttributeError( + f"No batching rule defined for custom_vmap function {fun_name} " + "using def_vmap.") + args = resolve_kwargs(self.fun, args, kwargs) args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index d27b0efc7e5e..bc9f7a687dcb 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -17,7 +17,6 @@ from collections.abc import Callable, Sequence import dataclasses from functools import update_wrapper, reduce, partial, wraps -import inspect from typing import Any, Generic, TypeVar from jax._src import config @@ -30,7 +29,8 @@ from jax._src import traceback_util from jax._src.ad_util import ( stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) -from jax._src.api_util import argnums_partial, flatten_fun_nokwargs +from jax._src.api_util import ( + argnums_partial, flatten_fun_nokwargs, resolve_kwargs) from jax._src.core import raise_to_shaped from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import ad @@ -56,17 +56,6 @@ ### util -def _resolve_kwargs(fun, args, kwargs): - if isinstance(fun, partial): - # functools.partial should have an opaque signature. - fun = lambda *args, **kwargs: None - ba = inspect.signature(fun).bind(*args, **kwargs) - ba.apply_defaults() - if ba.kwargs: - raise TypeError("keyword arguments could not be resolved to positions") - else: - return ba.args - def _initial_style_jaxpr(fun, in_avals): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) return jaxpr, consts @@ -240,7 +229,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp." raise AttributeError(msg) jvp_name = getattr(self.jvp, '__name__', str(self.jvp)) - args = _resolve_kwargs(self.fun, args, kwargs) + args = resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple(_stop_gradient(x) if i in nondiff_argnums else x @@ -599,7 +588,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp." raise AttributeError(msg) fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) - args = _resolve_kwargs(self.fun, args, kwargs) + args = resolve_kwargs(self.fun, args, kwargs) if self.optimize_remat: fwd = optimize_remat_of_custom_vjp_fwd( self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums, @@ -1451,7 +1440,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: # above and it would be good to consolidate it. primal_name = getattr(fun, "__name__", str(fun)) fwd_name = getattr(fwd, "__name__", str(fwd)) - args = _resolve_kwargs(fwd, args, kwargs) + args = resolve_kwargs(fwd, args, kwargs) if nondiff_argnums: for i in nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums_ = set(nondiff_argnums) diff --git a/tests/api_test.py b/tests/api_test.py index cb0d7c0d40c7..4aafc42b7a0e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10798,6 +10798,32 @@ def g(x, a): self.assertAllClose(y, (x + a)**2) + def test_kwargs(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x=x) + self.assertAllClose(y, jnp.sin(x)) + ys = api.vmap(f)(x=xs) + self.assertAllClose(ys, jnp.cos(xs)) + + def test_undefined_rule(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + with self.assertRaisesRegex( + AttributeError, "No batching rule defined for custom_vmap function f"): + f(0.5) + + class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" From 4d9d622dda1e228e197464edb2e449d9b7c078ad Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 5 Aug 2024 15:04:31 -0700 Subject: [PATCH 006/702] Fix check_is_flash_attention --- jax/_src/cudnn/fused_attention_stablehlo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 262b8e2c140a..eb2ed2ff8bd9 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -285,7 +285,7 @@ def check_eq(a, b, c, msg): def check_is_flash_attention( query, key, layout, cudnn_version, has_bias, is_training): - if layout == AttentionLayout.BNTH: + if layout == AttentionLayout.BNTH.value: _, _, T, H = query.shape _, _, S, _ = key.shape else: From b45f0fe50faa67107cb8d085b831f3f5fbd5d742 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 6 Aug 2024 09:56:03 -0700 Subject: [PATCH 007/702] Support empty boolean indexing --- jax/_src/numpy/lax_numpy.py | 5 ++++- tests/lax_numpy_indexing_test.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3af51e30585d..7ef27b1eea66 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -8340,7 +8340,10 @@ def _expand_bool_indices(idx, shape): i_shape = _shape(i) start = len(out) + ellipsis_offset - newaxis_offset expected_shape = shape[start: start + _ndim(i)] - if i_shape != expected_shape: + if len(i_shape) != len(expected_shape): + raise IndexError(f"too many boolean indices at index {dim_number}: got mask of shape " + f"{i_shape}, but only {len(expected_shape)} dimensions remain.") + if not all(s1 in (0, s2) for s1, s2 in zip(i_shape, expected_shape)): raise IndexError("boolean index did not match shape of indexed array in index " f"{dim_number}: got {i_shape}, expected {expected_shape}") out.extend(np.where(i)) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index cbb8e92ed603..bf2785f62d68 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1030,6 +1030,23 @@ def testNontrivialBooleanIndexing(self): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @parameterized.parameters( + [(3,), (0,)], + [(3, 4), (0,)], + [(3, 4), (0, 4)], + [(3, 4), (3, 0)], + [(3, 4, 5), (3, 0)], + ) + def testEmptyBooleanIndexing(self, x_shape, m_shape): + # Regression test for https://github.com/google/jax/issues/22886 + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(x_shape, np.int32), np.empty(m_shape, dtype=bool)] + + np_fun = lambda x, m: np.asarray(x)[np.asarray(m)] + jnp_fun = lambda x, m: jnp.asarray(x)[jnp.asarray(m)] + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + @jtu.sample_product( shape=[(2, 3, 4, 5)], idx=[ From 803453ed742c93e360d5de8341ba4a94f2c5f9b1 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 6 Aug 2024 22:32:46 -0700 Subject: [PATCH 008/702] [Pallas TPU] Close over consts in while_loop lowering to avoid passing refs in/out of loop PiperOrigin-RevId: 660238073 --- jax/_src/pallas/mosaic/lowering.py | 40 ++++++++---------------------- 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 082927677c73..7299f9929fe3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2160,7 +2160,7 @@ def _run_body(i, args): def _scan_lowering_rule( ctx: LoweringRuleContext, *args, - jaxpr: jax_core.Jaxpr, + jaxpr: jax_core.ClosedJaxpr, linear: tuple[bool, ...], length: int, reverse: bool, @@ -2241,7 +2241,7 @@ def _while_lowering_rule( body_jaxpr, ): # First try to lower via a simpler fori loop, which may optimize better. - fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop( + fori_jaxpr, _ = pallas_utils.pattern_match_while_to_fori_loop( cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts ) if fori_jaxpr is not None: @@ -2262,19 +2262,12 @@ def _while_lowering_rule( cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = ( split_list(ctx.block_shapes, [cond_nconsts, body_nconsts]) ) - cond_const_types = [a.type for a in cond_consts] - body_const_types = [a.type for a in body_consts] carry_types = [a.type for a in carry] - all_types = [*cond_const_types, *body_const_types, *carry_types] - while_op = scf.WhileOp(all_types, args) + while_op = scf.WhileOp(carry_types, carry) - before_block = while_op.before.blocks.append(*all_types) - cond_consts_, _, carry_ = split_list( - before_block.arguments, - [cond_nconsts, body_nconsts], - ) - cond_args = [*cond_consts_, *carry_] + before_block = while_op.before.blocks.append(*carry_types) with ir.InsertionPoint.at_block_begin(before_block): + cond_args = [*cond_consts, *before_block.arguments] [cond] = jaxpr_subcomp( ctx.lowering_context.replace( block_shapes=[*cond_const_block_shapes, *carry_block_shapes] @@ -2284,30 +2277,19 @@ def _while_lowering_rule( ) scf.condition(cond, before_block.arguments) - after_block = while_op.after.blocks.append(*all_types) - cond_consts_, body_consts_, carry_ = split_list( - after_block.arguments, - [cond_nconsts, body_nconsts], - ) - all_args = [*cond_consts_, *body_consts_, *carry_] - cond_const_args, body_const_args, carry_args = split_list( - all_args, [cond_nconsts, body_nconsts] - ) + after_block = while_op.after.blocks.append(*carry_types) with ir.InsertionPoint.at_block_begin(after_block): + body_args = [*body_consts, *after_block.arguments] loop_out = jaxpr_subcomp( ctx.lowering_context.replace( block_shapes=[*body_const_block_shapes, *carry_block_shapes], ), body_jaxpr.jaxpr, - *body_const_args, - *carry_args, + *body_args, ) - all_handles = [*cond_const_args, *body_const_args, *loop_out] - if all_handles: - scf.yield_(all_handles) - - all_out = list(while_op.results_) - return all_out[cond_nconsts + body_nconsts :] + if loop_out: + scf.yield_(loop_out) + return list(while_op.results) lowering_rules[lax.while_p] = _while_lowering_rule From 64eb8e9639ce26316be7e432dd7a6d43a5749085 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 7 Aug 2024 08:38:56 +0300 Subject: [PATCH 009/702] [pallas] Add a warning message about experimental and incomplete status --- docs/pallas/index.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index 403a8ce9c620..467f375d0e43 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -3,6 +3,9 @@ Pallas: a JAX kernel language ============================= Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. +It aims to provide fine-grained control over the generated code, combined with +the high-level ergonomics of JAX tracing and the `jax.numpy` API. + This section contains tutorials, guides and examples for using Pallas. See also the :class:`jax.experimental.pallas` module API documentation. @@ -10,6 +13,10 @@ See also the :class:`jax.experimental.pallas` module API documentation. Pallas is experimental and is changing frequently. See the :ref:`pallas-changelog` for the recent changes. + You can expect to encounter errors and unimplemented cases, e.g., when + lowering of high-level JAX concepts that would require emulation, + or simply because Pallas is still under development. + .. toctree:: :caption: Guides :maxdepth: 2 From 28ca734d9b32803aa176ef13044d246d8fb76006 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 7 Aug 2024 03:15:56 -0700 Subject: [PATCH 010/702] Added another boxDim check to mosaic_gpu_init_tma_desc PiperOrigin-RevId: 660314586 --- jax/experimental/mosaic/gpu/__init__.py | 10 ++++++++++ jaxlib/mosaic/gpu/runtime.cc | 7 +++++++ tests/mosaic/gpu_test.py | 18 ++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 9d4068745d3a..5ffc149bc76f 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -442,6 +442,16 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): rank = len(slice_shape) if rank > 5: # TODO: apaszke - Implement stride compression raise ValueError("Async copies only support striding up to 5 dimensions") + if max(slice_shape) > 256: + raise ValueError( + "Async copies only support copying <=256 elements along each" + " dimension" + ) + if (zeroth_bw := slice_shape[-1] * element_bytewidth) % 16 != 0: + raise ValueError( + "Async copies require the number of bytes copied along the last" + f" dimension to be divisible by 16, but got {zeroth_bw}" + ) if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth: raise ValueError( f"Async copies with {swizzle=} require last dimension of the slice to" diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index 4acb9c3dbf83..82659e45bef1 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -88,6 +88,13 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, tma_window_shape_i, rank - i - 1); abort(); } + if (i == 0 && (tma_window_shape_i * elem_bytewidth) % 16 != 0) { + fprintf(stderr, + "The last dimension of window shape must have a bytewidth " + "divisible by 16, but got %d*%ld at index %ld\n", + tma_window_shape_i, elem_bytewidth, rank - i - 1); + abort(); + } tma_window_shape[i] = tma_window_shape_i; } cuuint32_t element_strides[5] = {1, 1, 1, 1, 1}; diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f4fb6761ce41..dec9452fd9d7 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -961,6 +961,24 @@ def kernel(ctx, src, dst, tmp): y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) np.testing.assert_array_equal(y, x) + def test_tma_invalid(self): + def kernel(ctx, src, dst, tmp): + copy(src, tmp) + ctx.async_copy(src_ref=tmp, dst_ref=dst) + ctx.await_async_copy(0) + + def run_kernel(shape): + x = np.arange(np.prod(shape)).reshape(shape) + _ = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + + with self.assertRaisesRegex(ValueError, "only support striding up to 5"): + run_kernel([1] * 6) + + with self.assertRaisesRegex( + ValueError, "last dimension to be divisible by 16" + ): + run_kernel([23]) + class FragmentedArrayTest(TestCase): From 3e5e9475429237b6eef6ddfd242ef2dfbbff931a Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 7 Aug 2024 04:59:19 -0700 Subject: [PATCH 011/702] Move some backwards compatibility tests from jax_triton to jax/pallas. While doing this I moved `matmul.py` to `jax/experimental/pallas/ops/tpu` PiperOrigin-RevId: 660341331 --- .../pallas/mosaic_matmul.py | 339 ++++++++++++++++++ .../pallas/mosaic_semaphore_dma.py | 95 +++++ .../{cuda_add_one.py => triton_add_one.py} | 0 jax/experimental/pallas/ops/tpu/matmul.py | 85 +++++ tests/pallas/BUILD | 1 + .../pallas/export_back_compat_pallas_test.py | 61 +++- 6 files changed, 574 insertions(+), 7 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py rename jax/_src/internal_test_util/export_back_compat_test_data/pallas/{cuda_add_one.py => triton_add_one.py} (100%) create mode 100644 jax/experimental/pallas/ops/tpu/matmul.py diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py new file mode 100644 index 000000000000..065db82453f3 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py @@ -0,0 +1,339 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from numpy import array, float32 + + +# Pasted from the test output (see back_compat_test_util.py module docstring) +data_2023_09_22 = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['tpu_custom_call'], + serialized_date=datetime.date(2023, 9, 22), + inputs=(), + expected_outputs=(array([[ 90458.2 , 90470.875, 90480.85 , 90491.11 , + 90500.945, 90510.95 , 90521.18 , 90530.95 , + 90540.78 , 90551.16 , 90560.68 , 90570.734, + 90580.73 , 90590.58 , 90600.66 , 90610.61 ], + [ 643341.75 , 643434.25 , 643509.75 , 643587.06 , + 643660.1 , 643735.9 , 643813.5 , 643886. , + 643960.6 , 644039.56 , 644110.25 , 644186.75 , + 644262.5 , 644336.06 , 644412.9 , 644488.4 ], + [ 1196323.2 , 1196495.6 , 1196636.8 , 1196781. , + 1196917.5 , 1197059. , 1197203.9 , 1197339.2 , + 1197478.5 , 1197625.8 , 1197757.8 , 1197900.5 , + 1198042. , 1198179.4 , 1198323. , 1198464. ], + [ 1749075.5 , 1749327.9 , 1749534.4 , 1749745.9 , + 1749945.5 , 1750152.8 , 1750365.1 , 1750563.1 , + 1750767.1 , 1750983.1 , 1751176.2 , 1751385.4 , + 1751592.8 , 1751793.8 , 1752004.2 , 1752210.8 ], + [ 2302500.5 , 2302832.5 , 2303104.8 , 2303383.5 , + 2303646.2 , 2303919.5 , 2304199. , 2304459.8 , + 2304728.5 , 2305013. , 2305267.2 , 2305543. , + 2305816.2 , 2306081. , 2306358.5 , 2306630.5 ], + [ 2855440.2 , 2855852.5 , 2856190.2 , 2856535.5 , + 2856861.5 , 2857200.5 , 2857547.2 , 2857870.5 , + 2858204.5 , 2858557. , 2858872.5 , 2859214.5 , + 2859553.2 , 2859882. , 2860226. , 2860563.5 ], + [ 3407472. , 3407964.2 , 3408367.5 , 3408780.2 , + 3409169.5 , 3409574.5 , 3409988.5 , 3410374.5 , + 3410773. , 3411194. , 3411570.5 , 3411979. , + 3412383.5 , 3412776. , 3413186.5 , 3413590. ], + [ 3959847.5 , 3960419. , 3960888. , 3961367.8 , + 3961820.2 , 3962290.8 , 3962772.5 , 3963221.2 , + 3963684.8 , 3964174.2 , 3964612.2 , 3965086.8 , + 3965557.2 , 3966013.2 , 3966491. , 3966959.5 ], + [ 4515869.5 , 4516521.5 , 4517056. , 4517602. , + 4518118. , 4518654.5 , 4519203. , 4519715. , + 4520243. , 4520801. , 4521300. , 4521841. , + 4522378. , 4522897. , 4523441.5 , 4523975.5 ], + [ 5061659. , 5062390. , 5062990. , 5063603.5 , + 5064182. , 5064784.5 , 5065401. , 5065975. , + 5066567.5 , 5067194. , 5067754. , 5068362. , + 5068964. , 5069547. , 5070159. , 5070759. ], + [ 5621329. , 5622141. , 5622806.5 , 5623487.5 , + 5624129.5 , 5624797. , 5625481. , 5626118. , + 5626775. , 5627470.5 , 5628092. , 5628765. , + 5629433.5 , 5630080.5 , 5630758.5 , 5631424. ], + [ 6172821. , 6173712. , 6174443. , 6175191. , + 6175896. , 6176630. , 6177381. , 6178080.5 , + 6178803. , 6179566. , 6180248.5 , 6180988. , + 6181722. , 6182432.5 , 6183178. , 6183908. ], + [ 6723343.5 , 6724315. , 6725111.5 , 6725927. , + 6726696. , 6727495.5 , 6728313.5 , 6729076.5 , + 6729863.5 , 6730696. , 6731440. , 6732246. , + 6733046. , 6733820.5 , 6734632. , 6735428.5 ], + [ 7280537. , 7281587.5 , 7282449.5 , 7283331.5 , + 7284163.5 , 7285028.5 , 7285914. , 7286739.5 , + 7287591. , 7288492. , 7289296.5 , 7290169.5 , + 7291035. , 7291873.5 , 7292752.5 , 7293614. ], + [ 7828292. , 7829423. , 7830350. , 7831299.5 , + 7832194.5 , 7833125.5 , 7834078.5 , 7834966. , + 7835883. , 7836852. , 7837717.5 , 7838657. , + 7839588. , 7840490. , 7841436. , 7842363.5 ], + [ 8384808.5 , 8386019.5 , 8387012.5 , 8388029.5 , + 8388988. , 8389985. , 8391005. , 8391956. , + 8392937. , 8393974. , 8394902. , 8395907. , + 8396904. , 8397870. , 8398882. , 8399875. ], + [ 8928697. , 8929987. , 8931044. , 8932126. , + 8933146. , 8934208. , 8935294. , 8936306. , + 8937351. , 8938455. , 8939443. , 8940514. , + 8941574. , 8942604. , 8943682. , 8944738. ], + [ 9501496. , 9502866. , 9503990. , 9505141. , + 9506226. , 9507354. , 9508508. , 9509584. , + 9510695. , 9511870. , 9512919. , 9514058. , + 9515186. , 9516279. , 9517425. , 9518549. ], + [10055416. , 10056868. , 10058060. , 10059279. , + 10060428. , 10061624. , 10062848. , 10063988. , + 10065166. , 10066410. , 10067522. , 10068729. , + 10069924. , 10071083. , 10072298. , 10073489. ], + [10595886. , 10597417. , 10598673. , 10599958. , + 10601170. , 10602431. , 10603721. , 10604923. , + 10606164. , 10607477. , 10608649. , 10609921. , + 10611182. , 10612404. , 10613684. , 10614941. ], + [11135804. , 11137412. , 11138732. , 11140083. , + 11141357. , 11142682. , 11144038. , 11145301. , + 11146606. , 11147985. , 11149218. , 11150554. , + 11151880. , 11153163. , 11154509. , 11155829. ], + [11686791. , 11688480. , 11689864. , 11691282. , + 11692618. , 11694007. , 11695430. , 11696756. , + 11698123. , 11699571. , 11700864. , 11702265. , + 11703656. , 11705003. , 11706414. , 11707799. ], + [12263420. , 12265190. , 12266642. , 12268128. , + 12269529. , 12270986. , 12272478. , 12273868. , + 12275303. , 12276820. , 12278176. , 12279646. , + 12281104. , 12282516. , 12283996. , 12285447. ], + [12821178. , 12823029. , 12824548. , 12826102. , + 12827567. , 12829092. , 12830652. , 12832105. , + 12833606. , 12835193. , 12836610. , 12838149. , + 12839673. , 12841150. , 12842699. , 12844217. ], + [13362964. , 13364895. , 13366479. , 13368100. , + 13369628. , 13371218. , 13372846. , 13374362. , + 13375927. , 13377582. , 13379061. , 13380665. , + 13382255. , 13383796. , 13385411. , 13386995. ], + [13902882. , 13904891. , 13906539. , 13908225. , + 13909815. , 13911470. , 13913163. , 13914740. , + 13916369. , 13918091. , 13919629. , 13921298. , + 13922953. , 13924556. , 13926236. , 13927884. ], + [14443848. , 14445934. , 14447646. , 14449398. , + 14451050. , 14452769. , 14454528. , 14456166. , + 14457858. , 14459647. , 14461245. , 14462979. , + 14464698. , 14466363. , 14468108. , 14469820. ], + [15024407. , 15026576. , 15028355. , 15030176. , + 15031893. , 15033679. , 15035507. , 15037210. , + 15038969. , 15040827. , 15042490. , 15044291. , + 15046077. , 15047808. , 15049621. , 15051400. ], + [15586096. , 15588347. , 15590193. , 15592082. , + 15593863. , 15595716. , 15597613. , 15599380. , + 15601204. , 15603133. , 15604857. , 15606726. , + 15608579. , 15610375. , 15612257. , 15614103. ], + [16130043. , 16132373. , 16134285. , 16136242. , + 16138087. , 16140006. , 16141970. , 16143800. , + 16145690. , 16147688. , 16149473. , 16151409. , + 16153328. , 16155188. , 16157138. , 16159049. ], + [16669961. , 16672369. , 16674345. , 16676367. , + 16678274. , 16680257. , 16682287. , 16684178. , + 16686131. , 16688196. , 16690041. , 16692042. , + 16694026. , 16695948. , 16697962. , 16699938. ], + [17209878. , 17212364. , 17214404. , 17216492. , + 17218460. , 17220508. , 17222604. , 17224556. , + 17226572. , 17228704. , 17230608. , 17232676. , + 17234724. , 17236708. , 17238788. , 17240828. ], + [17817286. , 17819860. , 17821972. , 17824132. , + 17826172. , 17828292. , 17830460. , 17832482. , + 17834570. , 17836776. , 17838748. , 17840888. , + 17843008. , 17845062. , 17847216. , 17849328. ], + [18357204. , 18359856. , 18362032. , 18364258. , + 18366358. , 18368542. , 18370778. , 18372860. , + 18375012. , 18377284. , 18379316. , 18381520. , + 18383704. , 18385820. , 18388040. , 18390216. ], + [18897120. , 18899852. , 18902092. , 18904384. , + 18906544. , 18908794. , 18911096. , 18913240. , + 18915452. , 18917792. , 18919884. , 18922152. , + 18924402. , 18926580. , 18928864. , 18931104. ], + [19437040. , 19439848. , 19442152. , 19444508. , + 19446732. , 19449044. , 19451412. , 19453616. , + 19455894. , 19458302. , 19460452. , 19462786. , + 19465100. , 19467340. , 19469688. , 19471992. ], + [19976956. , 19979844. , 19982212. , 19984634. , + 19986920. , 19989296. , 19991728. , 19993996. , + 19996336. , 19998810. , 20001020. , 20003420. , + 20005796. , 20008100. , 20010514. , 20012882. ], + [20516874. , 20519838. , 20522270. , 20524760. , + 20527106. , 20529548. , 20532046. , 20534374. , + 20536776. , 20539318. , 20541588. , 20544052. , + 20546492. , 20548860. , 20551338. , 20553770. ], + [21056792. , 21059834. , 21062330. , 21064884. , + 21067292. , 21069798. , 21072364. , 21074752. , + 21077218. , 21079826. , 21082156. , 21084684. , + 21087190. , 21089618. , 21092162. , 21094658. ], + [21596710. , 21599830. , 21602390. , 21605010. , + 21607480. , 21610050. , 21612680. , 21615130. , + 21617660. , 21620336. , 21622724. , 21625318. , + 21627888. , 21630378. , 21632988. , 21635548. ], + [22218698. , 22221906. , 22224536. , 22227228. , + 22229768. , 22232408. , 22235108. , 22237628. , + 22240228. , 22242976. , 22245434. , 22248094. , + 22250734. , 22253292. , 22255972. , 22258602. ], + [22802946. , 22806238. , 22808938. , 22811700. , + 22814306. , 22817016. , 22819790. , 22822374. , + 22825044. , 22827864. , 22830386. , 22833120. , + 22835830. , 22838456. , 22841208. , 22843906. ], + [23351442. , 23354816. , 23357584. , 23360416. , + 23363088. , 23365866. , 23368710. , 23371360. , + 23374094. , 23376988. , 23379572. , 23382374. , + 23385154. , 23387846. , 23390668. , 23393436. ], + [23891360. , 23894812. , 23897644. , 23900542. , + 23903274. , 23906118. , 23909028. , 23911738. , + 23914536. , 23917496. , 23920140. , 23923008. , + 23925850. , 23928606. , 23931492. , 23934324. ], + [24431278. , 24434808. , 24437704. , 24440668. , + 24443462. , 24446368. , 24449344. , 24452116. , + 24454978. , 24458004. , 24460708. , 24463640. , + 24466548. , 24469364. , 24472318. , 24475214. ], + [24971196. , 24974804. , 24977764. , 24980792. , + 24983648. , 24986620. , 24989662. , 24992494. , + 24995420. , 24998512. , 25001276. , 25004274. , + 25007244. , 25010124. , 25013142. , 25016102. ], + [25511114. , 25514800. , 25517824. , 25520918. , + 25523836. , 25526872. , 25529978. , 25532872. , + 25535860. , 25539020. , 25541844. , 25544906. , + 25547942. , 25550884. , 25553966. , 25556990. ], + [26051032. , 26054796. , 26057884. , 26061044. , + 26064022. , 26067122. , 26070296. , 26073250. , + 26076302. , 26079530. , 26082412. , 26085540. , + 26088640. , 26091642. , 26094792. , 26097880. ], + [26590950. , 26594790. , 26597942. , 26601168. , + 26604210. , 26607374. , 26610612. , 26613628. , + 26616744. , 26620038. , 26622980. , 26626172. , + 26629336. , 26632402. , 26635616. , 26638768. ], + [27130866. , 27134786. , 27138002. , 27141294. , + 27144396. , 27147626. , 27150930. , 27154008. , + 27157186. , 27160546. , 27163548. , 27166806. , + 27170034. , 27173162. , 27176440. , 27179656. ], + [27723244. , 27727248. , 27730532. , 27733892. , + 27737062. , 27740358. , 27743732. , 27746876. , + 27750120. , 27753552. , 27756618. , 27759944. , + 27763240. , 27766436. , 27769782. , 27773064. ], + [28323220. , 28327310. , 28330664. , 28334094. , + 28337330. , 28340696. , 28344142. , 28347352. , + 28350664. , 28354168. , 28357300. , 28360696. , + 28364062. , 28367324. , 28370744. , 28374096. ], + [28885444. , 28889618. , 28893040. , 28896544. , + 28899848. , 28903284. , 28906802. , 28910078. , + 28913462. , 28917038. , 28920234. , 28923702. , + 28927138. , 28930468. , 28933958. , 28937382. ], + [29425518. , 29429768. , 29433256. , 29436826. , + 29440192. , 29443694. , 29447276. , 29450614. , + 29454062. , 29457706. , 29460962. , 29464496. , + 29467996. , 29471390. , 29474946. , 29478434. ], + [29965436. , 29969764. , 29973316. , 29976952. , + 29980378. , 29983944. , 29987594. , 29990992. , + 29994504. , 29998214. , 30001532. , 30005128. , + 30008694. , 30012148. , 30015770. , 30019322. ], + [30505352. , 30509760. , 30513376. , 30517076. , + 30520566. , 30524196. , 30527910. , 30531372. , + 30534944. , 30538724. , 30542100. , 30545760. , + 30549392. , 30552908. , 30556594. , 30560210. ], + [31045270. , 31049756. , 31053436. , 31057202. , + 31060752. , 31064446. , 31068228. , 31071750. , + 31075386. , 31079232. , 31082668. , 31086394. , + 31090088. , 31093668. , 31097420. , 31101100. ], + [31585188. , 31589752. , 31593496. , 31597328. , + 31600940. , 31604698. , 31608544. , 31612128. , + 31615828. , 31619740. , 31623236. , 31627026. , + 31630786. , 31634428. , 31638244. , 31641988. ], + [32125106. , 32129748. , 32133556. , 32137452. , + 32141126. , 32144950. , 32148862. , 32152506. , + 32156270. , 32160248. , 32163804. , 32167660. , + 32171482. , 32175186. , 32179068. , 32182876. ], + [32665024. , 32669742. , 32673614. , 32677578. , + 32681314. , 32685200. , 32689178. , 32692884. , + 32696710. , 32700756. , 32704372. , 32708292. , + 32712180. , 32715946. , 32719894. , 32723766. ], + [33221238. , 33226038. , 33229974. , 33234004. , + 33237804. , 33241756. , 33245802. , 33249570. , + 33253460. , 33257576. , 33261252. , 33265238. , + 33269192. , 33273022. , 33277034. , 33280972. ], + [33836944. , 33841824. , 33845832. , 33849936. , + 33853804. , 33857824. , 33861940. , 33865776. , + 33869736. , 33873920. , 33877664. , 33881720. , + 33885744. , 33889640. , 33893724. , 33897732. ], + [34414896. , 34419864. , 34423944. , 34428112. , + 34432048. , 34436140. , 34440328. , 34444232. , + 34448260. , 34452520. , 34456324. , 34460456. , + 34464548. , 34468512. , 34472672. , 34476744. ], + [34824696. , 34829728. , 34833856. , 34838080. , + 34842064. , 34846208. , 34850448. , 34854396. , + 34858476. , 34862792. , 34866644. , 34870824. , + 34874968. , 34878984. , 34883192. , 34887320. ]], + dtype=float32),), + mlir_module_text=r""" +#loc4 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":33:0) +#loc11 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=matmul keep_unused=False inline=False]"(#loc4)) +#loc16 = loc("jit(func)/jit(main)/jit(matmul)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=wrapped keep_unused=False inline=False]"(#loc4)) +#loc17 = loc("jit(func)/jit(main)/jit(matmul)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=apply_kernel keep_unused=False inline=False]"(#loc4)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<64x16xf32> {jax.result_info = ""}) { + %0 = stablehlo.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512, 528, 544, 560, 576, 592, 608, 624, 640, 656, 672, 688, 704, 720, 736, 752, 768, 784, 800, 816, 832, 848, 864, 880, 896, 912, 928, 944, 960, 976, 992, 1008]> : tensor<64xi32> loc(#loc) + %1 = stablehlo.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : tensor<16xi32> loc(#loc) + %2 = stablehlo.iota dim = 0 : tensor<524288xf32> loc(#loc6) + %3 = stablehlo.reshape %2 : (tensor<524288xf32>) -> tensor<1024x512xf32> loc(#loc7) + %4 = stablehlo.constant dense<1.000000e-03> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1024x512xf32> loc(#loc8) + %6 = stablehlo.multiply %5, %3 : tensor<1024x512xf32> loc(#loc8) + %7 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc) + %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1024x512xf32> loc(#loc9) + %9 = stablehlo.add %8, %6 : tensor<1024x512xf32> loc(#loc9) + %10 = stablehlo.slice %9 [0:512, 0:256] : (tensor<1024x512xf32>) -> tensor<512x256xf32> loc(#loc10) + %11 = call @matmul(%9, %10) : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc11) + %12 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<64xi32>) -> tensor<64x16x1xi32> loc(#loc12) + %13 = stablehlo.broadcast_in_dim %1, dims = [1] : (tensor<16xi32>) -> tensor<64x16x1xi32> loc(#loc13) + %14 = stablehlo.concatenate %12, %13, dim = 2 : (tensor<64x16x1xi32>, tensor<64x16x1xi32>) -> tensor<64x16x2xi32> loc(#loc14) + %15 = "stablehlo.gather"(%11, %14) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<1024x256xf32>, tensor<64x16x2xi32>) -> tensor<64x16xf32> loc(#loc15) + return %15 : tensor<64x16xf32> loc(#loc) + } loc(#loc) + func.func private @matmul(%arg0: tensor<1024x512xf32> loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=matmul keep_unused=False inline=False]"(#loc4)), %arg1: tensor<512x256xf32> loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=matmul keep_unused=False inline=False]"(#loc4))) -> tensor<1024x256xf32> { + %0 = call @wrapped(%arg0, %arg1) : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc16) + return %0 : tensor<1024x256xf32> loc(#loc11) + } loc(#loc11) + func.func private @wrapped(%arg0: tensor<1024x512xf32> loc("jit(func)/jit(main)/jit(matmul)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=wrapped keep_unused=False inline=False]"(#loc4)), %arg1: tensor<512x256xf32> loc("jit(func)/jit(main)/jit(matmul)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=wrapped keep_unused=False inline=False]"(#loc4))) -> tensor<1024x256xf32> { + %0 = call @apply_kernel(%arg0, %arg1) : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc17) + return %0 : tensor<1024x256xf32> loc(#loc16) + } loc(#loc16) + func.func private @apply_kernel(%arg0: tensor<1024x512xf32> loc("jit(func)/jit(main)/jit(matmul)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=apply_kernel keep_unused=False inline=False]"(#loc4)), %arg1: tensor<512x256xf32> loc("jit(func)/jit(main)/jit(matmul)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=apply_kernel keep_unused=False inline=False]"(#loc4))) -> tensor<1024x256xf32> { + %0 = stablehlo.custom_call @tpu_custom_call(%arg0, %arg1) {backend_config = "{\22custom_call_config\22: {\22body\22: \22TUzvUgFNTElSZ29vZ2xlMy10cnVuawABLwkBAwUHAQMJAwUDCwUNDQ8RExUXBwMZA44DIgMhAfkbDw8LKxMTBxcjEwsLCwsTCwsLhQsLCxsLMwsPEw87CxMLC1MLDwsLFxsLUxsLUxsLUxsbGw8TEwsLExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTCxMTExMXBQthkWlpeQGPExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXDxcPFw8XDxcPFw8XDxcTHxMfDxcfCwsLCwtTCxMBIRsHHw8HHw8nJycLIx8nJycCwhEDBTEzNTcd7R8dcR8FGwMHEgMWAzEzNTcdGgMfHQIDHx8DAwYDOQMFCgM7DgM7AwMXOQUdBR8FIQEBF3NDAQUjBSUNGWFmZmluZV9tYXA8KGQwLCBkMSkgLT4gKGQwLCBkMSk+AAUnBSkFKwMFFx0HawUtIxUREQEBAQEBAQEBBS8RBwUBAwICERUAAw0/QRlDRUdJSxtNT1EFMQEF+/sNFwUzIw0FIQQAAAAAAAAAAQAAAAAAAAAFNRENAQU3BTkBB1NZXwMFIVUjVwkpIw0FIQABAAAAAAAAAAIAAAAAAAADBSFbI10JKyMNBSEAAgAAAAAAAAABAAAAAAAAAwUhYSNjCS0jDQUhAAEAAAAAAAAAAQAAAAAAAAMFGSUbKQMFGSUbKwMFGSUbLREHAQMDB28RA8IPBTsFPQMDB3cRA4IPAwMHexEDQg8DAwd/EQMCDwMDB4MRA8IOAwMHhxEDgg4DAweLEQNCDgMDB48RAwIOAwMHkxEDwg0DAweXEQOCDQMDB5sRA0INAwMHnxEDAg0DAwejEQPCDAMDB6cRA4IMAwMHqxEDQgwDAwevEQPCCwMDB7MRA4ILAwMHtxEDQgsDAwe7EQMCCwMDB78RA8IKAwMHwxEDggoDAwfHEQNCCgMDB8sRAwIKAwMHzxEDwgkDAwfTEQOCCQMDB9cRA0IJAwMH2xEDAgkDAwffEQPCCAMDB+MRA4IIAwMH5xEDQggDAwfrEQMCDAU/AwMH8REDwgcDAwf1EQOCBwMDBwYCI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPGFyYml0cmFyeT4AI3RwdS50aWxlZDwoOCwxMjgpLFsyLDFdPgAjdHB1LnRpbGVkPCg4LDEyOCksWzQsMV0+ACN0cHUudnBhZDwiMzIsezAsMH0sKDgsMTI4KSI+ABEDQgcDAwcOAhEDAgcDAwcWAhEDwgYDAwceAhEDggYDAwcmAhEDQgYDAwcuAhEDAgYDAwc2AhEDwgUDAwc+AhEDggUDAwdGAhEDQgUDAwdOAhEDAgUDAwdWAhEDwgQDAwdeAhEDggQDAwdmAhEDQgQDAwduAhEDwgMDAwd2AhEDggMDAwd+AhEDQgMDAweGAhEDAgMDAweOAhEDwgIDAweWAhEDggIDAweeAhEDQgIDAwemAhEDAgIDAweuAhED4QMDB7YCEQPBAwMHvgIRA6EDAwfGAhEDgQMDB84CEQNhAwMH1gIRA0EDAwfeAhEDIQMDB+YCEQMCBAMFFx0H7gIRAwIIAwUXHQf2AhEDAQMDB/4CJQEJAAAAAAVBBUMFRQVHBUkjBwkhAQAAAAEAAAACAAAAAAAAAAVLAwMXHScFIQIECQMnBQIIAgQJAQICCycFAgQCBAkBAgQX+QUCCAIQCf8X+QUCEAIICf0X+QUCCAIICf0BCQULBwcPERMBBQUHBwUHBxf5BQIIAhAJJxf5BQIQAggJJxf5BQIIAggJJwR6UQUBEA8HAwERAxEPPQcDlglODQsHDwcPDw8RDxMPEwMFbQMDEwMFdQMDEwMFeQMDEwMFfQMDEwMFgQMDEwMFhQMDEwMFiQMDEwMFjQMDEwMFkQMDEwMFlQMDEwMFmQMDEwMFnQMDEwMFoQMDEwMFpQMDEwMFqQMDEwMFrQMDEwMFsQMDEwMFtQMDEwMFuQMDEwMFvQMDEwMFwQMDEwMFxQMDEwMFyQMDEwMFzQMDEwMF0QMDEwMF1QMDEwMF2QMDEwMF3QMDEwMF4QMDEwMF5QMDEwMD6QMDEwMD7wMDEwMD8wMDEwMD9wMDEwMDCgIDAxMDAxICAwMTAwMaAgMDEwMDIgIDAxMDAyoCAwMTAwMyAgMDEwMDOgIDAxMDA0ICAwMTAwNKAgMDEwMDUgIDAxMDA1oCAwMTAwNiAgMDEwMDagIDAxMDA3ICAwMTAwN6AgMDEwMDggIDAxMDA4oCAwMTAwOSAgMDEwMDmgIDAxMDA6ICAwMTAwOqAgMDEwMDsgIDAxMDA7oCAwMTAwPCAgMDEwMDygIDAxMDA9ICAwMTAwPaAgMDEwMD4gIDAxMDA+oCAwMTAwPyAgMDEwMN+gIDAREGDwMbAwURBg8DHQMHEQYPAx8DCQcHAwEDAQeNiYkHBwMBAwEHjYmFBwcDAQMBB42DiQcHAwEDAQeNg4UHBwMBAwEHjYGJBwcDAQMBB42BhQcHAwEDAQeNf4kHBwMBAwEHjX+FBwcDAQMBB419iQcHAwEDAQeNfYUHBwMBAwEHjXuJBwcDAQMBB417hQcHAwEDAQeNeYkHBwMBAwEHjXmFBwcDAQMBB413iQcHAwEDAQeNd4UHBwMBAwEHjXWJBwcDAQMBB411hQcHAwEDAQeNc4kHBwMBAwEHjXOFBwcDAQMBB41xiQcHAwEDAQeNcYUHBwMBAwEHjW+JBwcDAQMBB41vhQcHAwEDAQeNbYkHBwMBAwEHjW2FBwcDAQMBB41riQcHAwEDAQeNa4UHBwMBAwEHjWmJBwcDAQMBB41phQcHAwEDAQeNZ4kHBwMBAwEHjWeFBwcDAQMBB42FiQcHAwEDAQeNhYUHBwMBAwEHjWWJBwcDAQMBB41lhQcHAwEDAQeNY4kHBwMBAwEHjWOFBwcDAQMBB41hiQcHAwEDAQeNYYUHBwMBAwEHjV+JBwcDAQMBB41fhQcHAwEDAQeNXYkHBwMBAwEHjV2FBwcDAQMBB41biQcHAwEDAQeNW4UHBwMBAwEHjVmJBwcDAQMBB41ZhQcHAwEDAQeNV4kHBwMBAwEHjVeFBwcDAQMBB41ViQcHAwEDAQeNVYUHBwMBAwEHjVOJBwcDAQMBB41ThQcHAwEDAQeNUYkHBwMBAwEHjVGFBwcDAQMBB41PiQcHAwEDAQeNT4UHBwMBAwEHjU2JBwcDAQMBB41NhQcHAwEDAQeNS4kHBwMBAwEHjUuFBwcDAQMBB41JiQcHAwEDAQeNSYUHBwUBAwEHj4mJBwcFAQMBB4+JhQcHBQEDAQePg4kHBwUBAwEHj4OFBwcFAQMBB4+BiQcHBQEDAQePgYUHBwUBAwEHj3+JBwcFAQMBB49/hQcHBQEDAQePfYkHBwUBAwEHj32FBwcFAQMBB497iQcHBQEDAQePe4UHBwUBAwEHj3mJBwcFAQMBB495hQcHBQEDAQePd4kHBwUBAwEHj3eFBwcFAQMBB491iQcHBQEDAQePdYUHBwUBAwEHj3OJBwcFAQMBB49zhQcHBQEDAQePcYkHBwUBAwEHj3GFBwcFAQMBB49viQcHBQEDAQePb4UHBwUBAwEHj22JBwcFAQMBB49thQcHBQEDAQePa4kHBwUBAwEHj2uFBwcFAQMBB49piQcHBQEDAQePaYUHBwUBAwEHj2eJBwcFAQMBB49nhQcHBQEDAQePhYkHBwUBAwEHj4WFBwcFAQMBB49liQcHBQEDAQePZYUHBwUBAwEHj2OJBwcFAQMBB49jhQcHBQEDAQePYYkHBwUBAwEHj2GFBwcFAQMBB49fiQcHBQEDAQePX4UHBwUBAwEHj12JBwcFAQMBB49dhQcHBQEDAQePW4kHBwUBAwEHj1uFBwcFAQMBB49ZiQcHBQEDAQePWYUHBwUBAwEHj1eJBwcFAQMBB49XhQcHBQEDAQePVYkHBwUBAwEHj1WFBwcFAQMBB49TiQcHBQEDAQePU4UHBwUBAwEHj1GJBwcFAQMBB49RhQcHBQEDAQePT4kHBwUBAwEHj0+FBwcFAQMBB49NiQcHBQEDAQePTYUHBwUBAwEHj0uJBwcFAQMBB49LhQcHBQEDAQePSYkHBwUBAwEHj0mFBwcDAQMBB42JhwcHAwEDAQeNiUcHBwMBAwEHjYOHBwcDAQMBB42DRwcHAwEDAQeNgYcHBwMBAwEHjYFHBwcDAQMBB41/hwcHAwEDAQeNf0cHBwMBAwEHjX2HBwcDAQMBB419RwcHAwEDAQeNe4cHBwMBAwEHjXtHBwcDAQMBB415hwcHAwEDAQeNeUcHBwMBAwEHjXeHBwcDAQMBB413RwcHAwEDAQeNdYcHBwMBAwEHjXVHBwcDAQMBB41zhwcHAwEDAQeNc0cHBwMBAwEHjXGHBwcDAQMBB41xRwcHAwEDAQeNb4cHBwMBAwEHjW9HBwcDAQMBB41thwcHAwEDAQeNbUcHBwMBAwEHjWuHBwcDAQMBB41rRwcHAwEDAQeNaYcHBwMBAwEHjWlHBwcDAQMBB41nhwcHAwEDAQeNZ0cHBwMBAwEHjYWHBwcDAQMBB42FRwcHAwEDAQeNZYcHBwMBAwEHjWVHBwcDAQMBB41jhwcHAwEDAQeNY0cHBwMBAwEHjWGHBwcDAQMBB41hRwcHAwEDAQeNX4cHBwMBAwEHjV9HBwcDAQMBB41dhwcHAwEDAQeNXUcHBwMBAwEHjVuHBwcDAQMBB41bRwcHAwEDAQeNWYcHBwMBAwEHjVlHBwcDAQMBB41XhwcHAwEDAQeNV0cHBwMBAwEHjVWHBwcDAQMBB41VRwcHAwEDAQeNU4cHBwMBAwEHjVNHBwcDAQMBB41RhwcHAwEDAQeNUUcHBwMBAwEHjU+HBwcDAQMBB41PRwcHAwEDAQeNTYcHBwMBAwEHjU1HBwcDAQMBB41LhwcHAwEDAQeNS0cHBwMBAwEHjUmHBwcDAQMBB41JRwcHBQEDAQePh4kHBwUBAwEHj4eFBwcFAQMBB49FiQcHBQEDAQePRYUHBwUBAwEHj0OJBwcFAQMBB49DhQcHBQEDAQePQYkHBwUBAwEHj0GFBwcFAQMBB48/iQcHBQEDAQePP4UHBwUBAwEHjz2JBwcFAQMBB489hQcHBQEDAQePO4kHBwUBAwEHjzuFBwcFAQMBB485iQcHBQEDAQePOYUHBwUBAwEHjzeJBwcFAQMBB483hQcHBQEDAQePNYkHBwUBAwEHjzWFBwcFAQMBB48ziQcHBQEDAQePM4UHBwUBAwEHjzGJBwcFAQMBB48xhQcHBQEDAQePL4kHBwUBAwEHjy+FBwcFAQMBB48tiQcHBQEDAQePLYUHBwUBAwEHjyuJBwcFAQMBB48rhQcHBQEDAQePKYkHBwUBAwEHjymFBwcFAQMBB49HiQcHBQEDAQePR4UHBwUBAwEHjyeJBwcFAQMBB48nhQcHBQEDAQePJYkHBwUBAwEHjyWFBwcFAQMBB48jiQcHBQEDAQePI4UHBwUBAwEHjyGJBwcFAQMBB48hhQcHBQEDAQePH4kHBwUBAwEHjx+FBwcFAQMBB48diQcHBQEDAQePHYUHBwUBAwEHjxuJBwcFAQMBB48bhQcHBQEDAQePGYkHBwUBAwEHjxmFBwcFAQMBB48XiQcHBQEDAQePF4UHBwUBAwEHjxWJBwcFAQMBB48VhQcHBQEDAQePE4kHBwUBAwEHjxOFBwcFAQMBB48RiQcHBQEDAQePEYUHBwUBAwEHjw+JBwcFAQMBB48PhQcHBQEDAQePDYkHBwUBAwEHjw2FBwcFAQMBB48LiQcHBQEDAQePC4ULBw0RAwVBJgMuAzYDPgNGA04DVgNeA2YDbgN2A34DhgOOA5YDngOmA64DtgO+A8YDzgPWA94D5gPuA/YD/gMGBA4EFgQeBAsHDREDBUEqAzIDOgNCA0oDUgNaA2IDagNyA3oDggOKA5IDmgOiA6oDsgO6A8IDygPSA9oD4gPqA/ID+gMCBAoEEgQaBCIECwcNEQMLISYELgQ2BD4ERgROBFYEXgRmBG4EdgR+BIYEjgSWBJ4ECwcNEQMFQYuLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLDQcNEwMFByYFLgUyBQ8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEDNgULBw0RAwshpgSuBLYEvgTGBM4E1gTeBOYE7gT2BP4EBgUOBRYFHgULBw0RAwVBOgU+BUIFRgVKBU4FUgVWBVoFXgViBWYFagVuBXIFdgV6BX4FggWGBYoFjgWSBZYFmgWeBaIFpgWqBa4FsgW2BQ0HDRMDBQcqBboFvgUPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA8IFCwcNEQMLISoEMgQ6BEIESgRSBFoEYgRqBHIEegSCBIoEkgSaBKIECwcNEQMFQYuLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLDQcNEwMFByYFRgZKBg8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEDTgYLBw0RAwshqgSyBLoEwgTKBNIE2gTiBOoE8gT6BAIFCgUSBRoFIgULBw0RAwVBUgZWBloGXgZiBmYGagZuBnIGdgZ6Bn4GggaGBooGjgaSBpYGmgaeBqIGpgaqBq4Gsga2BroGvgbCBsYGygbOBg0HDRMDBQcqBdIG1gYPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA9oGCwcNEQMFQZOXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+/8GAg4CFgIeAgsHDREDBUGVmZ2hpamtsbW5vcHFyc3R1dnd4eXp7fH1+f0CAgoCEgIaAiICCwcNEQMLISYCLgI2Aj4CRgJOAlYCXgJmAm4CdgJ+AoYCjgKWAp4CCwcNEQMFQcYFygXOBdIF1gXaBd4F4gXmBeoF7gXyBfYF+gX+BQIGBgYKBg4GEgYWBhoGHgYiBiYGKgYuBjIGNgY6Bj4GQgYNBw0TAwUHXgdmB2oHDwcNFUEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQNuBwsHDREDCyGmAq4CtgK+AsYCzgLWAt4C5gLuAvYC/gIGAw4DFgMeAwsHDREDBUFyB3YHegd+B4IHhgeKB44HkgeWB5oHngeiB6YHqgeuB7IHtge6B74HwgfGB8oHzgfSB9YH2gfeB+IH5gfqB+4HDQcNEwMFB2IH8gf2Bw8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQED+gcLBw0RAwshKgIyAjoCQgJKAlICWgJiAmoCcgJ6AoICigKSApoCogILBw0RAwVB3gbiBuYG6gbuBvIG9gb6Bv4GAgcGBwoHDgcSBxYHGgceByIHJgcqBy4HMgc2BzoHPgdCB0YHSgdOB1IHVgdaBw0HDRMDBQdeB34IgggPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA4YICwcNEQMLIaoCsgK6AsICygLSAtoC4gLqAvIC+gICAwoDEgMaAyIDCwcNEQMFQYoIjgiSCJYImgieCKIIpgiqCK4Isgi2CLoIvgjCCMYIygjOCNII1gjaCN4I4gjmCOoI7gjyCPYI+gj+CAIJBgkNBw0TAwUHYgcKCQ4JDwcNFUEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQMSCQkFCwkJ/geRiYkJBQsJCRYJkYmFCQULCQkCCJGDiQkFCwkJGgmRg4UJBQsJCQYIkYGJCQULCQkeCZGBhQkFCwkJCgiRf4kJBQsJCSIJkX+FCQULCQkOCJF9iQkFCwkJJgmRfYUJBQsJCRIIkXuJCQULCQkqCZF7hQkFCwkJFgiReYkJBQsJCS4JkXmFCQULCQkaCJF3iQkFCwkJMgmRd4UJBQsJCR4IkXWJCQULCQk2CZF1hQkFCwkJIgiRc4kJBQsJCToJkXOFCQULCQkmCJFxiQkFCwkJPgmRcYUJBQsJCSoIkW+JCQULCQlCCZFvhQkFCwkJLgiRbYkJBQsJCUYJkW2FCQULCQkyCJFriQkFCwkJSgmRa4UJBQsJCTYIkWmJCQULCQlOCZFphQkFCwkJOgiRZ4kJBQsJCVIJkWeFCQULCQk+CJGFiQkFCwkJVgmRhYUJBQsJCUIIkWWJCQULCQlaCZFlhQkFCwkJRgiRY4kJBQsJCV4JkWOFCQULCQlKCJFhiQkFCwkJYgmRYYUJBQsJCU4IkV+JCQULCQlmCZFfhQkFCwkJUgiRXYkJBQsJCWoJkV2FCQULCQlWCJFbiQkFCwkJbgmRW4UJBQsJCVoIkVmJCQULCQlyCZFZhQkFCwkJXgiRV4kJBQsJCXYJkVeFCQULCQliCJFViQkFCwkJegmRVYUJBQsJCWYIkVOJCQULCQl+CZFThQkFCwkJagiRUYkJBQsJCYIJkVGFCQULCQluCJFPiQkFCwkJhgmRT4UJBQsJCXIIkU2JCQULCQmKCZFNhQkFCwkJdgiRS4kJBQsJCY4JkUuFCQULCQl6CJFJiQkFCwkJkgmRSYUFAQ8eAwMRD2UHAwcLBQcPBw8TAw8vAwcFBA8FAQUDEQ9nBwMHCwUHDwcPEwMPLwMHBQQPBQUDAxEPaQcDBQcFBw8HDwUEDwUBAwYDAQUBAE4VTXYDKR0dF04C/gOB/gMdCyEjKR8bGRkZHSUTHRUNEykfDxsNCw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAbW9kdWxlAHJldHVybgBsb2FkAHN0b3JlAHJvbGxfdmVjdG9ycwBtYXRtdWwAdW5yb2xsX3ZlY3RvcnMAZXJhc2VfbWVtcmVmX2xheW91dABjb25zdGFudAB2YWx1ZQBpbl9sYXlvdXQAZnVuY3Rpb25fdHlwZQBzeW1fbmFtZQB0cmFuc2Zvcm1faW5kaWNlcwB3aW5kb3dfYm91bmRzAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHN1YmxhbmVfbWFzawBzdWJsYW5lX3N0cmlkZQBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAG1haW4Ad2luZG93X3BhcmFtcwAvbWFza2VkX2xvYWRbbWFza2VkPUZhbHNlIGNhY2hlX21vZGlmaWVyPSBldmljdGlvbl9wb2xpY3k9IGlzX3ZvbGF0aWxlPUZhbHNlIGFyZ3NfdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhbXSwgUHlUcmVlRGVmKFtDdXN0b21Ob2RlKFNsaWNlWyhGYWxzZSwgMjU2KV0sIFsqXSksIEN1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSldKSwgW1RydWUsIFRydWVdLCAoNTEyLCAyNTYpLCAoKSldLCBbKl0pLCkpXQB0aGlyZF9wYXJ0eS9weS9qYXhfdHJpdG9uL2dvb2dsZS9wYWxsYXNfdHB1L2JhY2tfY29tcGF0X3Rlc3QucHkAL21hc2tlZF9sb2FkW21hc2tlZD1GYWxzZSBjYWNoZV9tb2RpZmllcj0gZXZpY3Rpb25fcG9saWN5PSBpc192b2xhdGlsZT1GYWxzZSBhcmdzX3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoW10sIFB5VHJlZURlZihbQ3VzdG9tTm9kZShTbGljZVsoVHJ1ZSwgMCwgMjU2KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoRmFsc2UsIDI1NildLCBbKl0pXSksIFtUcnVlLCBUcnVlXSwgKDI1NiwgNTEyKSwgKCkpXSwgWypdKSwpKV0AL2RvdF9nZW5lcmFsW2RpbWVuc2lvbl9udW1iZXJzPSgoKDEsKSwgKDAsKSksICgoKSwgKCkpKSBwcmVjaXNpb249KDxQcmVjaXNpb24uREVGQVVMVDogMD4sIDxQcmVjaXNpb24uREVGQVVMVDogMD4pIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0Ab3V0X2xheW91dAB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAb3BlcmFuZFNlZ21lbnRTaXplcwAvbWFza2VkX3N3YXBbbWFza2VkPUZhbHNlIGV2aWN0aW9uX3BvbGljeT0gYXJnc190cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFtdLCBQeVRyZWVEZWYoW0N1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSksIEN1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSldKSwgW1RydWUsIFRydWVdLCAoMjU2LCAyNTYpLCAoKSldLCBbXSksKSldAA==\22}}", kernel_name = "func", operand_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc18) + return %0 : tensor<1024x256xf32> loc(#loc17) + } loc(#loc17) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":30:0) +#loc2 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":31:0) +#loc3 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":32:0) +#loc5 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":35:0) +#loc6 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(524288,) dimension=0]"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/reshape[new_sizes=(1024, 512) dimensions=None]"(#loc2)) +#loc8 = loc("jit(func)/jit(main)/mul"(#loc1)) +#loc9 = loc("jit(func)/jit(main)/add"(#loc1)) +#loc10 = loc("jit(func)/jit(main)/slice[start_indices=(0, 0) limit_indices=(512, 256) strides=None]"(#loc3)) +#loc12 = loc("jit(func)/jit(main)/broadcast_in_dim[shape=(64, 16, 1) broadcast_dimensions=(0,)]"(#loc5)) +#loc13 = loc("jit(func)/jit(main)/broadcast_in_dim[shape=(64, 16, 1) broadcast_dimensions=(1,)]"(#loc5)) +#loc14 = loc("jit(func)/jit(main)/concatenate[dimension=2]"(#loc5)) +#loc15 = loc("jit(func)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1)) slice_sizes=(1, 1) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc5)) +#loc18 = loc("jit(func)/jit(main)/jit(matmul)/jit(wrapped)/jit(apply_kernel)/tpu_custom_call[config=CustomCallBackendConfig() kernel_name=func kernel_regeneration_metadata=None out_avals=(ShapedArray(float32[1024,256]),)]"(#loc4)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01+\x05\x01\x03\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03~\x02\xfb-\x01\xaf\x07\x0b\x0f\x0b\x0f\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0b\x13\x0b\x0f\x13\x0f\x0f+\x0b\x0f\x0b\x0b\x0b33\x0b3\x0b3\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x13\x13\x13\x13\x13\x0b\x0f\x0b\x0f\x0b\x13\x13\x0b\x13\x0b#\x0b\x0b\x0b\x0f\x0b\x13\x13\x13\x0f\x0b\x13\x0f\x0b\x13\x0b\x0f\x0b;\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x03M\x0b\x13\x0b\x0b\x0f\x0bO\x0b\x0b\x0b\x0f/\x0fO\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f&\x08\x1e\x02\x0f\x1f\x1fO///\x0b\x01\x05\x0b\x0f\x03)\x1f\x07\x1f\x1f\x07\x07\x0f\x13\x1b\x17\x13\x1f\x13\x13\x1b\x13\x07\x1b\x13\x1f\x022\x0e\x1f\x05!\x1d9\x15\x05#\x1d=\x15\x1dA\x15\x05%\x05\'\x05)\x05+\x17\x07C\x01\x05-\x17\x07G\x01\x05/\x17\x07=\x01\x051\x11\x03\x05\x03\x03\x1f\xc3\x1ds\x1d\x1dw\x1d\x03\t+-/!1!\x033\x053\x11\x01\x00\x055\x057\x059\x03\x0b\r\xaf\x0f\xcb\x11\xcd\x03\xd5\x13\xd7\x03\x0b\r\xb1\x0f\xb5\x11\xb7\x03\xbd\x13\xb9\x05;\x03\x0b\r\xb1\x0f\xb5\x11\xb7\x03\xbf\x13\xb9\x05=\x03\x0b\r\xb1\x0f\xb5\x11\xb7\x03\xc1\x13\xb9\x05?\x03\x13E\xd9G\xdbI\xddK\xafM\xdfO\xe1Q\xe3S\xafU\xe5\x05A\x05C\x05E\x05G\x05I\x05K\x05M\x05O\x05Q\x1dY\x15\x05S\x03\x03\x1b\xc1\x03\x03\x1b\xbf\x03\x03\x17\xe7\x03\x03\x17\xe9\x03\x03e\xeb\x05U\x1di\x1d\x05W\x1dmo\x05Y\x17\x07?\x01\x03\x03\x17\xed\x05[\x03\x03\x17\xef\x05]\x03\x07{\xf1}\xf3\x7f\xc5\x05_\x05a\x05c\x1d\x83\x85\x05e\x17\x07A\x01\x03\x03\x1b\xbd\x03\x03\x1f\xf5\x1d\x8d\x19\x05g\x03\x03\x1f\xf7\x1d\x93\x19\x05i\x03\x03\x97\xc7\x05k\x1d\x9b\x19\x05m\x03\r\x9f\xc9\xa1\xc7\xa3\xf9\xa5\xc3\xa7\xc5\xa9\xc9\x05o\x05q\x05s\x05u\x05w\x05y\x1d\xad\x19\x05{\x03\x01\x03\x05\xb3\xb3\r\x01#!\x03\x03\xb3\x1d}\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x7f\x1d\x81\x1d\x83\x1f)\x01\x1f\x13\x11\x01\x00\x00\x00\x00\x00\x00\x00\x13\r\t\x1f\x13!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x03\xcf\r\x03\xd1\xd3\x1d\x85\x1d\x87\x1d\x89\x1d\x8b\x0b\x03\x1d\x8d\x1d\x8f\x05\x01\x1d\x91\x03\x05\xbb\xbb\x03\x03\xbb\x1f\x17\x02\x04\x00\x00\x00\x00\x10\x00\x00\x00 \x00\x00\x000\x00\x00\x00@\x00\x00\x00P\x00\x00\x00`\x00\x00\x00p\x00\x00\x00\x80\x00\x00\x00\x90\x00\x00\x00\xa0\x00\x00\x00\xb0\x00\x00\x00\xc0\x00\x00\x00\xd0\x00\x00\x00\xe0\x00\x00\x00\xf0\x00\x00\x00\x00\x01\x00\x00\x10\x01\x00\x00 \x01\x00\x000\x01\x00\x00@\x01\x00\x00P\x01\x00\x00`\x01\x00\x00p\x01\x00\x00\x80\x01\x00\x00\x90\x01\x00\x00\xa0\x01\x00\x00\xb0\x01\x00\x00\xc0\x01\x00\x00\xd0\x01\x00\x00\xe0\x01\x00\x00\xf0\x01\x00\x00\x00\x02\x00\x00\x10\x02\x00\x00 \x02\x00\x000\x02\x00\x00@\x02\x00\x00P\x02\x00\x00`\x02\x00\x00p\x02\x00\x00\x80\x02\x00\x00\x90\x02\x00\x00\xa0\x02\x00\x00\xb0\x02\x00\x00\xc0\x02\x00\x00\xd0\x02\x00\x00\xe0\x02\x00\x00\xf0\x02\x00\x00\x00\x03\x00\x00\x10\x03\x00\x00 \x03\x00\x000\x03\x00\x00@\x03\x00\x00P\x03\x00\x00`\x03\x00\x00p\x03\x00\x00\x80\x03\x00\x00\x90\x03\x00\x00\xa0\x03\x00\x00\xb0\x03\x00\x00\xc0\x03\x00\x00\xd0\x03\x00\x00\xe0\x03\x00\x00\xf0\x03\x00\x00\x1f\x19\x81\x00\x00\x00\x00\x10\x00\x00\x00 \x00\x00\x000\x00\x00\x00@\x00\x00\x00P\x00\x00\x00`\x00\x00\x00p\x00\x00\x00\x80\x00\x00\x00\x90\x00\x00\x00\xa0\x00\x00\x00\xb0\x00\x00\x00\xc0\x00\x00\x00\xd0\x00\x00\x00\xe0\x00\x00\x00\xf0\x00\x00\x00\x13\r\x01\x1f\x11\to\x12\x83:\x1f\x11\t\x00\x00\x80?\x1f\x13!\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00\x05\x03\x01\t\x01\x02\x02)\x05\x02 \x02\x10\x07\t)\x05\x02\x10\x02\x08\x07)\x05\x02 \x02\x08\x07\x1d\x1b)\x01\x07)\x03\t\r)\x05\x02\x02A\x07)\x03\x02\x02\x0f)\x03A\x0f)\x07\x02\x02A\x05\x0f)\x03\x05\r\x11\x01\x03\x15\x11\x05\x05\t\x03\x0b)\x03\t%\x13)\x03\x04\x00\x80\x07)\x03\x01\r)\x07\x02\x02A\t\x0f\x04~\x03\x05\x01\x11\x01)\x07\x03\x01\x11\x03\x11\x015\x07\x03!E\x07\x03\x01_\x03\x17\x07\x03\x01a\x03\x19\x0f\x03gc\x03\'\x11\x06k\x03\x05\x03\x05\x07\x03\x01q\x03\x11\t\x07%#\x03\x05\x03\t\x13\x06%\x03\x05\x05\x0b\x07\x07\x03\x01u\x03\x11\t\x07\'#\x03\x05\x03\x0f\x15\x06\'\x03\x05\x05\x11\r\x17\x07\x81y\x03\t\x03\x13\x0b\x07\x05\x87\x03\x0b\x05\x13\x15\t\x07\x8b\x89\x03\x1b\x03\x01\t\x07\x91\x8f\x03\x1b\x03\x03\x19\x07\x99\x95\x03+\x05\x19\x1b\x1b\x07\xab\x9d\x03\x15\x05\x17\x1d\x05\x04\x01\x03\x1f\x03\x11\x057\x07\x03\x07\x0b\x05\x05\x05\t\x05\x0b\x07\t]\x03\x0b\x05\x01\x03\x05\x04\x05\x03\x05\x03\x11\t;\x07\x03\x07\x0b\x05\x05\t\t\t\x0b\x07\x0b[\x03\x0b\x05\x01\x03\x05\x04\t\x03\x05\x03\x11\x0b?\x07\x03\x07\x0b\x05\x05\x0b\t\x0b\r\x07WC\x03\x0b\x05\x01\x03\x05\x04\x0b\x03\x05\x06\x03\x01\x05\x01\x00\xee\xcd\x93\x0b!f\xa7\x0f\x0b\x03!\x1b\x11\x0f\x11\n\x04!\x19\x19\'#+[\x15\xa5\xa5\xad\x11\x1d\x1d11\x87\x89\x1ff\x03\x1f/!\x19!)#\x1f\x19\xa2\x03Z\x03&\x03\x13%)9+\x0f\r\x1f\x15\x1d\x15\x81\x13\x15\x1f\x13\x0f\x19\x17\x11\x1f\x11)\x19\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00constant_v1\x00broadcast_in_dim_v1\x00call_v1\x00custom_call_v1\x00iota_v1\x00reshape_v1\x00multiply_v1\x00add_v1\x00slice_v1\x00concatenate_v1\x00gather_v1\x00sym_name\x00third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00value\x00callee\x00broadcast_dimensions\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=matmul keep_unused=False inline=False]\x00jit(func)/jit(main)/jit(matmul)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=wrapped keep_unused=False inline=False]\x00jit(func)/jit(main)/jit(matmul)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=apply_kernel keep_unused=False inline=False]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00kernel_name\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/jit(matmul)/jit(wrapped)/jit(apply_kernel)/tpu_custom_call[config=CustomCallBackendConfig() kernel_name=func kernel_regeneration_metadata=None out_avals=(ShapedArray(float32[1024,256]),)]\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(524288,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(1024, 512) dimensions=None]\x00jit(func)/jit(main)/mul\x00jit(func)/jit(main)/add\x00limit_indices\x00start_indices\x00strides\x00jit(func)/jit(main)/slice[start_indices=(0, 0) limit_indices=(512, 256) strides=None]\x00jit(func)/jit(main)/broadcast_in_dim[shape=(64, 16, 1) broadcast_dimensions=(0,)]\x00jit(func)/jit(main)/broadcast_in_dim[shape=(64, 16, 1) broadcast_dimensions=(1,)]\x00dimension\x00jit(func)/jit(main)/concatenate[dimension=2]\x00collapsed_slice_dims\x00index_vector_dim\x00indices_are_sorted\x00offset_dims\x00slice_sizes\x00start_index_map\x00jit(func)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1)) slice_sizes=(1, 1) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00private\x00matmul\x00wrapped\x00apply_kernel\x00jax.result_info\x00\x00main\x00public\x00{"custom_call_config": {"body": "TUzvUgFNTElSZ29vZ2xlMy10cnVuawABLwkBAwUHAQMJAwUDCwUNDQ8RExUXBwMZA44DIgMhAfkbDw8LKxMTBxcjEwsLCwsTCwsLhQsLCxsLMwsPEw87CxMLC1MLDwsLFxsLUxsLUxsLUxsbGw8TEwsLExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTCxMTExMXBQthkWlpeQGPExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXDxcPFw8XDxcPFw8XDxcTHxMfDxcfCwsLCwtTCxMBIRsHHw8HHw8nJycLIx8nJycCwhEDBTEzNTcd7R8dcR8FGwMHEgMWAzEzNTcdGgMfHQIDHx8DAwYDOQMFCgM7DgM7AwMXOQUdBR8FIQEBF3NDAQUjBSUNGWFmZmluZV9tYXA8KGQwLCBkMSkgLT4gKGQwLCBkMSk+AAUnBSkFKwMFFx0HawUtIxUREQEBAQEBAQEBBS8RBwUBAwICERUAAw0/QRlDRUdJSxtNT1EFMQEF+/sNFwUzIw0FIQQAAAAAAAAAAQAAAAAAAAAFNRENAQU3BTkBB1NZXwMFIVUjVwkpIw0FIQABAAAAAAAAAAIAAAAAAAADBSFbI10JKyMNBSEAAgAAAAAAAAABAAAAAAAAAwUhYSNjCS0jDQUhAAEAAAAAAAAAAQAAAAAAAAMFGSUbKQMFGSUbKwMFGSUbLREHAQMDB28RA8IPBTsFPQMDB3cRA4IPAwMHexEDQg8DAwd/EQMCDwMDB4MRA8IOAwMHhxEDgg4DAweLEQNCDgMDB48RAwIOAwMHkxEDwg0DAweXEQOCDQMDB5sRA0INAwMHnxEDAg0DAwejEQPCDAMDB6cRA4IMAwMHqxEDQgwDAwevEQPCCwMDB7MRA4ILAwMHtxEDQgsDAwe7EQMCCwMDB78RA8IKAwMHwxEDggoDAwfHEQNCCgMDB8sRAwIKAwMHzxEDwgkDAwfTEQOCCQMDB9cRA0IJAwMH2xEDAgkDAwffEQPCCAMDB+MRA4IIAwMH5xEDQggDAwfrEQMCDAU/AwMH8REDwgcDAwf1EQOCBwMDBwYCI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPGFyYml0cmFyeT4AI3RwdS50aWxlZDwoOCwxMjgpLFsyLDFdPgAjdHB1LnRpbGVkPCg4LDEyOCksWzQsMV0+ACN0cHUudnBhZDwiMzIsezAsMH0sKDgsMTI4KSI+ABEDQgcDAwcOAhEDAgcDAwcWAhEDwgYDAwceAhEDggYDAwcmAhEDQgYDAwcuAhEDAgYDAwc2AhEDwgUDAwc+AhEDggUDAwdGAhEDQgUDAwdOAhEDAgUDAwdWAhEDwgQDAwdeAhEDggQDAwdmAhEDQgQDAwduAhEDwgMDAwd2AhEDggMDAwd+AhEDQgMDAweGAhEDAgMDAweOAhEDwgIDAweWAhEDggIDAweeAhEDQgIDAwemAhEDAgIDAweuAhED4QMDB7YCEQPBAwMHvgIRA6EDAwfGAhEDgQMDB84CEQNhAwMH1gIRA0EDAwfeAhEDIQMDB+YCEQMCBAMFFx0H7gIRAwIIAwUXHQf2AhEDAQMDB/4CJQEJAAAAAAVBBUMFRQVHBUkjBwkhAQAAAAEAAAACAAAAAAAAAAVLAwMXHScFIQIECQMnBQIIAgQJAQICCycFAgQCBAkBAgQX+QUCCAIQCf8X+QUCEAIICf0X+QUCCAIICf0BCQULBwcPERMBBQUHBwUHBxf5BQIIAhAJJxf5BQIQAggJJxf5BQIIAggJJwR6UQUBEA8HAwERAxEPPQcDlglODQsHDwcPDw8RDxMPEwMFbQMDEwMFdQMDEwMFeQMDEwMFfQMDEwMFgQMDEwMFhQMDEwMFiQMDEwMFjQMDEwMFkQMDEwMFlQMDEwMFmQMDEwMFnQMDEwMFoQMDEwMFpQMDEwMFqQMDEwMFrQMDEwMFsQMDEwMFtQMDEwMFuQMDEwMFvQMDEwMFwQMDEwMFxQMDEwMFyQMDEwMFzQMDEwMF0QMDEwMF1QMDEwMF2QMDEwMF3QMDEwMF4QMDEwMF5QMDEwMD6QMDEwMD7wMDEwMD8wMDEwMD9wMDEwMDCgIDAxMDAxICAwMTAwMaAgMDEwMDIgIDAxMDAyoCAwMTAwMyAgMDEwMDOgIDAxMDA0ICAwMTAwNKAgMDEwMDUgIDAxMDA1oCAwMTAwNiAgMDEwMDagIDAxMDA3ICAwMTAwN6AgMDEwMDggIDAxMDA4oCAwMTAwOSAgMDEwMDmgIDAxMDA6ICAwMTAwOqAgMDEwMDsgIDAxMDA7oCAwMTAwPCAgMDEwMDygIDAxMDA9ICAwMTAwPaAgMDEwMD4gIDAxMDA+oCAwMTAwPyAgMDEwMN+gIDAREGDwMbAwURBg8DHQMHEQYPAx8DCQcHAwEDAQeNiYkHBwMBAwEHjYmFBwcDAQMBB42DiQcHAwEDAQeNg4UHBwMBAwEHjYGJBwcDAQMBB42BhQcHAwEDAQeNf4kHBwMBAwEHjX+FBwcDAQMBB419iQcHAwEDAQeNfYUHBwMBAwEHjXuJBwcDAQMBB417hQcHAwEDAQeNeYkHBwMBAwEHjXmFBwcDAQMBB413iQcHAwEDAQeNd4UHBwMBAwEHjXWJBwcDAQMBB411hQcHAwEDAQeNc4kHBwMBAwEHjXOFBwcDAQMBB41xiQcHAwEDAQeNcYUHBwMBAwEHjW+JBwcDAQMBB41vhQcHAwEDAQeNbYkHBwMBAwEHjW2FBwcDAQMBB41riQcHAwEDAQeNa4UHBwMBAwEHjWmJBwcDAQMBB41phQcHAwEDAQeNZ4kHBwMBAwEHjWeFBwcDAQMBB42FiQcHAwEDAQeNhYUHBwMBAwEHjWWJBwcDAQMBB41lhQcHAwEDAQeNY4kHBwMBAwEHjWOFBwcDAQMBB41hiQcHAwEDAQeNYYUHBwMBAwEHjV+JBwcDAQMBB41fhQcHAwEDAQeNXYkHBwMBAwEHjV2FBwcDAQMBB41biQcHAwEDAQeNW4UHBwMBAwEHjVmJBwcDAQMBB41ZhQcHAwEDAQeNV4kHBwMBAwEHjVeFBwcDAQMBB41ViQcHAwEDAQeNVYUHBwMBAwEHjVOJBwcDAQMBB41ThQcHAwEDAQeNUYkHBwMBAwEHjVGFBwcDAQMBB41PiQcHAwEDAQeNT4UHBwMBAwEHjU2JBwcDAQMBB41NhQcHAwEDAQeNS4kHBwMBAwEHjUuFBwcDAQMBB41JiQcHAwEDAQeNSYUHBwUBAwEHj4mJBwcFAQMBB4+JhQcHBQEDAQePg4kHBwUBAwEHj4OFBwcFAQMBB4+BiQcHBQEDAQePgYUHBwUBAwEHj3+JBwcFAQMBB49/hQcHBQEDAQePfYkHBwUBAwEHj32FBwcFAQMBB497iQcHBQEDAQePe4UHBwUBAwEHj3mJBwcFAQMBB495hQcHBQEDAQePd4kHBwUBAwEHj3eFBwcFAQMBB491iQcHBQEDAQePdYUHBwUBAwEHj3OJBwcFAQMBB49zhQcHBQEDAQePcYkHBwUBAwEHj3GFBwcFAQMBB49viQcHBQEDAQePb4UHBwUBAwEHj22JBwcFAQMBB49thQcHBQEDAQePa4kHBwUBAwEHj2uFBwcFAQMBB49piQcHBQEDAQePaYUHBwUBAwEHj2eJBwcFAQMBB49nhQcHBQEDAQePhYkHBwUBAwEHj4WFBwcFAQMBB49liQcHBQEDAQePZYUHBwUBAwEHj2OJBwcFAQMBB49jhQcHBQEDAQePYYkHBwUBAwEHj2GFBwcFAQMBB49fiQcHBQEDAQePX4UHBwUBAwEHj12JBwcFAQMBB49dhQcHBQEDAQePW4kHBwUBAwEHj1uFBwcFAQMBB49ZiQcHBQEDAQePWYUHBwUBAwEHj1eJBwcFAQMBB49XhQcHBQEDAQePVYkHBwUBAwEHj1WFBwcFAQMBB49TiQcHBQEDAQePU4UHBwUBAwEHj1GJBwcFAQMBB49RhQcHBQEDAQePT4kHBwUBAwEHj0+FBwcFAQMBB49NiQcHBQEDAQePTYUHBwUBAwEHj0uJBwcFAQMBB49LhQcHBQEDAQePSYkHBwUBAwEHj0mFBwcDAQMBB42JhwcHAwEDAQeNiUcHBwMBAwEHjYOHBwcDAQMBB42DRwcHAwEDAQeNgYcHBwMBAwEHjYFHBwcDAQMBB41/hwcHAwEDAQeNf0cHBwMBAwEHjX2HBwcDAQMBB419RwcHAwEDAQeNe4cHBwMBAwEHjXtHBwcDAQMBB415hwcHAwEDAQeNeUcHBwMBAwEHjXeHBwcDAQMBB413RwcHAwEDAQeNdYcHBwMBAwEHjXVHBwcDAQMBB41zhwcHAwEDAQeNc0cHBwMBAwEHjXGHBwcDAQMBB41xRwcHAwEDAQeNb4cHBwMBAwEHjW9HBwcDAQMBB41thwcHAwEDAQeNbUcHBwMBAwEHjWuHBwcDAQMBB41rRwcHAwEDAQeNaYcHBwMBAwEHjWlHBwcDAQMBB41nhwcHAwEDAQeNZ0cHBwMBAwEHjYWHBwcDAQMBB42FRwcHAwEDAQeNZYcHBwMBAwEHjWVHBwcDAQMBB41jhwcHAwEDAQeNY0cHBwMBAwEHjWGHBwcDAQMBB41hRwcHAwEDAQeNX4cHBwMBAwEHjV9HBwcDAQMBB41dhwcHAwEDAQeNXUcHBwMBAwEHjVuHBwcDAQMBB41bRwcHAwEDAQeNWYcHBwMBAwEHjVlHBwcDAQMBB41XhwcHAwEDAQeNV0cHBwMBAwEHjVWHBwcDAQMBB41VRwcHAwEDAQeNU4cHBwMBAwEHjVNHBwcDAQMBB41RhwcHAwEDAQeNUUcHBwMBAwEHjU+HBwcDAQMBB41PRwcHAwEDAQeNTYcHBwMBAwEHjU1HBwcDAQMBB41LhwcHAwEDAQeNS0cHBwMBAwEHjUmHBwcDAQMBB41JRwcHBQEDAQePh4kHBwUBAwEHj4eFBwcFAQMBB49FiQcHBQEDAQePRYUHBwUBAwEHj0OJBwcFAQMBB49DhQcHBQEDAQePQYkHBwUBAwEHj0GFBwcFAQMBB48/iQcHBQEDAQePP4UHBwUBAwEHjz2JBwcFAQMBB489hQcHBQEDAQePO4kHBwUBAwEHjzuFBwcFAQMBB485iQcHBQEDAQePOYUHBwUBAwEHjzeJBwcFAQMBB483hQcHBQEDAQePNYkHBwUBAwEHjzWFBwcFAQMBB48ziQcHBQEDAQePM4UHBwUBAwEHjzGJBwcFAQMBB48xhQcHBQEDAQePL4kHBwUBAwEHjy+FBwcFAQMBB48tiQcHBQEDAQePLYUHBwUBAwEHjyuJBwcFAQMBB48rhQcHBQEDAQePKYkHBwUBAwEHjymFBwcFAQMBB49HiQcHBQEDAQePR4UHBwUBAwEHjyeJBwcFAQMBB48nhQcHBQEDAQePJYkHBwUBAwEHjyWFBwcFAQMBB48jiQcHBQEDAQePI4UHBwUBAwEHjyGJBwcFAQMBB48hhQcHBQEDAQePH4kHBwUBAwEHjx+FBwcFAQMBB48diQcHBQEDAQePHYUHBwUBAwEHjxuJBwcFAQMBB48bhQcHBQEDAQePGYkHBwUBAwEHjxmFBwcFAQMBB48XiQcHBQEDAQePF4UHBwUBAwEHjxWJBwcFAQMBB48VhQcHBQEDAQePE4kHBwUBAwEHjxOFBwcFAQMBB48RiQcHBQEDAQePEYUHBwUBAwEHjw+JBwcFAQMBB48PhQcHBQEDAQePDYkHBwUBAwEHjw2FBwcFAQMBB48LiQcHBQEDAQePC4ULBw0RAwVBJgMuAzYDPgNGA04DVgNeA2YDbgN2A34DhgOOA5YDngOmA64DtgO+A8YDzgPWA94D5gPuA/YD/gMGBA4EFgQeBAsHDREDBUEqAzIDOgNCA0oDUgNaA2IDagNyA3oDggOKA5IDmgOiA6oDsgO6A8IDygPSA9oD4gPqA/ID+gMCBAoEEgQaBCIECwcNEQMLISYELgQ2BD4ERgROBFYEXgRmBG4EdgR+BIYEjgSWBJ4ECwcNEQMFQYuLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLDQcNEwMFByYFLgUyBQ8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEDNgULBw0RAwshpgSuBLYEvgTGBM4E1gTeBOYE7gT2BP4EBgUOBRYFHgULBw0RAwVBOgU+BUIFRgVKBU4FUgVWBVoFXgViBWYFagVuBXIFdgV6BX4FggWGBYoFjgWSBZYFmgWeBaIFpgWqBa4FsgW2BQ0HDRMDBQcqBboFvgUPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA8IFCwcNEQMLISoEMgQ6BEIESgRSBFoEYgRqBHIEegSCBIoEkgSaBKIECwcNEQMFQYuLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLDQcNEwMFByYFRgZKBg8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEDTgYLBw0RAwshqgSyBLoEwgTKBNIE2gTiBOoE8gT6BAIFCgUSBRoFIgULBw0RAwVBUgZWBloGXgZiBmYGagZuBnIGdgZ6Bn4GggaGBooGjgaSBpYGmgaeBqIGpgaqBq4Gsga2BroGvgbCBsYGygbOBg0HDRMDBQcqBdIG1gYPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA9oGCwcNEQMFQZOXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+/8GAg4CFgIeAgsHDREDBUGVmZ2hpamtsbW5vcHFyc3R1dnd4eXp7fH1+f0CAgoCEgIaAiICCwcNEQMLISYCLgI2Aj4CRgJOAlYCXgJmAm4CdgJ+AoYCjgKWAp4CCwcNEQMFQcYFygXOBdIF1gXaBd4F4gXmBeoF7gXyBfYF+gX+BQIGBgYKBg4GEgYWBhoGHgYiBiYGKgYuBjIGNgY6Bj4GQgYNBw0TAwUHXgdmB2oHDwcNFUEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQNuBwsHDREDCyGmAq4CtgK+AsYCzgLWAt4C5gLuAvYC/gIGAw4DFgMeAwsHDREDBUFyB3YHegd+B4IHhgeKB44HkgeWB5oHngeiB6YHqgeuB7IHtge6B74HwgfGB8oHzgfSB9YH2gfeB+IH5gfqB+4HDQcNEwMFB2IH8gf2Bw8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQED+gcLBw0RAwshKgIyAjoCQgJKAlICWgJiAmoCcgJ6AoICigKSApoCogILBw0RAwVB3gbiBuYG6gbuBvIG9gb6Bv4GAgcGBwoHDgcSBxYHGgceByIHJgcqBy4HMgc2BzoHPgdCB0YHSgdOB1IHVgdaBw0HDRMDBQdeB34IgggPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA4YICwcNEQMLIaoCsgK6AsICygLSAtoC4gLqAvIC+gICAwoDEgMaAyIDCwcNEQMFQYoIjgiSCJYImgieCKIIpgiqCK4Isgi2CLoIvgjCCMYIygjOCNII1gjaCN4I4gjmCOoI7gjyCPYI+gj+CAIJBgkNBw0TAwUHYgcKCQ4JDwcNFUEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQMSCQkFCwkJ/geRiYkJBQsJCRYJkYmFCQULCQkCCJGDiQkFCwkJGgmRg4UJBQsJCQYIkYGJCQULCQkeCZGBhQkFCwkJCgiRf4kJBQsJCSIJkX+FCQULCQkOCJF9iQkFCwkJJgmRfYUJBQsJCRIIkXuJCQULCQkqCZF7hQkFCwkJFgiReYkJBQsJCS4JkXmFCQULCQkaCJF3iQkFCwkJMgmRd4UJBQsJCR4IkXWJCQULCQk2CZF1hQkFCwkJIgiRc4kJBQsJCToJkXOFCQULCQkmCJFxiQkFCwkJPgmRcYUJBQsJCSoIkW+JCQULCQlCCZFvhQkFCwkJLgiRbYkJBQsJCUYJkW2FCQULCQkyCJFriQkFCwkJSgmRa4UJBQsJCTYIkWmJCQULCQlOCZFphQkFCwkJOgiRZ4kJBQsJCVIJkWeFCQULCQk+CJGFiQkFCwkJVgmRhYUJBQsJCUIIkWWJCQULCQlaCZFlhQkFCwkJRgiRY4kJBQsJCV4JkWOFCQULCQlKCJFhiQkFCwkJYgmRYYUJBQsJCU4IkV+JCQULCQlmCZFfhQkFCwkJUgiRXYkJBQsJCWoJkV2FCQULCQlWCJFbiQkFCwkJbgmRW4UJBQsJCVoIkVmJCQULCQlyCZFZhQkFCwkJXgiRV4kJBQsJCXYJkVeFCQULCQliCJFViQkFCwkJegmRVYUJBQsJCWYIkVOJCQULCQl+CZFThQkFCwkJagiRUYkJBQsJCYIJkVGFCQULCQluCJFPiQkFCwkJhgmRT4UJBQsJCXIIkU2JCQULCQmKCZFNhQkFCwkJdgiRS4kJBQsJCY4JkUuFCQULCQl6CJFJiQkFCwkJkgmRSYUFAQ8eAwMRD2UHAwcLBQcPBw8TAw8vAwcFBA8FAQUDEQ9nBwMHCwUHDwcPEwMPLwMHBQQPBQUDAxEPaQcDBQcFBw8HDwUEDwUBAwYDAQUBAE4VTXYDKR0dF04C/gOB/gMdCyEjKR8bGRkZHSUTHRUNEykfDxsNCw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAbW9kdWxlAHJldHVybgBsb2FkAHN0b3JlAHJvbGxfdmVjdG9ycwBtYXRtdWwAdW5yb2xsX3ZlY3RvcnMAZXJhc2VfbWVtcmVmX2xheW91dABjb25zdGFudAB2YWx1ZQBpbl9sYXlvdXQAZnVuY3Rpb25fdHlwZQBzeW1fbmFtZQB0cmFuc2Zvcm1faW5kaWNlcwB3aW5kb3dfYm91bmRzAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHN1YmxhbmVfbWFzawBzdWJsYW5lX3N0cmlkZQBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAG1haW4Ad2luZG93X3BhcmFtcwAvbWFza2VkX2xvYWRbbWFza2VkPUZhbHNlIGNhY2hlX21vZGlmaWVyPSBldmljdGlvbl9wb2xpY3k9IGlzX3ZvbGF0aWxlPUZhbHNlIGFyZ3NfdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhbXSwgUHlUcmVlRGVmKFtDdXN0b21Ob2RlKFNsaWNlWyhGYWxzZSwgMjU2KV0sIFsqXSksIEN1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSldKSwgW1RydWUsIFRydWVdLCAoNTEyLCAyNTYpLCAoKSldLCBbKl0pLCkpXQB0aGlyZF9wYXJ0eS9weS9qYXhfdHJpdG9uL2dvb2dsZS9wYWxsYXNfdHB1L2JhY2tfY29tcGF0X3Rlc3QucHkAL21hc2tlZF9sb2FkW21hc2tlZD1GYWxzZSBjYWNoZV9tb2RpZmllcj0gZXZpY3Rpb25fcG9saWN5PSBpc192b2xhdGlsZT1GYWxzZSBhcmdzX3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoW10sIFB5VHJlZURlZihbQ3VzdG9tTm9kZShTbGljZVsoVHJ1ZSwgMCwgMjU2KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoRmFsc2UsIDI1NildLCBbKl0pXSksIFtUcnVlLCBUcnVlXSwgKDI1NiwgNTEyKSwgKCkpXSwgWypdKSwpKV0AL2RvdF9nZW5lcmFsW2RpbWVuc2lvbl9udW1iZXJzPSgoKDEsKSwgKDAsKSksICgoKSwgKCkpKSBwcmVjaXNpb249KDxQcmVjaXNpb24uREVGQVVMVDogMD4sIDxQcmVjaXNpb24uREVGQVVMVDogMD4pIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0Ab3V0X2xheW91dAB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAb3BlcmFuZFNlZ21lbnRTaXplcwAvbWFza2VkX3N3YXBbbWFza2VkPUZhbHNlIGV2aWN0aW9uX3BvbGljeT0gYXJnc190cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFtdLCBQeVRyZWVEZWYoW0N1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSksIEN1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSldKSwgW1RydWUsIFRydWVdLCAoMjU2LCAyNTYpLCAoKSldLCBbXSksKSldAA=="}}\x00tpu_custom_call\x00func\x00', + xla_call_module_version=7, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py new file mode 100644 index 000000000000..a44e92846b98 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py @@ -0,0 +1,95 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from numpy import array, float32 + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +semaphore_and_dma_2024_04_22 = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['tpu_custom_call'], + serialized_date=datetime.date(2024, 4, 22), + inputs=(), + expected_outputs=(array(1., dtype=float32),), + mlir_module_text=r""" +#loc2 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":60:4) +#loc3 = loc("third_party/py/absl/testing/absltest.py":2718:19) +#loc4 = loc("third_party/py/absl/testing/absltest.py":2754:35) +#loc5 = loc("third_party/py/absl/testing/absltest.py":2298:6) +#loc6 = loc("third_party/py/absl/app.py":395:13) +#loc7 = loc("third_party/py/absl/app.py":473:6) +#loc8 = loc("third_party/py/absl/testing/absltest.py":2300:4) +#loc9 = loc("third_party/py/absl/testing/absltest.py":2182:2) +#loc10 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":64:2) +#loc11 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":57:10) +#loc14 = loc("PallasKernelTest.test_semaphore_and_dma_22_04_2024"(#loc2)) +#loc15 = loc("_run_and_get_tests_result"(#loc3)) +#loc16 = loc("run_tests"(#loc4)) +#loc17 = loc("_run_in_app..main_function"(#loc5)) +#loc18 = loc("_run_main"(#loc6)) +#loc19 = loc("run"(#loc7)) +#loc20 = loc("_run_in_app"(#loc8)) +#loc21 = loc("main"(#loc9)) +#loc22 = loc(""(#loc10)) +#loc23 = loc("PallasKernelTest.test_semaphore_and_dma_22_04_2024..func"(#loc11)) +#loc25 = loc(callsite(#loc21 at #loc22)) +#loc26 = loc(callsite(#loc20 at #loc25)) +#loc27 = loc(callsite(#loc19 at #loc26)) +#loc28 = loc(callsite(#loc18 at #loc27)) +#loc29 = loc(callsite(#loc17 at #loc28)) +#loc30 = loc(callsite(#loc16 at #loc29)) +#loc31 = loc(callsite(#loc15 at #loc30)) +#loc32 = loc(callsite(#loc14 at #loc31)) +#loc34 = loc(callsite(#loc23 at #loc32)) +#loc38 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc34)) +#loc42 = loc("jit(func)/jit(main)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=apply_kernel keep_unused=False inline=False]"(#loc34)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16384xf32> loc(#loc36) + %1 = stablehlo.reshape %0 : (tensor<16384xf32>) -> tensor<128x128xf32> loc(#loc37) + %2 = call @wrapped(%1) : (tensor<128x128xf32>) -> tensor<128x128xf32> loc(#loc38) + %3 = stablehlo.compare EQ, %1, %2, FLOAT : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xi1> loc(#loc39) + %c = stablehlo.constant dense : tensor loc(#loc40) + %4 = stablehlo.reduce(%3 init: %c) applies stablehlo.and across dimensions = [0, 1] : (tensor<128x128xi1>, tensor) -> tensor loc(#loc40) + %5 = stablehlo.convert %4 : (tensor) -> tensor loc(#loc41) + return %5 : tensor loc(#loc) + } loc(#loc) + func.func private @wrapped(%arg0: tensor<128x128xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc34))) -> (tensor<128x128xf32> {mhlo.layout_mode = "default"}) { + %0 = call @apply_kernel(%arg0) : (tensor<128x128xf32>) -> tensor<128x128xf32> loc(#loc42) + return %0 : tensor<128x128xf32> loc(#loc38) + } loc(#loc38) + func.func private @apply_kernel(%arg0: tensor<128x128xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=apply_kernel keep_unused=False inline=False]"(#loc34))) -> (tensor<128x128xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @tpu_custom_call(%arg0) {backend_config = "{\22custom_call_config\22: {\22body\22: \22TUzvUgFNTElSZ29vZ2xlMy10cnVuawABJwcBAwUBAwcDFQkLDQ8RExUXGRsD27UTAbELBwsPCw8PCw8PDw8PDw8LDw9VDxMPDxMLDzMLCwsLhQsLCwsPCxMPCxMPCxMPCxcPCxcPCxcPCxcPCxcPCxcPFw8LDxMPDw8PDw8PDwsLDwsPDxMLDw8TBQWFYQEPJw8PFwcXFwUFTT0CzgYFHR8FHx1HSQUhFRGLEQUBBSMdS00dUVMdV1kdXV8dY2UdaWsdb3EFJR11dx17fWFmZmluZV9tYXA8KCkgLT4gKCk+ABWHCwMDnZ8doaMdqasDAzEzBScRBQUDCzc5Oz1BDUMNRQ8FKQEBBSsNB2FmZmluZV9tYXA8KGQwLCBkMSkgLT4gKGQwLCBkMSk+AAUtBS8FMQUzFRFPBTUXAWsRFRNVBTcXAXMVFRVbBTkXAXkJFRdhBTsXBXoqJxUZZwU9FwUKK0cVG20FPxcF6iMNFR1zBUEXHy4GGxUheQVDFx9mBw0VI38FRRcF8iMJHQ+BFwUaIgUdhScFRx0JiRcBZRUVE40VFY8VF5EVGZMVG5UVHZcVISMdmycFSQVLEQMFBU0VpQsdCacXAWcVBU8VrQsdCa8XAWkVI3RwdS5tZW1vcnlfc3BhY2U8c2VtYXBob3JlX21lbT4AI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF7MFAgQCBAk/AQICAQIEBQUBAQELF7EBDyUXsQERJSF0cHUuZG1hX3NlbWFwaG9yZQAhdHB1LnNlbWFwaG9yZQAEpQUBEQMvBwMBBQcRAzUHAwULBQEDAQMJEAcFAwklAwIHAwsDAgcDDQ0EgwcBAwUPBJkFBQMFAyspAwMRBCsFBwkFAy0pAwMTBC0FBwsVAAcLAAMGAwEFAQA+ElFjtQ2XyxkJFUcVNWeDqxkTIyEdKS03C8dRgRUbHxshGRcVHx0PCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQBtb2R1bGUAdHB1LnNlbV9hbGxvYwBhcml0aC5jb25zdGFudABmdW5jLmZ1bmMAdHB1LnJlZ2lvbgBmdW5jLnJldHVybgB0cHUuZW5xdWV1ZV9kbWEAdHB1LndhaXRfZG1hAHRwdS5zZW1fc2lnbmFsAHRwdS5zZW1fd2FpdAB0cHUueWllbGQAdGhpcmRfcGFydHkvcHkvamF4X3RyaXRvbi9nb29nbGUvcGFsbGFzX3RwdS9iYWNrX2NvbXBhdF90ZXN0LnB5AHRoaXJkX3BhcnR5L3B5L2Fic2wvdGVzdGluZy9hYnNsdGVzdC5weQBQYWxsYXNLZXJuZWxUZXN0LnRlc3Rfc2VtYXBob3JlX2FuZF9kbWFfMjJfMDRfMjAyNC48bG9jYWxzPi5mdW5jLjxsb2NhbHM+LmRtYV9rZXJuZWwuPGxvY2Fscz4uYm9keQBtYWluAHRoaXJkX3BhcnR5L3B5L2Fic2wvYXBwLnB5AHN0YWJsZV9tb3NhaWMudmVyc2lvbgBkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAL3J1bl9zY29wZWQAUGFsbGFzS2VybmVsVGVzdC50ZXN0X3NlbWFwaG9yZV9hbmRfZG1hXzIyXzA0XzIwMjQuPGxvY2Fscz4uZnVuYy48bG9jYWxzPi5kbWFfa2VybmVsAFBhbGxhc0tlcm5lbFRlc3QudGVzdF9zZW1hcGhvcmVfYW5kX2RtYV8yMl8wNF8yMDI0Ljxsb2NhbHM+LmZ1bmMAUGFsbGFzS2VybmVsVGVzdC50ZXN0X3NlbWFwaG9yZV9hbmRfZG1hXzIyXzA0XzIwMjQAX3J1bl9hbmRfZ2V0X3Rlc3RzX3Jlc3VsdABydW5fdGVzdHMAX3J1bl9pbl9hcHAuPGxvY2Fscz4ubWFpbl9mdW5jdGlvbgBfcnVuX21haW4AcnVuAF9ydW5faW5fYXBwAC9kbWFfc3RhcnRbdHJlZT1QeVRyZWVEZWYoKCosICgpLCAqLCAoKSwgKiwgKCksIE5vbmUsIE5vbmUsIE5vbmUpKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AL2RtYV93YWl0W3RyZWU9UHlUcmVlRGVmKCgqLCAoKSwgKiwgKCkpKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AdmFsdWUAL3NlbWFwaG9yZV9zaWduYWxbYXJnc190cmVlPVB5VHJlZURlZihbKiwgKCksICosIE5vbmVdKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AL3NlbWFwaG9yZV93YWl0W2FyZ3NfdHJlZT1QeVRyZWVEZWYoWyosICgpLCAqXSldAA==\22, \22serialization_format\22: 1, \22needs_layout_passes\22: true}, \22implicit_sharding\22: {\22type\22: \22MANUAL\22}}", kernel_name = "dma_kernel", operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<128x128xf32>) -> tensor<128x128xf32> loc(#loc43) + return %0 : tensor<128x128xf32> loc(#loc42) + } loc(#loc42) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":56:10) +#loc12 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":58:13) +#loc13 = loc("PallasKernelTest.test_semaphore_and_dma_22_04_2024..func"(#loc1)) +#loc24 = loc("PallasKernelTest.test_semaphore_and_dma_22_04_2024..func"(#loc12)) +#loc33 = loc(callsite(#loc13 at #loc32)) +#loc35 = loc(callsite(#loc24 at #loc32)) +#loc36 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(16384,) dimension=0]"(#loc33)) +#loc37 = loc("jit(func)/jit(main)/reshape[new_sizes=(128, 128) dimensions=None]"(#loc33)) +#loc39 = loc("jit(func)/jit(main)/eq"(#loc35)) +#loc40 = loc("jit(func)/jit(main)/reduce_and[axes=(0, 1)]"(#loc35)) +#loc41 = loc("jit(func)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]"(#loc35)) +#loc43 = loc("jit(func)/jit(main)/jit(wrapped)/jit(apply_kernel)/tpu_custom_call[config=CustomCallBackendConfig() kernel_name=dma_kernel kernel_regeneration_metadata=None out_avals=(ShapedArray(float32[128,128]),) input_output_aliases=()]"(#loc34)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\'\x05\x01\x03\x01\x03\x05\x03\x17\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x03z\x02\n\x02\x1f\x01\xc9\x0f\x0b\x0b\x0b\x0f\x0f\x07\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0f\x0f\x0b\x0b\x0f+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x13\x0f\x0b\x13\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0b\x133\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x13\x13\x0b\x0f\x0b\x0f\x13\x0f\x0b\x13\x1b\x0b\x0b\x0f\x0b\x0f\x13\x13\x0b\x0b\x13\x0b\x039\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0bO\x0f\x0b\x0b\x13O\x01\x05\x13\x0b\x01\x05\x0b\x0f\x03\x1b\x1f\x0f\x07\x0f\x07\x07\x13\x17\x13\x07\x1b\x1f\x13\x02\xb2\x07\x1d\xc3\x1d\x05\x1d\x05\x1f\x05!\x1d7\x17\x1d\x83\x17\x1f\x05#\x05%\x05\'\x05)\x159\x1b\x05+\x15=C\x15\xbb\x1b\x11\x03\x05\x05-\x05/\x15\xa7\x1b\x03\t)+-\x1f/\x1f\x071\x051\x11\x01\x00\x053\x055\x057\x03\x0b\x0f\xcb\x11\xdb\x13\xdd\x07\xe5\x15\xe7\x03\x0b\x0f\xc9\x11\xd1\x13\xc9\x07\xd3\x15\xd5\x059\x1d\x19;\x17\x03s\x15\x1d?A\x05;\x17\x03y\t\x15EK\x1dGI\x05=\x17\x05z*\'\x15MS\x1dOQ\x05?\x17\x05\n+G\x15U[\x1dWY\x05A\x17\x05\xea#\r\x15]c\x1d_a\x05C\x17!.\x06\x1b\x15ek\x1dgi\x05E\x17!f\x07\r\x15ms\x1doq\x05G\x17\x05\xf2#\t\x15u{\x1dwy\x05I\x17\x05\x1a"\x05\x1d}\x7f\x05K\x17\x03\x81\x05\x03\x0b\x0f\xc9\x11\xd1\x13\xc9\x07\xd7\x15\xd5\x05M\x03\x13\x87\xeb\x89\xed\x8b\xef\x8d\xcb\x8f\xf1\x91\xf3\x93\xd9\x95\xcb\x97\xd9\x05O\x05Q\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x1d\x9b\x17\x05a\x03\x03#\xd7\x03\x03\xa1\xf7\x05c\x1d\xa5%\x05e\x1d\x19\xa9\x17\x03q\x15\x1d\xad%\x05g\x03\x03#\xd3\x03\x05\xb3\xf9\xb5\xfb\x05i\x05k\x1d\xb9\x1d\x05m\x1d\x19\xbd\x17\x03u\x1b\x03\x03\xc1\xfd\x05o\x05q\x03\x03\xc7\xff\x05s\x03\x03\xe9\x03\x01\x1du\x1dw#\x13\x1dy\x1d{\x1d}\x03\x03\xf5#\x11\x03\x03\xdf\r\x05\xe1\xe3\xcd\xcf\x1d\x7f\x1d\x81\x1dI\x1d\x83\r\x03\xcd\xcf\x0b\x03\x1d\x85\x1d\x87\x05\x01\x1d\x89\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x13\r\x01\t\x03\x07\x01\x1f\x07\x03\xff\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d\x06\x02\x1d\x05\x8b\x01\t\x01\x02\x02)\x05\x02\x04\x02\x04\t)\x01\x0f\t)\x01\t\x1d\x01\x11\x01\x03\x0b\x11\x03\x05\x03\x05)\x03\t\x17\x13)\x03\x04\x00\x04\t)\x05\x02\x04\x02\x04\x0f)\x03\t\r\x04F\x02\x05\x01\x11\r\'\x07\x03\x01\r\x05\x11\r3\x07\x03\x0f!\x0b\x03\xa3\x9f\x03\x19\r\x06\xab\x03\x05\x03\x01\x07\x07\t\xaf\x03\x05\x03\x03\x0f\x07\xb7\xb1\x03\x1b\x05\x03\x05\x11\x03\x01\xbf\x03\x07\x13\x17\x01\xc5\x03\x07\x05\x07\t\x07\x03\x07\x0b\x05\x07\x01\x07\x01\x17\x06\x01\x03\x07\x05\x01\x03\x03\x04\x01\x03\x05\x15\x06\x02\x02\x03\x0b\x03\x0b\x03\x04\r\x03\r\x05\x11\t5\x07\x03\x05\x0b\x03\x05\t\x07\x07\x0b\x9d\x03\x05\x03\x01\x03\x04\t\x03\x03\x05\x11\x0b\x81\x07\x03\x05\x0b\x03\x05\x0b\t\x07\x99\x85\x03\x05\x03\x01\x03\x04\x0b\x03\x03\x06\x03\x01\x05\x01\x00\xbeG\x8d\x99\x17!\xba(\x0f\x03!\x1b\x11\x11\x11#\x17Y\r/+\x1b\x85\x87\x1f\xaa\x03\x1f/!\x19!)#\x1f\x19\xb2\x03\x13\x0b\x19\t\x15G\x155gj\x03\x13%)9\x0f7\x83\x1f\x15\x1d\x15\x13Q\x81\x0f\x17\x15\x19\x17\x17\x11\x1f\x11\x11\x15\x0f\x0b\x11builtin\x00vhlo\x00module\x00return_v1\x00func_v1\x00call_v1\x00custom_call_v1\x00iota_v1\x00reshape_v1\x00compare_v1\x00constant_v1\x00reduce_v1\x00convert_v1\x00and_v1\x00third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py\x00third_party/py/absl/testing/absltest.py\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00PallasKernelTest.test_semaphore_and_dma_22_04_2024..func\x00third_party/py/absl/app.py\x00callee\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]\x00PallasKernelTest.test_semaphore_and_dma_22_04_2024\x00_run_and_get_tests_result\x00run_tests\x00_run_in_app..main_function\x00_run_main\x00run\x00_run_in_app\x00main\x00\x00jit(func)/jit(main)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=apply_kernel keep_unused=False inline=False]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00kernel_name\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/jit(wrapped)/jit(apply_kernel)/tpu_custom_call[config=CustomCallBackendConfig() kernel_name=dma_kernel kernel_regeneration_metadata=None out_avals=(ShapedArray(float32[128,128]),) input_output_aliases=()]\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(16384,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(128, 128) dimensions=None]\x00compare_type\x00comparison_direction\x00jit(func)/jit(main)/eq\x00value\x00jit(func)/jit(main)/reduce_and[axes=(0, 1)]\x00dimensions\x00mhlo.layout_mode\x00default\x00wrapped\x00private\x00apply_kernel\x00jax.result_info\x00\x00public\x00{"custom_call_config": {"body": "TUzvUgFNTElSZ29vZ2xlMy10cnVuawABJwcBAwUBAwcDFQkLDQ8RExUXGRsD27UTAbELBwsPCw8PCw8PDw8PDw8LDw9VDxMPDxMLDzMLCwsLhQsLCwsPCxMPCxMPCxMPCxcPCxcPCxcPCxcPCxcPCxcPFw8LDxMPDw8PDw8PDwsLDwsPDxMLDw8TBQWFYQEPJw8PFwcXFwUFTT0CzgYFHR8FHx1HSQUhFRGLEQUBBSMdS00dUVMdV1kdXV8dY2UdaWsdb3EFJR11dx17fWFmZmluZV9tYXA8KCkgLT4gKCk+ABWHCwMDnZ8doaMdqasDAzEzBScRBQUDCzc5Oz1BDUMNRQ8FKQEBBSsNB2FmZmluZV9tYXA8KGQwLCBkMSkgLT4gKGQwLCBkMSk+AAUtBS8FMQUzFRFPBTUXAWsRFRNVBTcXAXMVFRVbBTkXAXkJFRdhBTsXBXoqJxUZZwU9FwUKK0cVG20FPxcF6iMNFR1zBUEXHy4GGxUheQVDFx9mBw0VI38FRRcF8iMJHQ+BFwUaIgUdhScFRx0JiRcBZRUVE40VFY8VF5EVGZMVG5UVHZcVISMdmycFSQVLEQMFBU0VpQsdCacXAWcVBU8VrQsdCa8XAWkVI3RwdS5tZW1vcnlfc3BhY2U8c2VtYXBob3JlX21lbT4AI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF7MFAgQCBAk/AQICAQIEBQUBAQELF7EBDyUXsQERJSF0cHUuZG1hX3NlbWFwaG9yZQAhdHB1LnNlbWFwaG9yZQAEpQUBEQMvBwMBBQcRAzUHAwULBQEDAQMJEAcFAwklAwIHAwsDAgcDDQ0EgwcBAwUPBJkFBQMFAyspAwMRBCsFBwkFAy0pAwMTBC0FBwsVAAcLAAMGAwEFAQA+ElFjtQ2XyxkJFUcVNWeDqxkTIyEdKS03C8dRgRUbHxshGRcVHx0PCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQBtb2R1bGUAdHB1LnNlbV9hbGxvYwBhcml0aC5jb25zdGFudABmdW5jLmZ1bmMAdHB1LnJlZ2lvbgBmdW5jLnJldHVybgB0cHUuZW5xdWV1ZV9kbWEAdHB1LndhaXRfZG1hAHRwdS5zZW1fc2lnbmFsAHRwdS5zZW1fd2FpdAB0cHUueWllbGQAdGhpcmRfcGFydHkvcHkvamF4X3RyaXRvbi9nb29nbGUvcGFsbGFzX3RwdS9iYWNrX2NvbXBhdF90ZXN0LnB5AHRoaXJkX3BhcnR5L3B5L2Fic2wvdGVzdGluZy9hYnNsdGVzdC5weQBQYWxsYXNLZXJuZWxUZXN0LnRlc3Rfc2VtYXBob3JlX2FuZF9kbWFfMjJfMDRfMjAyNC48bG9jYWxzPi5mdW5jLjxsb2NhbHM+LmRtYV9rZXJuZWwuPGxvY2Fscz4uYm9keQBtYWluAHRoaXJkX3BhcnR5L3B5L2Fic2wvYXBwLnB5AHN0YWJsZV9tb3NhaWMudmVyc2lvbgBkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAL3J1bl9zY29wZWQAUGFsbGFzS2VybmVsVGVzdC50ZXN0X3NlbWFwaG9yZV9hbmRfZG1hXzIyXzA0XzIwMjQuPGxvY2Fscz4uZnVuYy48bG9jYWxzPi5kbWFfa2VybmVsAFBhbGxhc0tlcm5lbFRlc3QudGVzdF9zZW1hcGhvcmVfYW5kX2RtYV8yMl8wNF8yMDI0Ljxsb2NhbHM+LmZ1bmMAUGFsbGFzS2VybmVsVGVzdC50ZXN0X3NlbWFwaG9yZV9hbmRfZG1hXzIyXzA0XzIwMjQAX3J1bl9hbmRfZ2V0X3Rlc3RzX3Jlc3VsdABydW5fdGVzdHMAX3J1bl9pbl9hcHAuPGxvY2Fscz4ubWFpbl9mdW5jdGlvbgBfcnVuX21haW4AcnVuAF9ydW5faW5fYXBwAC9kbWFfc3RhcnRbdHJlZT1QeVRyZWVEZWYoKCosICgpLCAqLCAoKSwgKiwgKCksIE5vbmUsIE5vbmUsIE5vbmUpKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AL2RtYV93YWl0W3RyZWU9UHlUcmVlRGVmKCgqLCAoKSwgKiwgKCkpKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AdmFsdWUAL3NlbWFwaG9yZV9zaWduYWxbYXJnc190cmVlPVB5VHJlZURlZihbKiwgKCksICosIE5vbmVdKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AL3NlbWFwaG9yZV93YWl0W2FyZ3NfdHJlZT1QeVRyZWVEZWYoWyosICgpLCAqXSldAA==", "serialization_format": 1, "needs_layout_passes": true}, "implicit_sharding": {"type": "MANUAL"}}\x00tpu_custom_call\x00dma_kernel\x00jit(func)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/cuda_add_one.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/triton_add_one.py similarity index 100% rename from jax/_src/internal_test_util/export_back_compat_test_data/pallas/cuda_add_one.py rename to jax/_src/internal_test_util/export_back_compat_test_data/pallas/triton_add_one.py diff --git a/jax/experimental/pallas/ops/tpu/matmul.py b/jax/experimental/pallas/ops/tpu/matmul.py new file mode 100644 index 000000000000..2145fbc95b55 --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/matmul.py @@ -0,0 +1,85 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example matmul TPU kernel. + +See discussion in https://jax.readthedocs.io/en/latest/pallas/tpu/matmul.html. +""" + +import functools + +import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + + +def matmul_kernel(x_tile_ref, y_tile_ref, o_tile_ref, acc_ref): + @pl.when(pl.program_id(2) == 0) + def init(): + acc_ref[...] = jnp.zeros_like(acc_ref) + + acc_ref[...] = acc_ref[...] + jnp.dot( + x_tile_ref[...], + y_tile_ref[...], + preferred_element_type=acc_ref.dtype, + ) + # It is possible to make this conditional but in general this bundle packs + # quite well for a simple matmul kernel + o_tile_ref[...] = acc_ref[...].astype(o_tile_ref.dtype) + + +@functools.partial( + jax.jit, static_argnames=["block_shape", "block_k", "debug", "out_dtype"] +) +def matmul( + x: jax.Array, + y: jax.Array, + *, + block_shape, + block_k: int = 256, + out_dtype: jnp.dtype | None = None, + debug: bool = False, +) -> jax.Array: + if out_dtype is None: + if x.dtype != y.dtype: + # TODO(tlongeri): Maybe we could use a deduction similar to jnp.dot + raise TypeError( + f"Cannot deduce output dtype for different input dtypes: {x.dtype}," + f" {y.dtype}" + ) + out_dtype = x.dtype + acc_dtype = jnp.float32 + if x.dtype in [jnp.int8, jnp.int4, jnp.uint8, jnp.uint4]: + acc_dtype = jnp.int32 + + l, r = block_shape + return pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), out_dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec((l, block_k), lambda i, _, k: (i, k)), + pl.BlockSpec((block_k, r), lambda _, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((l, r), lambda i, j, k: (i, j)), + grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k), + scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)], + ), + compiler_params=dict( + mosaic=dict(dimension_semantics=("parallel", "parallel", "arbitrary")) + ), + debug=debug, + )(x, y) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 3e1fd863a3ab..2076519f1af3 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -203,6 +203,7 @@ jax_test( "//jax:internal_export_back_compat_test_util", "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_tpu_ops", # build_cleaner: keep ], ) diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index 8cf3f9708e38..0804cf04af9b 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -17,15 +17,21 @@ update these tests. """ -from absl.testing import absltest +import math +from absl.testing import absltest import jax -import jax.numpy as jnp from jax._src import config from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu -from jax._src.internal_test_util.export_back_compat_test_data.pallas import cuda_add_one +from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_matmul +from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_semaphore_dma +from jax._src.internal_test_util.export_back_compat_test_data.pallas import triton_add_one from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu import matmul +import jax.numpy as jnp + config.parse_flags_with_absl() @@ -36,14 +42,12 @@ class CompatTest(bctu.CompatTestBase): def setUp(self): if jax.config.x64_enabled: self.skipTest("Only works in 32-bit") - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() - def test_cuda_add_one(self): + def test_triton_add_one(self): def func(x): def add_one(x_ref, o_ref): o_ref[0] = x_ref[0] + 1 @@ -52,8 +56,51 @@ def add_one(x_ref, o_ref): in_specs=[pl.BlockSpec((1,), lambda i: i)], out_specs=pl.BlockSpec((1,), lambda i: i), grid=8)(x) - data = self.load_testdata(cuda_add_one.data_2024_05_02) + data = self.load_testdata(triton_add_one.data_2024_05_02) + + self.run_one_test(func, data) + + @jax.default_matmul_precision("bfloat16") + def test_mosaic_matmul(self): + dtype = jnp.float32 + def func(): + # Build the inputs here, to reduce the size of the golden inputs. + x_shape = (1024, 512) + bias = 1.0 + scale = 1e-3 + x = bias + scale * jnp.arange( + math.prod(x_shape), dtype=dtype).reshape(x_shape) + y = x[:512, :256] + res = matmul.matmul(x, y, block_shape=(256, 256)) + # Keep only slices of the output, to reduce the size of the goldens. + return res[::16, ::16] + + data = self.load_testdata(mosaic_matmul.data_2023_09_22) + self.run_one_test(func, data, rtol=2e-7) + + def test_mosaic_semaphore_dma(self): + if not (jtu.test_device_matches(["tpu"]) and + jtu.is_device_tpu_at_least(4)): + # TODO: crashes during compilation on TPU v4 + self.skipTest("Only works on TPU v5+") + + # The signatures of TPU ops for semaphore and DMA have changed. + # This test ensures that the new signatures are backwards compatible. + def func(): + def dma_kernel(x, y): + def body(dma_sem, sem): + pltpu.async_copy(x, y, dma_sem).wait() + pltpu.semaphore_signal(sem) + pltpu.semaphore_wait(sem) + pl.run_scoped( + body, pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.REGULAR + ) + x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + y = pl.pallas_call(dma_kernel, out_shape=x)(x) + return jnp.array_equal(x, y).astype(jnp.float32) + data = self.load_testdata( + mosaic_semaphore_dma.semaphore_and_dma_2024_04_22) self.run_one_test(func, data) From 3095c570b842816ca32a210bda6c2d02ccc09631 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 7 Aug 2024 17:59:53 +0530 Subject: [PATCH 012/702] Better docs for jnp.fft.rfft2 and jnp.fft.irfft2 --- jax/_src/numpy/fft.py | 154 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 148 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index b65f1ee589cc..e246b1fb6929 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -465,12 +465,12 @@ def rfft(a: ArrayLike, n: int | None = None, def irfft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: - r"""Compute a one-dimensional inverse discrete Fourier transform for real input. + """Compute a real-valued one-dimensional inverse discrete Fourier transform. JAX implementation of :func:`numpy.fft.irfft`. Args: - a: real-valued input array. + a: input array. n: int. Specifies the dimension of the result along ``axis``. If not specified, ``n = 2*(m-1)``, where ``m`` is the dimension of ``a`` along ``axis``. axis: int, default=-1. Specifies the axis along which the transform is computed. @@ -479,8 +479,8 @@ def irfft(a: ArrayLike, n: int | None = None, supported. Returns: - An array containing the one-dimensional inverse discrete Fourier transform - of ``a``, with a dimension of ``n`` along ``axis``. + A real-valued array containing the one-dimensional inverse discrete Fourier + transform of ``a``, with a dimension of ``n`` along ``axis``. See also: - :func:`jax.numpy.fft.ifft`: Computes a one-dimensional inverse discrete @@ -826,15 +826,157 @@ def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes, norm=norm) -@implements(np.fft.rfft2) + def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: + """Compute a two-dimensional discrete Fourier transform of a real-valued array. + + JAX implementation of :func:`numpy.fft.rfft2`. + + Args: + a: real-valued input array. Must have ``a.ndim >= 2``. + s: optional length-2 sequence of integers. Specifies the effective size of the + output along each specified axis. If not specified, it will default to the + dimension of input along ``axes``. + axes: optional length-2 sequence of integers, default=(-2,-1). Specifies the + axes along which the transform is computed. + norm: string, default="backward". The normalization mode. "backward", "ortho" + and "forward" are supported. + + Returns: + An array containing the two-dimensional discrete Fourier transform of ``a``. + The size of the output along the axis ``axes[1]`` is ``(s[1]/2)+1``, if ``s[1]`` + is even and ``(s[1]+1)/2``, if ``s[1]`` is odd. The size of the output along + the axis ``axes[0]`` is ``s[0]``. + + See also: + - :func:`jax.numpy.fft.rfft`: Computes a one-dimensional discrete Fourier + transform of real-valued array. + - :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier + transform of real-valued array. + - :func:`jax.numpy.fft.irfft2`: Computes a real-valued two-dimensional inverse + discrete Fourier transform. + + Examples: + ``jnp.fft.rfft2`` computes the transform along the last two axes by default. + + >>> x = jnp.array([[[1, 3, 5], + ... [2, 4, 6]], + ... [[7, 9, 11], + ... [8, 10, 12]]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfft2(x) + Array([[[21.+0.j , -6.+3.46j], + [-3.+0.j , 0.+0.j ]], + + [[57.+0.j , -6.+3.46j], + [-3.+0.j , 0.+0.j ]]], dtype=complex64) + + When ``s=[2, 4]``, dimension of the transform along ``axis -2`` will be + ``2``, along ``axis -1`` will be ``(4/2)+1) = 3`` and dimension along other + axes will be the same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfft2(x, s=[2, 4]) + Array([[[21. +0.j, -8. -7.j, 7. +0.j], + [-3. +0.j, 0. +1.j, -1. +0.j]], + + [[57. +0.j, -8.-19.j, 19. +0.j], + [-3. +0.j, 0. +1.j, -1. +0.j]]], dtype=complex64) + + When ``s=[3, 5]`` and ``axes=(0, 1)``, shape of the transform along ``axis 0`` + will be ``3``, along ``axis 1`` will be ``(5+1)/2 = 3`` and dimension along + other axes will be same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfft2(x, s=[3, 5], axes=(0, 1)) + Array([[[ 18. +0.j , 26. +0.j , 34. +0.j ], + [ 11.09 -9.51j, 16.33-13.31j, 21.56-17.12j], + [ -0.09 -5.88j, 0.67 -8.23j, 1.44-10.58j]], + + [[ -4.5 -12.99j, -2.5 -16.45j, -0.5 -19.92j], + [ -9.71 -6.3j , -10.05 -9.52j, -10.38-12.74j], + [ -4.95 +0.72j, -5.78 -0.2j , -6.61 -1.12j]], + + [[ -4.5 +12.99j, -2.5 +16.45j, -0.5 +19.92j], + [ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j], + [ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64) + """ return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes, norm=norm) -@implements(np.fft.irfft2) + def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: + """Compute a real-valued two-dimensional inverse discrete Fourier transform. + + JAX implementation of :func:`numpy.fft.irfft2`. + + Args: + a: input array. Must have ``a.ndim >= 2``. + s: optional length-2 sequence of integers. Specifies the size of the output + in each specified axis. If not specified, the dimension of output along + axis ``axes[1]`` is ``2*(m-1)``, ``m`` is the size of input along axis + ``axes[1]`` and the dimension along other axes will be the same as that of + input. + axes: optional length-2 sequence of integers, default=(-2,-1). Specifies the + axes along which the transform is computed. + norm: string, default="backward". The normalization mode. "backward", "ortho" + and "forward" are supported. + + Returns: + A real-valued array containing the two-dimensional inverse discrete Fourier + transform of ``a``. + + See also: + - :func:`jax.numpy.fft.rfft2`: Computes a two-dimensional discrete Fourier + transform of a real-valued array. + - :func:`jax.numpy.fft.irfft`: Computes a real-valued one-dimensional inverse + discrete Fourier transform. + - :func:`jax.numpy.fft.irfftn`: Computes a real-valued multidimensional inverse + discrete Fourier transform. + + Examples: + ``jnp.fft.ifft2`` computes the transform along the last two axes by default. + + >>> x = jnp.array([[[1, 3, 5], + ... [2, 4, 6]], + ... [[7, 9, 11], + ... [8, 10, 12]]]) + >>> jnp.fft.irfft2(x) + Array([[[ 3.5, -1. , 0. , -1. ], + [-0.5, 0. , 0. , 0. ]], + + [[ 9.5, -1. , 0. , -1. ], + [-0.5, 0. , 0. , 0. ]]], dtype=float32) + + When ``s=[3, 3]``, dimension of the transform along ``axes (-2, -1)`` will be + ``(3, 3)`` and dimension along other axes will be the same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.irfft2(x, s=[3, 3]) + Array([[[ 1.89, -0.44, -0.44], + [ 0.22, -0.78, 0.56], + [ 0.22, 0.56, -0.78]], + + [[ 5.89, -0.44, -0.44], + [ 1.22, -1.78, 1.56], + [ 1.22, 1.56, -1.78]]], dtype=float32) + + When ``s=[2, 3]`` and ``axes=(0, 1)``, shape of the transform along + ``axes (0, 1)`` will be ``(2, 3)`` and dimension along other axes will be + same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.irfft2(x, s=[2, 3], axes=(0, 1)) + Array([[[ 4.67, 6.67, 8.67], + [-0.33, -0.33, -0.33], + [-0.33, -0.33, -0.33]], + + [[-3. , -3. , -3. ], + [ 0. , 0. , 0. ], + [ 0. , 0. , 0. ]]], dtype=float32) + """ return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes, norm=norm) From 3a1567f57adfeae7ab36ff078c06b2dfbf12fdb8 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 7 Aug 2024 07:13:25 -0700 Subject: [PATCH 013/702] Do not run nn_test under asan -- it times out PiperOrigin-RevId: 660377176 --- tests/BUILD | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index adea49cac293..b5a99b254c16 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -688,6 +688,11 @@ jax_test( jax_test( name = "nn_test", srcs = ["nn_test.py"], + backend_tags = { + "gpu": [ + "noasan", # Times out under asan. + ], + }, shard_count = { "cpu": 10, "tpu": 10, From 6fc57c0eb6f06b2da20c94f5f127fe4a551bda09 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 7 Aug 2024 10:17:04 -0700 Subject: [PATCH 014/702] Rolling forward #22836 This version, proposed by @dfm, does not have a custom JVP for the whole logsumexp and instead fixes #22398 directly. Reverts e416c6675acfd82866a6e83e8c221640c4d02f29 PiperOrigin-RevId: 660438802 --- jax/_src/ops/special.py | 36 +++++++++++++++++++++++++++++------- tests/lax_scipy_test.py | 10 +++++++++- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index 59ad594ef2bc..45b26b0de4d3 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -14,12 +14,12 @@ from __future__ import annotations -from typing import overload, Literal +from typing import Literal, overload import jax from jax import lax from jax import numpy as jnp -from jax._src.numpy.reductions import _reduction_dims, Axis +from jax._src.numpy.reductions import Axis, _reduction_dims from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike import numpy as np @@ -40,6 +40,7 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: ... + def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: r"""Log-sum-exp reduction. @@ -71,18 +72,22 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, """ if b is not None: a_arr, b_arr = promote_args_inexact("logsumexp", a, b) - a_arr = jnp.where(b_arr != 0, a_arr, -jnp.inf) + a_masked = jnp.where(b_arr != 0, a_arr, -jnp.inf) else: a_arr, = promote_args_inexact("logsumexp", a) b_arr = a_arr # for type checking + a_masked = a_arr pos_dims, dims = _reduction_dims(a_arr, axis) - amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-jnp.inf) + amax = jnp.max( + a_masked.real, axis=dims, keepdims=keepdims, where=where, initial=-jnp.inf + ) amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) - exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) - if b is not None: - exp_a = lax.mul(exp_a, b_arr) + if b is None: + exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) + else: + exp_a = _stable_mulexp(a_arr - amax_with_dims.astype(a_arr.dtype), b_arr) sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where) sign = lax.sign(sumexp) if return_sign or not np.issubdtype(a_arr.dtype, np.complexfloating): @@ -95,3 +100,20 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, with jax.debug_nans(False): out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out + + +@jax.custom_jvp +def _stable_mulexp(a_scaled: Array, b: Array) -> Array: + # This helper ensures that the output of logsumexp depends on b for b == 0. + # See https://github.com/google/jax/issues/22398. + a_scaled = jnp.where(b != 0, a_scaled, -jnp.inf) + return lax.mul(lax.exp(a_scaled), b) + + +@_stable_mulexp.defjvp +def _stable_mulexp_jvp(primals, tangents): + a_scaled, b = primals + da, db = tangents + out = _stable_mulexp(a_scaled, b) + dout = _stable_mulexp(a_scaled, db) + da * out + return out, dout diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 1ed410cbaed8..66d84c427fea 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -20,9 +20,9 @@ from absl.testing import absltest import numpy as np +import scipy.cluster as osp_cluster import scipy.integrate import scipy.special as osp_special -import scipy.cluster as osp_cluster import jax import jax.dtypes @@ -202,6 +202,14 @@ def testLogSumExpWhere(self, shape, dtype): y_actual = lsp_special.logsumexp(x, where=mask) self.assertAllClose(y_expected, y_actual, check_dtypes=False) + def testLogSumExpZerosJac(self): + # Regression test for https://github.com/google/jax/issues/22398 + fun = lambda b: lsp_special.logsumexp(jnp.zeros(2), axis=0, b=b) + np.testing.assert_array_equal( + jax.jacfwd(fun)(jnp.array([1.0, 0.0])), + jnp.ones(2), + ) + @jtu.sample_product( shape=all_shapes, dtype=float_dtypes, From 53af0d4d90e3f6159b2de70961eec31e5b3dcb85 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 7 Aug 2024 15:15:45 -0700 Subject: [PATCH 015/702] CI: fix mypy errors --- jax/_src/nn/functions.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 4aabf9521340..bfbee04b6a83 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -919,10 +919,10 @@ def _ensure_4d(t): if dims_to_add > 0: return jnp.expand_dims(t, axis=tuple(range(dims_to_add))) return t - - query = _ensure_4d(query) - key = _ensure_4d(key) - value = _ensure_4d(value) + + query_arr = _ensure_4d(query) + key_arr = _ensure_4d(key) + value_arr = _ensure_4d(value) bias = _ensure_4d(bias) if bias is not None else None mask = _ensure_4d(mask) if mask is not None else None @@ -933,15 +933,15 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: if shape[i] != -1 and t.shape[i] != shape[i]: raise ValueError(f"{name} shape should be {shape}: but got {t.shape}") - B, S, K, H = key.shape - _check_has_shape(value, [B, S, K, H], 'value') - _check_has_shape(query, [B, -1, -1, H], 'query') - if query.shape[-2] % K != 0: + B, S, K, H = key_arr.shape + _check_has_shape(value_arr, [B, S, K, H], 'value') + _check_has_shape(query_arr, [B, -1, -1, H], 'query') + if query_arr.shape[-2] % K != 0: raise ValueError(f"The number of query heads must to a multiple of " - f"key/value heads, but got {query.shape[-2]} vs {K}") - if not (query.dtype == key.dtype == value.dtype): + f"key/value heads, but got {query_arr.shape[-2]} vs {K}") + if not (query_arr.dtype == key_arr.dtype == value_arr.dtype): raise ValueError(f"query/key/value should have the same shape, but got " - f"{query.shape} vs {key.shape} vs {value.shape}.") + f"{query_arr.shape} vs {key_arr.shape} vs {value_arr.shape}.") if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4: raise ValueError(f"Mask must be a 4D boolean tensor, but got " f"rank={mask.ndim}, dtype={mask.dtype}.") @@ -953,18 +953,18 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: match implementation: case 'xla': out = _dot_product_attention_xla( - query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, + query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, ) case 'cudnn': mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK out = cudnn_dot_product_attention( - query, key, value, bias, mask, scale=scale_val, mask_type=mask_type + query_arr, key_arr, value_arr, bias, mask, scale=scale_val, mask_type=mask_type ) case None: # TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select # best backend. out = _dot_product_attention_xla( - query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, + query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, ) case _: raise ValueError(f"Unsupported implementation option: {implementation}") From a57d6591eee7e6c18e97081d3e00432728ab4c62 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 7 Aug 2024 15:57:00 -0700 Subject: [PATCH 016/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3bf7e1ae488174aa6b29cc3f2c216785dd161af8. PiperOrigin-RevId: 660570144 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f49cf2eb34ba..b885cad3f539 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "08b8d938eb56928970e65639b126794c01b75c3d" -XLA_SHA256 = "365d9b42b6da10c9f0b53f01e075e8b1513e431ad596183d8ec1c2c27e1d7973" +XLA_COMMIT = "3bf7e1ae488174aa6b29cc3f2c216785dd161af8" +XLA_SHA256 = "6f11fc246856472069926e5de3506c740fb9af750e74819e39737e1c8460da78" def repo(): tf_http_archive( From e6425a2c67253862f9fa3cf436c05f5d6dca0534 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Wed, 7 Aug 2024 23:20:19 +0000 Subject: [PATCH 017/702] Small performance improvement to pallas MHA --- jax/experimental/pallas/ops/gpu/attention.py | 89 +++++++++----------- 1 file changed, 39 insertions(+), 50 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 647310dcaacc..1cf8349e7da2 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -103,9 +103,7 @@ def body(start_k, carry): ) # Use m_next instead of m_curr to avoid a correction on l_curr l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr - l_next_rcp = 1. / l_next - s_curr = s_curr * l_next_rcp[:, None] - o_prev_corr = (l_prev_corr * l_next_rcp)[:, None] * o_prev + o_prev_corr = correction[:, None] * o_prev v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d))) o_curr = pl.dot(s_curr.astype(v.dtype), v) @@ -118,10 +116,15 @@ def body(start_k, carry): upper_bound = pl.cdiv(seq_len, block_k) o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) + # We keep an unscaled version of o during the scan over seq_len. Scaling it + # by the last l_i gives us the correct final output. See section 3.1.1 in the + # FlashAttention-2 paper: https://arxiv.org/pdf/2307.08691. + o /= l_i[:, None] + if residual_refs: - l_ref, m_ref = residual_refs - pl.store(l_ref, (curr_q_slice,), l_i) - pl.store(m_ref, (curr_q_slice,), m_i) + lse_ref = residual_refs[0] + lse_i = m_i + jnp.log(l_i) + pl.store(lse_ref, (curr_q_slice,), lse_i) # Write output to dram. o = o.astype(o_ref.dtype) pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o) @@ -258,11 +261,10 @@ def _mha_forward( sm_scale=sm_scale, causal=causal, block_q=block_q, block_k=block_k, block_d=head_dim) out_shape = [ - jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out - jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # l - dtype=jnp.float32), - jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m - dtype=jnp.float32) + jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out + jax.ShapeDtypeStruct( + shape=(batch_size, num_heads, seq_len), dtype=jnp.float32 # lse + ), ] in_specs = [ pl.BlockSpec( @@ -280,7 +282,7 @@ def _mha_forward( if segment_ids is None else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) ) - out, l, m = pl.pallas_call( + out, lse = pl.pallas_call( kernel, grid=grid_, in_specs=in_specs, @@ -289,7 +291,6 @@ def _mha_forward( (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), ], compiler_params=dict( triton=dict(num_warps=num_warps_, num_stages=num_stages) @@ -299,55 +300,45 @@ def _mha_forward( interpret=interpret, name="mha_forward", )(q, k, v, segment_ids) - return out, (q, k, v, segment_ids, out, l, m) + return out, (q, k, v, segment_ids, out, lse) -def _preprocess_backward_kernel(out_ref, dout_ref, l_ref, - new_dout_ref, delta_ref, *, - block_q: int): +def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, *, block_q: int): pid_m = pl.program_id(0) off_m = pl.ds(pid_m * block_q, block_q) # load o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32) do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32) - denom = pl.load(l_ref, (off_m,)).astype(jnp.float32) # compute - do = do / denom[:, None] delta = jnp.sum(o * do, axis=1) # write-back - pl.store(new_dout_ref, (off_m, slice(None)), - do.astype(new_dout_ref.dtype)) pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype)) @jax.named_scope("preprocess_backward") -def _preprocess_backward(out, do, l, block_q: int, +def _preprocess_backward(out, do, lse, block_q: int, debug: bool, interpret: bool): batch_size, seq_len, num_heads, head_dim = out.shape - out_shape = [ - jax.ShapeDtypeStruct(do.shape, do.dtype), - jax.ShapeDtypeStruct(l.shape, l.dtype), - ] - do_scaled, delta = pl.pallas_call( + out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype) + delta = pl.pallas_call( functools.partial(_preprocess_backward_kernel, block_q=block_q), grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), - ], - out_specs=[ - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + pl.BlockSpec( + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + ), + pl.BlockSpec( + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + ), ], - compiler_params=dict( - triton=dict(num_warps=4, num_stages=3) - ), + out_specs=pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + compiler_params=dict(triton=dict(num_warps=4, num_stages=3)), out_shape=out_shape, debug=debug, interpret=interpret, - name="mha_preprocess_backward")(out, do, l) - return do_scaled, delta + name="mha_preprocess_backward", + )(out, do) + return delta # This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence @@ -361,8 +352,7 @@ def mha_backward_kernel( segment_ids_ref: jax.Array | None, out_ref, do_scaled_ref, - l_ref, - m_ref, + lse_ref, delta_ref, # Outputs dq_ref, @@ -377,7 +367,7 @@ def mha_backward_kernel( block_k2: int, block_d: int, ): - del out_ref, l_ref # Not needed + del out_ref # Not needed seq_len = q_ref.shape[0] # Scan #1: dK and dV @@ -422,11 +412,11 @@ def inner_loop_dkdv(start_q, carry): ) qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) - m = pl.load(m_ref, (curr_q_slice,)) + lse = pl.load(lse_ref, (curr_q_slice,)) di = pl.load(delta_ref, (curr_q_slice,)) do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) - p = jnp.exp(qk - m[:, None]) + p = jnp.exp(qk - lse[:, None]) dv = dv + pl.dot(p.astype(do.dtype).T, do) dp = jnp.zeros((block_q1, block_k1), dtype=jnp.float32) - di[:, None] dp = dp + pl.dot(do, v.T) @@ -461,7 +451,7 @@ def inner_loop_dkdv(start_q, carry): if segment_ids_ref is None else pl.load(segment_ids_ref, (curr_q_slice,)) ) - m = pl.load(m_ref, (curr_q_slice,)) + lse = pl.load(lse_ref, (curr_q_slice,)) do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) di = pl.load(delta_ref, (curr_q_slice,)) @@ -488,7 +478,7 @@ def inner_loop_dq(start_k, dq): ) qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) - p = jnp.exp(qk - m[:, None]) + p = jnp.exp(qk - lse[:, None]) dp = jnp.zeros((block_q2, block_k2), dtype=jnp.float32) - di[:, None] dp = dp + pl.dot(do, v.T) ds = p * dp @@ -513,7 +503,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, num_stages: int, grid: Any, interpret: bool, debug: bool, res, do): del num_warps, num_stages, grid - q, k, v, segment_ids, out, l, m = res + q, k, v, segment_ids, out, lse = res if backward_pass_impl == "xla": return jax.vjp( @@ -527,7 +517,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, batch_size, seq_len, num_heads, head_dim = q.shape block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) - do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) + delta = _preprocess_backward(out, do, lse, block_q, debug, interpret) out_shapes = [ jax.ShapeDtypeStruct(q.shape, q.dtype), jax.ShapeDtypeStruct(k.shape, k.dtype), @@ -552,7 +542,6 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, ), pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), - pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] @@ -593,7 +582,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, debug=debug, interpret=interpret, compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=2)), - )(q, k, v, segment_ids, out, do_scaled, l, m, delta) + )(q, k, v, segment_ids, out, do, lse, delta) else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv, None From be53ee10b10d63f15ad912f23098532c120b5791 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 7 Aug 2024 16:24:42 -0700 Subject: [PATCH 018/702] Set `jax_enable_memories` flag to `True` by default PiperOrigin-RevId: 660579462 --- CHANGELOG.md | 1 + jax/_src/config.py | 2 +- jax/_src/dispatch.py | 6 ------ .../array_serialization/serialization_test.py | 2 ++ tests/pjit_test.py | 13 ------------- 5 files changed, 4 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fbe947fa7de..038c0131ad12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.32 * Changes + * `jax_enable_memories` flag is set to `True` by default. * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. See {ref}`python-array-api` for more information. * Computations on the CPU backend may now be dispatched asynchronously in diff --git a/jax/_src/config.py b/jax/_src/config.py index 5b4226f8fa33..46b3273278e0 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1054,7 +1054,7 @@ def _update_jax_memories_thread_local(val): enable_memories = bool_state( 'jax_enable_memories', - default=False, + default=True, upgrade=True, update_global_hook=_update_jax_memories_global, update_thread_local_hook=_update_jax_memories_thread_local, diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e7fd8657ccdb..6c0b46077dcc 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -565,12 +565,6 @@ def lower(x, device, src, aval, out_aval): def _common_device_put_lowering(ctx, *xs, devices, srcs): - for device in devices: - if (isinstance(device, (Sharding, TransferToMemoryKind)) and - device.memory_kind is not None): - raise NotImplementedError( - "Passing memory_kind to device_put via Shardings is not supported on" - f" platforms {ctx.module_context.platforms}") return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 04a64fe55e25..2712e2b4a819 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -579,6 +579,8 @@ def test_load_with_layout(self): self.assertArraysEqual(s.data, np_inp[s.index]) def test_deserialization_with_int4(self): + if jtu.test_device_matches(['gpu']): + self.skipTest("Fails on GPU. Enable after it's fixed") dtype = jnp.int4 shape = (8, 2) arr = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 516d1fec7ff9..df87fed4bb7d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3824,19 +3824,6 @@ def f(inp): ' manager.*SingleDeviceSharding'): jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr) - @jtu.skip_on_devices("tpu", "gpu") - def test_device_put_memory_kind_not_tpu_gpu(self): - @jax.jit - def f(x): - y = x * 2 - return jax.device_put(y, sharding_impls.TransferToMemoryKind('unpinned_host')) - - with self.assertRaisesRegex( - NotImplementedError, - 'Passing memory_kind to device_put via Shardings is not supported on' - ' platform.*'): - f(jnp.arange(8)) - def test_no_output_multiple_devices(self): mesh = jtu.create_global_mesh((2,), ('x',)) From 7f8a4c84d3ee4c4e0ce0aaa462990556902e811b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 7 Aug 2024 20:54:58 -0700 Subject: [PATCH 019/702] Remove PositionalSharding from distributed array doc --- ...arrays_and_automatic_parallelization.ipynb | 2155 +++++++++-------- ...ed_arrays_and_automatic_parallelization.md | 635 +++-- 2 files changed, 1507 insertions(+), 1283 deletions(-) diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 2face1d4a0b2..09fbbd3a74c3 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "id": "FNxScTfq3vGF" }, @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "IZMLqOUV3vGG" }, @@ -70,7 +70,7 @@ "source": [ "## Intro and a quick example\n", "\n", - "By reading this tutorial notebook, you'll learn about `jax.Array`, a unified\n", + "By reading this tutorial notebook, you'll learn about `jax.Array`, a unified \n", "datatype for representing arrays, even with physical storage spanning multiple\n", "devices. You'll also learn about how using `jax.Array`s together with `jax.jit`\n", "can provide automatic compiler-based parallelization.\n", @@ -81,57 +81,76 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "id": "Gf2lO4ii3vGG" }, "outputs": [], "source": [ "from jax.experimental import mesh_utils\n", - "from jax.sharding import PositionalSharding" + "from jax.sharding import Mesh, PartitionSpec as P, NamedSharding" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "id": "q-XBTEoy3vGG" }, "outputs": [], "source": [ "# Create a Sharding object to distribute a value across devices:\n", - "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))" + "mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),\n", + " axis_names=('x', 'y'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, "id": "vI39znW93vGH", - "outputId": "3b518df8-5c29-4848-acc3-e41df939f30b" + "outputId": "4f702753-8add-4b65-a4af-0f18f098cc46" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "# Create an array of random values:\n", "x = jax.random.normal(jax.random.key(0), (8192, 8192))\n", "# and use jax.device_put to distribute it across devices:\n", - "y = jax.device_put(x, sharding.reshape(4, 2))\n", + "y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))\n", "jax.debug.visualize_array_sharding(y)" ] }, @@ -147,26 +166,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, "id": "-qCnHZl83vGI", - "outputId": "9da9c29e-ce88-4425-e1ec-e93e5bcf3106" + "outputId": "0e131c23-5765-43ae-f232-6417ae1acbb2" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -186,18 +223,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "_VTzN0r03vGI", - "outputId": "c9208010-984b-442b-d105-c8c6a3a010e6" + "outputId": "c03eecab-4c86-4dac-d776-5fc72cbb5273" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached \n", - "5 loops, best of 5: 9.69 ms per loop\n" + "The slowest run took 8.96 times longer than the fastest. This could mean that an intermediate result is being cached.\n", + "25.2 ms ± 30.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -208,17 +248,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "QuzhU1g63vGI", - "outputId": "d48fc76e-79a7-47b9-d392-b18a1c33c798" + "outputId": "8135cca0-871b-4b6a-a7e5-02e78c2028c7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 1.86 ms per loop\n" + "2.4 ms ± 61.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -245,7 +288,7 @@ "id": "W6HsXauGxL6w" }, "source": [ - "### Sharding basics, and the `PositionalSharding` subclass" + "### Sharding basics, and the `NamedSharding` subclass" ] }, { @@ -263,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "id": "VmoX4SUp3vGJ" }, @@ -275,511 +318,109 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "vNRabO2J3vGJ", - "outputId": "73db7b6e-c2e7-467d-a0ef-c35e29e582dd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ TPU 0 │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] - } - ], - "source": [ - "jax.debug.visualize_array_sharding(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HhCjhK0zXIqX" - }, - "source": [ - "Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n", - "\n", - "But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VUIEIzRp3vGK" - }, - "outputs": [], - "source": [ - "from jax.experimental import mesh_utils\n", - "devices = mesh_utils.create_device_mesh((8,))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lbOKFWmBX1iv" - }, - "source": [ - "Then, we create a `PositionalSharding` and use it with `device_put`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jwrWfZeB3vGK", - "outputId": "e6f126bd-f6bd-48c7-c130-6f02757e3342" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ TPU 0 │\n", - "├───────────────────────┤\n", - "│ TPU 1 │\n", - "├───────────────────────┤\n", - "│ TPU 2 │\n", - "├───────────────────────┤\n", - "│ TPU 3 │\n", - "├───────────────────────┤\n", - "│ TPU 6 │\n", - "├───────────────────────┤\n", - "│ TPU 7 │\n", - "├───────────────────────┤\n", - "│ TPU 4 │\n", - "├───────────────────────┤\n", - "│ TPU 5 │\n", - "└───────────────────────┘\n" - ] - } - ], - "source": [ - "from jax.sharding import PositionalSharding\n", - "\n", - "sharding = PositionalSharding(devices)\n", - "\n", - "x = jax.device_put(x, sharding.reshape(8, 1))\n", - "jax.debug.visualize_array_sharding(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TUu69IWXZdTm" - }, - "source": [ - "Here `sharding` is a `PositionalSharding` which acts like an array with sets of devices as elements:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zxWB82Kz3vGK", - "outputId": "11384a6b-fabc-4c4c-bcad-a3be51eb0465" + "outputId": "40fd7172-a16c-4dd8-e2e1-17bb3afe5409" }, "outputs": [ { "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│         TPU 0         │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], "text/plain": [ - "PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}])" + "┌───────────────────────┐\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ TPU \u001b[1;36m0\u001b[0m │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" ] }, - "execution_count": 13, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "sharding" + "jax.debug.visualize_array_sharding(x)" ] }, { "cell_type": "markdown", "metadata": { - "id": "uRLpOcmNj_Vt" + "id": "HhCjhK0zXIqX" }, "source": [ - "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", + "Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n", "\n", - "By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PLsnpSzc3vGL", - "outputId": "9f4db733-cafe-46ae-c057-dc31046a6f66" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "PositionalSharding([[{TPU 0}]\n", - " [{TPU 1}]\n", - " [{TPU 2}]\n", - " [{TPU 3}]\n", - " [{TPU 6}]\n", - " [{TPU 7}]\n", - " [{TPU 4}]\n", - " [{TPU 5}]])" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sharding.reshape(8, 1)" + "But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": { - "id": "iqKdI4LO3vGL", - "outputId": "6aa10fc2-cec4-4401-a0df-343e71646e0a" + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, + "id": "zpB1JxyK3vGN", + "outputId": "8e385462-1c2c-4256-c38a-84299d3bd02c" }, "outputs": [ { "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], "text/plain": [ - "PositionalSharding([[{TPU 0} {TPU 1}]\n", - " [{TPU 2} {TPU 3}]\n", - " [{TPU 6} {TPU 7}]\n", - " [{TPU 4} {TPU 5}]])" + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" ] }, - "execution_count": 15, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "sharding.reshape(4, 2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KBu6WLfhm7ra" - }, - "source": [ - "To use `device_put` with a data array `x`, we can reshape the `sharding` into a shape that is _congruent_ with `x.shape`, meaning a shape with the same length as `x.shape` and where each element evenly divides the corresponding element of `x.shape`:\n", - "```python\n", - "def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool:\n", - " return (len(x_shape) == len(sharding_shape) and\n", - " all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape)))\n", - "```\n", - "\n", - "For example, we can reshape `sharding` to have shape `(4, 2)`, then use it in a `device_put`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "SELr4xNi3vGL", - "outputId": "b2f4acec-0cd3-4829-ca16-cae2e0e8ca60" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "PositionalSharding([[{TPU 0} {TPU 1}]\n", - " [{TPU 2} {TPU 3}]\n", - " [{TPU 6} {TPU 7}]\n", - " [{TPU 4} {TPU 5}]])\n" - ] - } - ], - "source": [ - "sharding = sharding.reshape(4, 2)\n", - "print(sharding)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8IVIsqfX3vGL", - "outputId": "033d0e02-a643-4f4c-9d24-9cd8465bc69a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] - } - ], - "source": [ - "y = jax.device_put(x, sharding)\n", - "jax.debug.visualize_array_sharding(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tyg9F-UIsU__" - }, - "source": [ - "Here `y` represents the same _value_ as `x`, but its shards (i.e. slices) are stored in different devices' memories.\n", - "\n", - "Different `PositionalSharding` shapes result in different distributed layouts (i.e. shardings) of the result:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cCjt6QCz3vGM", - "outputId": "4ad8a611-596d-424f-b6c5-fc00f1adc306" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "PositionalSharding([[{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]])\n" - ] - } - ], - "source": [ - "sharding = sharding.reshape(1, 8)\n", - "print(sharding)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yTK4Nz3u3vGM", - "outputId": "e445c6bc-4fe3-4e9d-cc9e-d82858f58312" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" - ] - } - ], - "source": [ - "y = jax.device_put(x, sharding)\n", - "jax.debug.visualize_array_sharding(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0PuamOvXubcf" - }, - "source": [ - "In some cases, we don't just want to store each slice of `x` in a single device's memory; we might want to _replicate_ some slices, meaning storing copies of a slice's values in multiple devices' memories.\n", - "\n", - "With `PositionalSharding`, we can express replication by calling the reducer method `replicate`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_jr6XYKx3vGM", - "outputId": "59c8b9a4-b8af-493a-ba8d-da5931e88f93" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "PositionalSharding([[{TPU 0, 2, 4, 6} {TPU 1, 3, 5, 7}]])\n" - ] - } - ], - "source": [ - "sharding = sharding.reshape(4, 2)\n", - "print(sharding.replicate(axis=0, keepdims=True))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "S5vzjFuH3vGN", - "outputId": "b6ce2675-7261-4e57-fa8c-b4e87abf7e52" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────┬───────────┐\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│TPU 0,2,4,6│TPU 1,3,5,7│\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "└───────────┴───────────┘\n" - ] - } - ], - "source": [ - "y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))\n", - "jax.debug.visualize_array_sharding(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FzeP0kpTvJv-" - }, - "source": [ - "Here the visualization shows that `x` is sharded two ways along its second dimension (and not sharded along the first dimension), and each of those shards is replicated four ways (i.e. stored in four device memories).\n", - "\n", - "The `replicate` method is analogous to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DR7VV-6e3vGN", - "outputId": "f879fc2c-5723-4199-b306-295bc1b3681e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 2)\n", - "(4, 1)\n" - ] - } - ], - "source": [ - "print(sharding.replicate(0).shape)\n", - "print(sharding.replicate(1).shape)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "agUtVUVx3vGN", - "outputId": "0e9789ef-ce52-4ed6-8bd5-c876b95f66e6" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ TPU 0,1 │\n", - "├───────────────────────┤\n", - "│ TPU 2,3 │\n", - "├───────────────────────┤\n", - "│ TPU 6,7 │\n", - "├───────────────────────┤\n", - "│ TPU 4,5 │\n", - "└───────────────────────┘\n" - ] - } - ], - "source": [ - "y = jax.device_put(x, sharding.replicate(1))\n", - "jax.debug.visualize_array_sharding(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D31t5POXxHHJ" - }, - "source": [ - "### `NamedSharding` gives a way to express shardings with names" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ayMKWeTmxl-X" - }, - "source": [ - "So far we've worked with `PositionalSharding`, but there are alternative ways to express shardings. In fact, `Sharding` is an interface, and any class that implements that interface can be used with functions like `device_put`.\n", - "\n", - "Another convenient way to express sharding is with the `NamedSharding`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zpB1JxyK3vGN", - "outputId": "46d5da37-840c-49d8-8380-a162811bae8a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] - } - ], - "source": [ - "from jax.sharding import Mesh\n", - "from jax.sharding import PartitionSpec\n", - "from jax.sharding import NamedSharding\n", + "from jax.sharding import Mesh, PartitionSpec, NamedSharding\n", "from jax.experimental import mesh_utils\n", "\n", "P = PartitionSpec\n", @@ -801,7 +442,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": { "id": "8g0Md2Gd3vGO" }, @@ -820,26 +461,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, "id": "zp3MfS4Y3vGO", - "outputId": "2c2f7201-c2c1-49e5-f8a5-0730c124d89a" + "outputId": "032fdd7e-19a1-45da-e1ad-b3227fa43ee6" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -858,28 +517,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "FigK5Zsa3vGO", - "outputId": "eca784e8-33fe-4e9b-a41d-21e9ee781a35" + "outputId": "e488d073-9d02-4376-a6af-19d6d5509c7d" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │\n", - "│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "├───────┼───────┼───────┼───────┤\n", - "│ │ │ │ │\n", - "│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "└───────┴───────┴───────┴───────┘\n" - ] + "data": { + "text/html": [ + "
┌───────┬───────┬───────┬───────┐\n",
+       "│       │       │       │       │\n",
+       "│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "├───────┼───────┼───────┼───────┤\n",
+       "│       │       │       │       │\n",
+       "│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "└───────┴───────┴───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┬───────┬───────┐\n", + "│ │ │ │ │\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "├───────┼───────┼───────┼───────┤\n", + "│ │ │ │ │\n", + "│ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "└───────┴───────┴───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -889,26 +568,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, "id": "hI-HD0xN3vGO", - "outputId": "c3e7dc3c-4048-448a-ef0b-50683532fcdc" + "outputId": "b0c2e863-3aee-4417-b45f-21b2187f6ef7" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ TPU 0,1 │\n", - "├───────────────────────┤\n", - "│ TPU 2,3 │\n", - "├───────────────────────┤\n", - "│ TPU 6,7 │\n", - "├───────────────────────┤\n", - "│ TPU 4,5 │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│        TPU 0,1        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 2,3        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 6,7        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 4,5        │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -932,28 +629,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "EXBExMQC3vGP", - "outputId": "fe1c8d7e-3345-4438-b9d2-780e7854b4eb" + "outputId": "c80e6177-12a6-40ef-b4e4-934dad22da3d" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────┬───────────┐\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│TPU 0,2,4,6│TPU 1,3,5,7│\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "└───────────┴───────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────┬───────────┐\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "└───────────┴───────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────┬───────────┐\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "└───────────┴───────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -963,28 +680,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "PjUpG8uz3vGP", - "outputId": "64d8224d-15d9-4ad4-d613-f7f85b1dc1af" + "outputId": "a0f59dc5-b509-4b8b-bd22-bcd69f696763" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "└───────┴───────┴───────┴───────┘\n" - ] + "data": { + "text/html": [ + "
┌───────┬───────┬───────┬───────┐\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "└───────┴───────┴───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┬───────┬───────┐\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "└───────┴───────┴───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1003,34 +740,60 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 298 + }, "id": "fVcPbDUA3vGP", - "outputId": "7f524ba5-a6d8-4490-cda9-685ad11416f9" + "outputId": "da3f435d-dfc1-4a41-ec90-691cd7c748a0" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ TPU 0 │\n", - "├───────────────────────┤\n", - "│ TPU 1 │\n", - "├───────────────────────┤\n", - "│ TPU 2 │\n", - "├───────────────────────┤\n", - "│ TPU 3 │\n", - "├───────────────────────┤\n", - "│ TPU 6 │\n", - "├───────────────────────┤\n", - "│ TPU 7 │\n", - "├───────────────────────┤\n", - "│ TPU 4 │\n", - "├───────────────────────┤\n", - "│ TPU 5 │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│         TPU 0         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 1         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 2         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 3         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 6         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 7         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 4         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 5         │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m1\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m3\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m7\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m5\u001b[0m │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1069,54 +832,103 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": { "id": "_EmQwggc3vGQ" }, "outputs": [], "source": [ - "from jax.experimental import mesh_utils\n", - "from jax.sharding import PositionalSharding\n", - "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))" + "devices = mesh_utils.create_device_mesh((4, 2))\n", + "mesh = Mesh(devices, axis_names=('a', 'b'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 349 + }, "id": "LnT0vWjc3vGQ", - "outputId": "8089effc-aa4c-49e3-dd19-7064881dbad0" + "outputId": "8e642049-61eb-458d-af79-ac449b58d11b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "input sharding:\n", - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n", - "output sharding:\n", - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" + "input sharding:\n" + ] + }, + { + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "output sharding:\n" ] + }, + { + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "x = jax.device_put(x, sharding.reshape(4, 2))\n", + "x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))\n", "print('input sharding:')\n", "jax.debug.visualize_array_sharding(x)\n", "\n", @@ -1140,54 +952,132 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 548 + }, "id": "Dq043GkP3vGQ", - "outputId": "350219a8-1e4a-4404-fe14-50f97ea3e7ba" + "outputId": "3eff7b67-d7f0-4212-c9d3-2cc271ac1f98" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "lhs sharding:\n", - "┌───────────────────────┐\n", - "│ TPU 0,1 │\n", - "├───────────────────────┤\n", - "│ TPU 2,3 │\n", - "├───────────────────────┤\n", - "│ TPU 6,7 │\n", - "├───────────────────────┤\n", - "│ TPU 4,5 │\n", - "└───────────────────────┘\n", - "rhs sharding:\n", - "┌───────────┬───────────┐\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│TPU 0,2,4,6│TPU 1,3,5,7│\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "└───────────┴───────────┘\n", - "out sharding:\n", - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" + "lhs sharding:\n" + ] + }, + { + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│        TPU 0,1        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 2,3        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 6,7        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 4,5        │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rhs sharding:\n" + ] + }, + { + "data": { + "text/html": [ + "
┌───────────┬───────────┐\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "└───────────┴───────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────┬───────────┐\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "└───────────┴───────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "out sharding:\n" ] + }, + { + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))\n", - "z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))\n", + "y = jax.device_put(x, NamedSharding(mesh, P('a', None)))\n", + "z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))\n", "print('lhs sharding:')\n", "jax.debug.visualize_array_sharding(y)\n", "print('rhs sharding:')\n", @@ -1211,28 +1101,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "QjQ5u8qh3vGQ", - "outputId": "bd29edcd-b87c-486e-c568-906f06ae16be" + "outputId": "0aefc170-833c-4a6a-e003-5990d3db31d9" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ TPU 0 │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│         TPU 0         │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ TPU \u001b[1;36m0\u001b[0m │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1242,10 +1152,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "8tn8lOj73vGR", - "outputId": "5809b3c8-7333-4cd3-db97-a7aede943dce" + "outputId": "d9898c93-7afc-416b-8c40-4d9551613cd0" }, "outputs": [ { @@ -1254,7 +1167,7 @@ "True" ] }, - "execution_count": 36, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1266,17 +1179,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 29, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "D7PpZwhR3vGR", - "outputId": "4f0bd43d-0b32-4089-d3da-c8f1449e3526" + "outputId": "4901a11b-2354-4d26-a897-b88def07a716" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 19.3 ms per loop\n" + "49.7 ms ± 349 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -1286,17 +1202,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 30, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "rgo_yVHF3vGR", - "outputId": "97f19052-f1c9-4d30-f453-07b3a7208aa9" + "outputId": "e51216cf-b073-4250-d422-67f9fd72f6aa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 3.25 ms per loop\n" + "7.47 ms ± 44.8 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -1315,26 +1234,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 31, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, "id": "f1Zw-2lH3vGR", - "outputId": "a796bed4-07b0-497d-8fd8-31a22ab9762e" + "outputId": "43d7a642-fde4-47a6-901f-dfdc64d6a613" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1365,7 +1302,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 94, "metadata": { "id": "1vAkZAOY3vGR" }, @@ -1375,54 +1312,63 @@ "from termcolor import colored\n", "\n", "def print_exception(e):\n", - " name = colored(f'{type(e).__name__}', 'red')\n", + " name = colored(f'{type(e).__name__}', 'red', force_color=True)\n", " print(textwrap.fill(f'{name}: {str(e)}'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 95, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "DHh0N3vn3vGS", - "outputId": "e7741882-0ebf-4237-e5d1-e48c9b9c178f" + "outputId": "8c4652f7-c484-423b-ad78-182134280187" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[31mValueError\u001b[0m: Devices of all `Array` inputs and outputs should\n", - "be the same. Got array device ids [0, 1, 2, 3] on platform TPU and\n", - "another array's device ids [4, 5, 6, 7] on platform TPU\n" + "\u001b[31mValueError\u001b[0m: Received incompatible devices for jitted\n", + "computation. Got argument x1 of jax.numpy.add with shape int32[24] and\n", + "device ids [0, 1, 2, 3] on platform TPU and argument x2 of\n", + "jax.numpy.add with shape int32[24] and device ids [4, 5, 6, 7] on\n", + "platform TPU\n" ] } ], "source": [ - "sharding1 = PositionalSharding(jax.devices()[:4])\n", - "sharding2 = PositionalSharding(jax.devices()[4:])\n", + "sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))\n", + "sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))\n", "\n", - "y = jax.device_put(x, sharding1.reshape(2, 2))\n", - "z = jax.device_put(x, sharding2.reshape(2, 2))\n", + "y = jax.device_put(x, sharding1)\n", + "z = jax.device_put(x, sharding2)\n", "try: y + z\n", "except ValueError as e: print_exception(e)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 96, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "Im7DkoOl3vGS", - "outputId": "3adfe1cb-db52-4a9d-e98e-62c6455c3100" + "outputId": "1b6fcd7a-762b-4366-a96d-aea63bad7fe0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[31mValueError\u001b[0m: Devices of all `Array` inputs and outputs should\n", - "be the same. Got array device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform\n", - "TPU and another array's device ids [0, 1, 2, 3, 6, 7, 4, 5] on\n", - "platform TPU\n" + "\u001b[31mValueError\u001b[0m: Received incompatible devices for jitted\n", + "computation. Got argument x1 of jax.numpy.add with shape int32[24] and\n", + "device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of\n", + "jax.numpy.add with shape int32[24] and device ids [0, 1, 2, 3, 6, 7,\n", + "4, 5] on platform TPU\n" ] } ], @@ -1430,11 +1376,11 @@ "devices = jax.devices()\n", "permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]\n", "\n", - "sharding1 = PositionalSharding(devices)\n", - "sharding2 = PositionalSharding(permuted_devices)\n", + "sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))\n", + "sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))\n", "\n", - "y = jax.device_put(x, sharding1.reshape(4, 2))\n", - "z = jax.device_put(x, sharding2.reshape(4, 2))\n", + "y = jax.device_put(x, sharding1)\n", + "z = jax.device_put(x, sharding2)\n", "try: y + z\n", "except ValueError as e: print_exception(e)" ] @@ -1455,10 +1401,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "_QvtKL8r3vGS", - "outputId": "e0078805-bdfd-436e-f94f-7cd256d2574f" + "outputId": "761b1208-fe4b-4c09-a7d2-f62152183ef0" }, "outputs": [ { @@ -1470,7 +1419,7 @@ } ], "source": [ - "y = jax.device_put(x, sharding1.reshape(4, 2))\n", + "y = jax.device_put(x, sharding1)\n", "y + jnp.ones_like(y)\n", "y + jnp.arange(y.size).reshape(y.shape)\n", "print('no error!')" @@ -1496,30 +1445,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 41, "metadata": { "id": "jniSFm5V3vGT" }, "outputs": [], "source": [ - "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))" + "mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 42, "metadata": { "id": "Q1wuDp-L3vGT" }, "outputs": [], "source": [ "x = jax.random.normal(jax.random.key(0), (8192, 8192))\n", - "x = jax.device_put(x, sharding.reshape(4, 2))" + "x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 44, "metadata": { "id": "rqEDj0wB3vGT" }, @@ -1528,43 +1477,83 @@ "@jax.jit\n", "def f(x):\n", " x = x + 1\n", - " y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))\n", + " y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))\n", " return y" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 45, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 347 + }, "id": "zYFS-n4r3vGT", - "outputId": "d23a7938-cb7d-44b4-b9c7-83edf1d1145e" + "outputId": "0ac96b8f-ed23-4413-aed9-edd00a841c37" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n", - "┌───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │\n", - "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "├───────┼───────┼───────┼───────┤\n", - "│ │ │ │ │\n", - "│ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "└───────┴───────┴───────┴───────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┌───────┬───────┬───────┬───────┐\n",
+       "│       │       │       │       │\n",
+       "│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "├───────┼───────┼───────┼───────┤\n",
+       "│       │       │       │       │\n",
+       "│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "└───────┴───────┴───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┬───────┬───────┐\n", + "│ │ │ │ │\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "├───────┼───────┼───────┼───────┤\n", + "│ │ │ │ │\n", + "│ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "└───────┴───────┴───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1575,7 +1564,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 46, "metadata": { "id": "8g_2Y8wp3vGT" }, @@ -1584,43 +1573,83 @@ "@jax.jit\n", "def f(x):\n", " x = x + 1\n", - " y = jax.lax.with_sharding_constraint(x, sharding.replicate())\n", + " y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))\n", " return y" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 47, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 347 + }, "id": "AiRFtVsR3vGT", - "outputId": "f3e28a70-46cf-46fb-c801-82f0ddb447e4" + "outputId": "2edacc2c-ac80-4519-c9d1-bee364a22b31" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n", - "┌───────────────────────┐\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ TPU 0,1,2,3,4,5,6,7 │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│  TPU 0,1,2,3,4,5,6,7  │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1669,7 +1698,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 48, "metadata": { "id": "mEKF3zIF3vGU" }, @@ -1681,7 +1710,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 49, "metadata": { "id": "Mocs3oGe3vGU" }, @@ -1701,7 +1730,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 50, "metadata": { "id": "glBB8tzW3vGU" }, @@ -1713,7 +1742,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 51, "metadata": { "id": "R0x62AIa3vGU" }, @@ -1752,33 +1781,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 52, + "metadata": { + "id": "mJLqRPpSDX0i" + }, + "outputs": [], + "source": [ + "mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')" + ] + }, + { + "cell_type": "code", + "execution_count": 54, "metadata": { "id": "_Q5NbdOn3vGV" }, "outputs": [], "source": [ - "sharding = PositionalSharding(jax.devices()).reshape(8, 1)" + "sharding = NamedSharding(mesh, P('batch'))\n", + "replicated_sharding = NamedSharding(mesh, P())" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 55, "metadata": { "id": "3KC6ieEe3vGV" }, "outputs": [], "source": [ "batch = jax.device_put(batch, sharding)\n", - "params = jax.device_put(params, sharding.replicate())" + "params = jax.device_put(params, replicated_sharding)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 56, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "MUb-QE2b3vGV", - "outputId": "1f831ea5-5a30-49ad-8195-977ff7ed476a" + "outputId": "5a27f007-c572-44f8-9f49-6e745ee739e8" }, "outputs": [ { @@ -1787,7 +1831,7 @@ "Array(23.469475, dtype=float32)" ] }, - "execution_count": 57, + "execution_count": 56, "metadata": {}, "output_type": "execute_result" } @@ -1798,17 +1842,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 57, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "HUkw0u413vGV", - "outputId": "dfa2599c-9440-4657-9035-0dc3bbf625e1" + "outputId": "07e481a1-97fb-4bd0-d754-cb6d8317bff6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10.760101\n" + "10.760109\n" ] } ], @@ -1825,17 +1872,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 58, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "paCw6Zaj3vGV", - "outputId": "8ab1c32c-f2b1-465c-df71-f5a599e7f19e" + "outputId": "ad4cce34-3a6a-4d44-9a86-477a7fee4841" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 26.3 ms per loop\n" + "53.8 ms ± 1.14 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -1845,7 +1895,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 59, "metadata": { "id": "BF86UWpg3vGV" }, @@ -1857,17 +1907,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 60, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "Z1wgUKXk3vGV", - "outputId": "74df8892-c349-41dc-cb1b-e0843ec5c994" + "outputId": "d66767b7-3f17-482f-b811-919bb1793277" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 122 ms per loop\n" + "351 ms ± 81.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -1886,50 +1939,88 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 61, "metadata": { - "id": "N5-zzgW03vGW" + "id": "k1hxOfgRDwo0" }, "outputs": [], "source": [ - "sharding = sharding.reshape(4, 2)" + "mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 62, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 314 + }, "id": "sgIWCjJK3vGW", - "outputId": "b2fdc556-05cc-4e68-fa04-48643d194dee" + "outputId": "8cb0f19f-3942-415c-c57a-31bb81784f46" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┐\n", - "│TPU 0,1│\n", - "├───────┤\n", - "│TPU 2,3│\n", - "├───────┤\n", - "│TPU 4,5│\n", - "├───────┤\n", - "│TPU 6,7│\n", - "└───────┘\n", - "┌───────┐\n", - "│TPU 0,1│\n", - "├───────┤\n", - "│TPU 2,3│\n", - "├───────┤\n", - "│TPU 4,5│\n", - "├───────┤\n", - "│TPU 6,7│\n", - "└───────┘\n" - ] + "data": { + "text/html": [ + "
┌───────┐\n",
+       "│TPU 0,1│\n",
+       "├───────┤\n",
+       "│TPU 2,3│\n",
+       "├───────┤\n",
+       "│TPU 6,7│\n",
+       "├───────┤\n",
+       "│TPU 4,5│\n",
+       "└───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┐\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n", + "└───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┌───────┐\n",
+       "│TPU 0,1│\n",
+       "├───────┤\n",
+       "│TPU 2,3│\n",
+       "├───────┤\n",
+       "│TPU 6,7│\n",
+       "├───────┤\n",
+       "│TPU 4,5│\n",
+       "└───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┐\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n", + "└───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "batch = jax.device_put(batch, sharding.replicate(1))\n", + "batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))\n", "jax.debug.visualize_array_sharding(batch[0])\n", "jax.debug.visualize_array_sharding(batch[1])" ] @@ -1937,6 +2028,17 @@ { "cell_type": "code", "execution_count": null, + "metadata": { + "id": "q9PQP-0eEAO6" + }, + "outputs": [], + "source": [ + "replicated_sharding = NamedSharding(mesh, P())" + ] + }, + { + "cell_type": "code", + "execution_count": 67, "metadata": { "id": "BqCjYCgg3vGW" }, @@ -1944,45 +2046,65 @@ "source": [ "(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params\n", "\n", - "W1 = jax.device_put(W1, sharding.replicate())\n", - "b1 = jax.device_put(b1, sharding.replicate())\n", + "W1 = jax.device_put(W1, replicated_sharding)\n", + "b1 = jax.device_put(b1, replicated_sharding)\n", "\n", - "W2 = jax.device_put(W2, sharding.replicate(0))\n", - "b2 = jax.device_put(b2, sharding.replicate(0))\n", + "W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))\n", + "b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))\n", "\n", - "W3 = jax.device_put(W3, sharding.replicate(0).T)\n", - "b3 = jax.device_put(b3, sharding.replicate())\n", + "W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))\n", + "b3 = jax.device_put(b3, replicated_sharding)\n", "\n", - "W4 = jax.device_put(W4, sharding.replicate())\n", - "b4 = jax.device_put(b4, sharding.replicate())\n", + "W4 = jax.device_put(W4, replicated_sharding)\n", + "b4 = jax.device_put(b4, replicated_sharding)\n", "\n", "params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 68, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "_lSJ63sh3vGW", - "outputId": "5b37aa8b-3226-4805-8282-876e8d06edda" + "outputId": "bcd3e33e-36b5-4787-9cd2-60623fd6e5fa" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────┬───────────┐\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│TPU 0,2,4,6│TPU 1,3,5,7│\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "└───────────┴───────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────┬───────────┐\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "└───────────┴───────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────┬───────────┐\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "└───────────┴───────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1991,28 +2113,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 69, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "fxkfWYkk3vGW", - "outputId": "8a1063c3-540b-47c1-d990-a6845da861f7" + "outputId": "59e60b16-fe37-47d4-8214-96096ffbd79c" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ │\n", - "│ TPU 0,2,4,6 │\n", - "│ │\n", - "│ │\n", - "├───────────────────────┤\n", - "│ │\n", - "│ TPU 1,3,5,7 │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│                       │\n",
+       "│      TPU 0,2,4,6      │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "├───────────────────────┤\n",
+       "│                       │\n",
+       "│      TPU 1,3,5,7      │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ │\n", + "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m │\n", + "│ │\n", + "│ │\n", + "├───────────────────────┤\n", + "│ │\n", + "│ TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -2021,17 +2163,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 70, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "uPCVs-_k3vGW", - "outputId": "de01cdfc-36cb-4823-c692-22c692ef4220" + "outputId": "618516e9-9736-4ca0-dd22-09d094ce57a2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10.760103\n" + "10.760109\n" ] } ], @@ -2041,7 +2186,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 71, "metadata": { "id": "L9JebLK_3vGW" }, @@ -2057,17 +2202,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 72, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "c9Sbl69e3vGX", - "outputId": "8272c5fa-e59f-4953-c2d5-658c42a28712" + "outputId": "2ee3d432-7172-46ca-e01a-614e83345808" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10.752466\n" + "10.752513\n" ] } ], @@ -2077,39 +2225,81 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 73, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 380 + }, "id": "lkAF0dAb3vGX", - "outputId": "acf0df31-c5e1-4683-b73f-b0cd1b0929f8" + "outputId": "6c1e317e-cded-4af4-8080-0de835fa4c71" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────┬───────────┐\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│TPU 0,2,4,6│TPU 1,3,5,7│\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "└───────────┴───────────┘\n", - "┌───────────────────────┐\n", - "│ │\n", - "│ TPU 0,2,4,6 │\n", - "│ │\n", - "│ │\n", - "├───────────────────────┤\n", - "│ │\n", - "│ TPU 1,3,5,7 │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────┬───────────┐\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "└───────────┴───────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────┬───────────┐\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "└───────────┴───────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│                       │\n",
+       "│      TPU 0,2,4,6      │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "├───────────────────────┤\n",
+       "│                       │\n",
+       "│      TPU 1,3,5,7      │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ │\n", + "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m │\n", + "│ │\n", + "│ │\n", + "├───────────────────────┤\n", + "│ │\n", + "│ TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -2120,17 +2310,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 74, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "I1Npor3i3vGX", - "outputId": "4099f6dd-7b46-4123-c1cb-5173c3d3278e" + "outputId": "479c4d81-cb0b-40a5-89ba-394c10dc3297" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10 loops, best of 10: 30.5 ms per loop\n" + "51.4 ms ± 454 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)\n" ] } ], @@ -2173,7 +2366,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 75, "metadata": { "id": "kwS-aQE_3vGX" }, @@ -2185,7 +2378,8 @@ " return x + numbers\n", "\n", "key = jax.random.key(42)\n", - "x_sharding = jax.sharding.PositionalSharding(jax.devices())\n", + "mesh = Mesh(jax.devices(), 'x')\n", + "x_sharding = NamedSharding(mesh, P('x'))\n", "x = jax.device_put(jnp.arange(24), x_sharding)" ] }, @@ -2200,20 +2394,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 76, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 67 + }, "id": "Oi97rpLz3vGY", - "outputId": "204a7e8d-dc88-4b77-b7e3-0e72f306c5d3" + "outputId": "9dd63254-a483-4847-c0f5-5a4367bf08e9" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", - "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n", - "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" - ] + "data": { + "text/html": [ + "
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
+       "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n",
+       "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -2231,10 +2437,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 77, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "64wIZuSJ3vGY", - "outputId": "1054fe99-0476-44ec-9693-b0d8f98bf6a8" + "outputId": "fa166d45-ca9c-457a-be84-bcc9236d0730" }, "outputs": [ { @@ -2261,10 +2470,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 78, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "1I7bqxA63vGY", - "outputId": "ec4c579d-f446-4b48-ceda-785c09ba299b" + "outputId": "756e0a36-ff14-438f-bbd4-3ef03f97a47b" }, "outputs": [ { @@ -2292,20 +2504,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 79, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 67 + }, "id": "zHPJzdn23vGY", - "outputId": "a8904d20-4d04-4f59-8eae-281e47d29246" + "outputId": "3332de0f-4827-4f0b-b9ef-69249b7c6bc6" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", - "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n", - "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" - ] + "data": { + "text/html": [ + "
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
+       "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n",
+       "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -2323,10 +2547,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 80, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "nBUHBBal3vGY", - "outputId": "f194c213-0688-4b7a-ffb8-c4453b82b1f1" + "outputId": "4b9be948-ccab-4a31-a06f-37ec9c7b5235" }, "outputs": [ { @@ -2371,10 +2598,10 @@ "metadata": { "accelerator": "TPU", "colab": { + "gpuType": "V28", "provenance": [], "toc_visible": true }, - "gpuClass": "standard", "jupytext": { "formats": "ipynb,md:myst" }, diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index b9ec9dc694d2..43c14bc41da4 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -52,7 +52,7 @@ if len(jax.local_devices()) < 8: ## Intro and a quick example -By reading this tutorial notebook, you'll learn about `jax.Array`, a unified +By reading this tutorial notebook, you'll learn about `jax.Array`, a unified datatype for representing arrays, even with physical storage spanning multiple devices. You'll also learn about how using `jax.Array`s together with `jax.jit` can provide automatic compiler-based parallelization. @@ -64,24 +64,29 @@ First, we'll create a `jax.Array` sharded across multiple devices: :id: Gf2lO4ii3vGG from jax.experimental import mesh_utils -from jax.sharding import PositionalSharding +from jax.sharding import Mesh, PartitionSpec as P, NamedSharding ``` ```{code-cell} :id: q-XBTEoy3vGG # Create a Sharding object to distribute a value across devices: -sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) +mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)), + axis_names=('x', 'y')) ``` ```{code-cell} -:id: vI39znW93vGH -:outputId: 3b518df8-5c29-4848-acc3-e41df939f30b - +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: vI39znW93vGH +outputId: 4f702753-8add-4b65-a4af-0f18f098cc46 +--- # Create an array of random values: x = jax.random.normal(jax.random.key(0), (8192, 8192)) # and use jax.device_put to distribute it across devices: -y = jax.device_put(x, sharding.reshape(4, 2)) +y = jax.device_put(x, NamedSharding(mesh, P('x', 'y'))) jax.debug.visualize_array_sharding(y) ``` @@ -91,9 +96,13 @@ Next, we'll apply a computation to it and visualize how the result values are stored across multiple devices too: ```{code-cell} -:id: -qCnHZl83vGI -:outputId: 9da9c29e-ce88-4425-e1ec-e93e5bcf3106 - +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: -qCnHZl83vGI +outputId: 0e131c23-5765-43ae-f232-6417ae1acbb2 +--- z = jnp.sin(y) jax.debug.visualize_array_sharding(z) ``` @@ -104,17 +113,23 @@ The evaluation of the `jnp.sin` application was automatically parallelized across the devices on which the input values (and output values) are stored: ```{code-cell} -:id: _VTzN0r03vGI -:outputId: c9208010-984b-442b-d105-c8c6a3a010e6 - +--- +colab: + base_uri: https://localhost:8080/ +id: _VTzN0r03vGI +outputId: c03eecab-4c86-4dac-d776-5fc72cbb5273 +--- # `x` is present on a single device %timeit -n 5 -r 5 jnp.sin(x).block_until_ready() ``` ```{code-cell} -:id: QuzhU1g63vGI -:outputId: d48fc76e-79a7-47b9-d392-b18a1c33c798 - +--- +colab: + base_uri: https://localhost:8080/ +id: QuzhU1g63vGI +outputId: 8135cca0-871b-4b6a-a7e5-02e78c2028c7 +--- # `y` is sharded across 8 devices. %timeit -n 5 -r 5 jnp.sin(y).block_until_ready() ``` @@ -128,7 +143,7 @@ Now let's look at each of these pieces in more detail! +++ {"id": "W6HsXauGxL6w"} -### Sharding basics, and the `PositionalSharding` subclass +### Sharding basics, and the `NamedSharding` subclass +++ {"id": "NWDyp_EjVHkg"} @@ -146,9 +161,13 @@ x = jax.random.normal(jax.random.key(0), (8192, 8192)) ``` ```{code-cell} -:id: vNRabO2J3vGJ -:outputId: 73db7b6e-c2e7-467d-a0ef-c35e29e582dd - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: vNRabO2J3vGJ +outputId: 40fd7172-a16c-4dd8-e2e1-17bb3afe5409 +--- jax.debug.visualize_array_sharding(x) ``` @@ -159,169 +178,14 @@ Here, we're using the `jax.debug.visualize_array_sharding` function to show wher But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order: ```{code-cell} -:id: VUIEIzRp3vGK - -from jax.experimental import mesh_utils -devices = mesh_utils.create_device_mesh((8,)) -``` - -+++ {"id": "lbOKFWmBX1iv"} - -Then, we create a `PositionalSharding` and use it with `device_put`: - -```{code-cell} -:id: jwrWfZeB3vGK -:outputId: e6f126bd-f6bd-48c7-c130-6f02757e3342 - -from jax.sharding import PositionalSharding - -sharding = PositionalSharding(devices) - -x = jax.device_put(x, sharding.reshape(8, 1)) -jax.debug.visualize_array_sharding(x) -``` - -+++ {"id": "TUu69IWXZdTm"} - -Here `sharding` is a `PositionalSharding` which acts like an array with sets of devices as elements: - -```{code-cell} -:id: zxWB82Kz3vGK -:outputId: 11384a6b-fabc-4c4c-bcad-a3be51eb0465 - -sharding -``` - -+++ {"id": "uRLpOcmNj_Vt"} - -The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device. - -By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it: - -```{code-cell} -:id: PLsnpSzc3vGL -:outputId: 9f4db733-cafe-46ae-c057-dc31046a6f66 - -sharding.reshape(8, 1) -``` - -```{code-cell} -:id: iqKdI4LO3vGL -:outputId: 6aa10fc2-cec4-4401-a0df-343e71646e0a - -sharding.reshape(4, 2) -``` - -+++ {"id": "KBu6WLfhm7ra"} - -To use `device_put` with a data array `x`, we can reshape the `sharding` into a shape that is _congruent_ with `x.shape`, meaning a shape with the same length as `x.shape` and where each element evenly divides the corresponding element of `x.shape`: -```python -def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool: - return (len(x_shape) == len(sharding_shape) and - all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape))) -``` - -For example, we can reshape `sharding` to have shape `(4, 2)`, then use it in a `device_put`: - -```{code-cell} -:id: SELr4xNi3vGL -:outputId: b2f4acec-0cd3-4829-ca16-cae2e0e8ca60 - -sharding = sharding.reshape(4, 2) -print(sharding) -``` - -```{code-cell} -:id: 8IVIsqfX3vGL -:outputId: 033d0e02-a643-4f4c-9d24-9cd8465bc69a - -y = jax.device_put(x, sharding) -jax.debug.visualize_array_sharding(y) -``` - -+++ {"id": "tyg9F-UIsU__"} - -Here `y` represents the same _value_ as `x`, but its shards (i.e. slices) are stored in different devices' memories. - -Different `PositionalSharding` shapes result in different distributed layouts (i.e. shardings) of the result: - -```{code-cell} -:id: cCjt6QCz3vGM -:outputId: 4ad8a611-596d-424f-b6c5-fc00f1adc306 - -sharding = sharding.reshape(1, 8) -print(sharding) -``` - -```{code-cell} -:id: yTK4Nz3u3vGM -:outputId: e445c6bc-4fe3-4e9d-cc9e-d82858f58312 - -y = jax.device_put(x, sharding) -jax.debug.visualize_array_sharding(y) -``` - -+++ {"id": "0PuamOvXubcf"} - -In some cases, we don't just want to store each slice of `x` in a single device's memory; we might want to _replicate_ some slices, meaning storing copies of a slice's values in multiple devices' memories. - -With `PositionalSharding`, we can express replication by calling the reducer method `replicate`: - -```{code-cell} -:id: _jr6XYKx3vGM -:outputId: 59c8b9a4-b8af-493a-ba8d-da5931e88f93 - -sharding = sharding.reshape(4, 2) -print(sharding.replicate(axis=0, keepdims=True)) -``` - -```{code-cell} -:id: S5vzjFuH3vGN -:outputId: b6ce2675-7261-4e57-fa8c-b4e87abf7e52 - -y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True)) -jax.debug.visualize_array_sharding(y) -``` - -+++ {"id": "FzeP0kpTvJv-"} - -Here the visualization shows that `x` is sharded two ways along its second dimension (and not sharded along the first dimension), and each of those shards is replicated four ways (i.e. stored in four device memories). - -The `replicate` method is analogous to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed: - -```{code-cell} -:id: DR7VV-6e3vGN -:outputId: f879fc2c-5723-4199-b306-295bc1b3681e - -print(sharding.replicate(0).shape) -print(sharding.replicate(1).shape) -``` - -```{code-cell} -:id: agUtVUVx3vGN -:outputId: 0e9789ef-ce52-4ed6-8bd5-c876b95f66e6 - -y = jax.device_put(x, sharding.replicate(1)) -jax.debug.visualize_array_sharding(y) -``` - -+++ {"id": "D31t5POXxHHJ"} - -### `NamedSharding` gives a way to express shardings with names - -+++ {"id": "ayMKWeTmxl-X"} - -So far we've worked with `PositionalSharding`, but there are alternative ways to express shardings. In fact, `Sharding` is an interface, and any class that implements that interface can be used with functions like `device_put`. - -Another convenient way to express sharding is with the `NamedSharding`: - -```{code-cell} -:id: zpB1JxyK3vGN -:outputId: 46d5da37-840c-49d8-8380-a162811bae8a - -from jax.sharding import Mesh -from jax.sharding import PartitionSpec -from jax.sharding import NamedSharding +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: zpB1JxyK3vGN +outputId: 8e385462-1c2c-4256-c38a-84299d3bd02c +--- +from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.experimental import mesh_utils P = PartitionSpec @@ -351,9 +215,13 @@ def mesh_sharding( ``` ```{code-cell} -:id: zp3MfS4Y3vGO -:outputId: 2c2f7201-c2c1-49e5-f8a5-0730c124d89a - +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: zp3MfS4Y3vGO +outputId: 032fdd7e-19a1-45da-e1ad-b3227fa43ee6 +--- y = jax.device_put(x, mesh_sharding(P('a', 'b'))) jax.debug.visualize_array_sharding(y) ``` @@ -363,17 +231,25 @@ jax.debug.visualize_array_sharding(y) Here, we use `P('a', 'b')` to express that the first and second axes of `x` should be sharded over the device mesh axes `'a'` and `'b'`, respectively. We can easily switch to `P('b', 'a')` to shard the axes of `x` over different devices: ```{code-cell} -:id: FigK5Zsa3vGO -:outputId: eca784e8-33fe-4e9b-a41d-21e9ee781a35 - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: FigK5Zsa3vGO +outputId: e488d073-9d02-4376-a6af-19d6d5509c7d +--- y = jax.device_put(x, mesh_sharding(P('b', 'a'))) jax.debug.visualize_array_sharding(y) ``` ```{code-cell} -:id: hI-HD0xN3vGO -:outputId: c3e7dc3c-4048-448a-ef0b-50683532fcdc - +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: hI-HD0xN3vGO +outputId: b0c2e863-3aee-4417-b45f-21b2187f6ef7 +--- # This `None` means that `x` is not sharded on its second dimension, # and since the Mesh axis name 'b' is not mentioned, shards are # replicated across it. @@ -388,17 +264,25 @@ Here, because `P('a', None)` doesn't mention the `Mesh` axis name `'b'`, we get To shard only over the second axis of `x`, we can use a `None` placeholder in the `PartitionSpec`: ```{code-cell} -:id: EXBExMQC3vGP -:outputId: fe1c8d7e-3345-4438-b9d2-780e7854b4eb - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: EXBExMQC3vGP +outputId: c80e6177-12a6-40ef-b4e4-934dad22da3d +--- y = jax.device_put(x, mesh_sharding(P(None, 'b'))) jax.debug.visualize_array_sharding(y) ``` ```{code-cell} -:id: PjUpG8uz3vGP -:outputId: 64d8224d-15d9-4ad4-d613-f7f85b1dc1af - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: PjUpG8uz3vGP +outputId: a0f59dc5-b509-4b8b-bd22-bcd69f696763 +--- y = jax.device_put(x, mesh_sharding(P(None, 'a'))) jax.debug.visualize_array_sharding(y) ``` @@ -408,9 +292,13 @@ jax.debug.visualize_array_sharding(y) For a fixed mesh, we can even partition one logical axis of `x` over multiple device mesh axes: ```{code-cell} -:id: fVcPbDUA3vGP -:outputId: 7f524ba5-a6d8-4490-cda9-685ad11416f9 - +--- +colab: + base_uri: https://localhost:8080/ + height: 298 +id: fVcPbDUA3vGP +outputId: da3f435d-dfc1-4a41-ec90-691cd7c748a0 +--- y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None))) jax.debug.visualize_array_sharding(y) ``` @@ -432,16 +320,19 @@ For example, the simplest computation is an elementwise one: ```{code-cell} :id: _EmQwggc3vGQ -from jax.experimental import mesh_utils -from jax.sharding import PositionalSharding -sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) +devices = mesh_utils.create_device_mesh((4, 2)) +mesh = Mesh(devices, axis_names=('a', 'b')) ``` ```{code-cell} -:id: LnT0vWjc3vGQ -:outputId: 8089effc-aa4c-49e3-dd19-7064881dbad0 - -x = jax.device_put(x, sharding.reshape(4, 2)) +--- +colab: + base_uri: https://localhost:8080/ + height: 349 +id: LnT0vWjc3vGQ +outputId: 8e642049-61eb-458d-af79-ac449b58d11b +--- +x = jax.device_put(x, NamedSharding(mesh, P('a', 'b'))) print('input sharding:') jax.debug.visualize_array_sharding(x) @@ -459,11 +350,15 @@ In other words, even though we wrote the `jnp.sin` computation as if a single ma We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs: ```{code-cell} -:id: Dq043GkP3vGQ -:outputId: 350219a8-1e4a-4404-fe14-50f97ea3e7ba - -y = jax.device_put(x, sharding.reshape(4, 2).replicate(1)) -z = jax.device_put(x, sharding.reshape(4, 2).replicate(0)) +--- +colab: + base_uri: https://localhost:8080/ + height: 548 +id: Dq043GkP3vGQ +outputId: 3eff7b67-d7f0-4212-c9d3-2cc271ac1f98 +--- +y = jax.device_put(x, NamedSharding(mesh, P('a', None))) +z = jax.device_put(x, NamedSharding(mesh, P(None, 'b'))) print('lhs sharding:') jax.debug.visualize_array_sharding(y) print('rhs sharding:') @@ -481,32 +376,45 @@ Here the compiler chose the output sharding so that it could maximally paralleli How can we be sure it's actually running in parallel? We can do a simple timing experiment: ```{code-cell} -:id: QjQ5u8qh3vGQ -:outputId: bd29edcd-b87c-486e-c568-906f06ae16be - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: QjQ5u8qh3vGQ +outputId: 0aefc170-833c-4a6a-e003-5990d3db31d9 +--- x_single = jax.device_put(x, jax.devices()[0]) jax.debug.visualize_array_sharding(x_single) ``` ```{code-cell} -:id: 8tn8lOj73vGR -:outputId: 5809b3c8-7333-4cd3-db97-a7aede943dce - +--- +colab: + base_uri: https://localhost:8080/ +id: 8tn8lOj73vGR +outputId: d9898c93-7afc-416b-8c40-4d9551613cd0 +--- np.allclose(jnp.dot(x_single, x_single), jnp.dot(y, z)) ``` ```{code-cell} -:id: D7PpZwhR3vGR -:outputId: 4f0bd43d-0b32-4089-d3da-c8f1449e3526 - +--- +colab: + base_uri: https://localhost:8080/ +id: D7PpZwhR3vGR +outputId: 4901a11b-2354-4d26-a897-b88def07a716 +--- %timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready() ``` ```{code-cell} -:id: rgo_yVHF3vGR -:outputId: 97f19052-f1c9-4d30-f453-07b3a7208aa9 - +--- +colab: + base_uri: https://localhost:8080/ +id: rgo_yVHF3vGR +outputId: e51216cf-b073-4250-d422-67f9fd72f6aa +--- %timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready() ``` @@ -515,9 +423,13 @@ np.allclose(jnp.dot(x_single, x_single), Even copying a sharded `Array` produces a result with the sharding of the input: ```{code-cell} -:id: f1Zw-2lH3vGR -:outputId: a796bed4-07b0-497d-8fd8-31a22ab9762e - +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: f1Zw-2lH3vGR +outputId: 43d7a642-fde4-47a6-901f-dfdc64d6a613 +--- w_copy = jnp.copy(w) jax.debug.visualize_array_sharding(w_copy) ``` @@ -540,35 +452,41 @@ import textwrap from termcolor import colored def print_exception(e): - name = colored(f'{type(e).__name__}', 'red') + name = colored(f'{type(e).__name__}', 'red', force_color=True) print(textwrap.fill(f'{name}: {str(e)}')) ``` ```{code-cell} -:id: DHh0N3vn3vGS -:outputId: e7741882-0ebf-4237-e5d1-e48c9b9c178f - -sharding1 = PositionalSharding(jax.devices()[:4]) -sharding2 = PositionalSharding(jax.devices()[4:]) +--- +colab: + base_uri: https://localhost:8080/ +id: DHh0N3vn3vGS +outputId: 8c4652f7-c484-423b-ad78-182134280187 +--- +sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x')) +sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x')) -y = jax.device_put(x, sharding1.reshape(2, 2)) -z = jax.device_put(x, sharding2.reshape(2, 2)) +y = jax.device_put(x, sharding1) +z = jax.device_put(x, sharding2) try: y + z except ValueError as e: print_exception(e) ``` ```{code-cell} -:id: Im7DkoOl3vGS -:outputId: 3adfe1cb-db52-4a9d-e98e-62c6455c3100 - +--- +colab: + base_uri: https://localhost:8080/ +id: Im7DkoOl3vGS +outputId: 1b6fcd7a-762b-4366-a96d-aea63bad7fe0 +--- devices = jax.devices() permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]] -sharding1 = PositionalSharding(devices) -sharding2 = PositionalSharding(permuted_devices) +sharding1 = NamedSharding(Mesh(devices, 'x'), P('x')) +sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x')) -y = jax.device_put(x, sharding1.reshape(4, 2)) -z = jax.device_put(x, sharding2.reshape(4, 2)) +y = jax.device_put(x, sharding1) +z = jax.device_put(x, sharding2) try: y + z except ValueError as e: print_exception(e) ``` @@ -583,10 +501,13 @@ Unlike committed arrays, uncommitted arrays can be moved and resharded automatic For example, the output of `jnp.zeros`, `jnp.arange`, and `jnp.array` are uncommitted: ```{code-cell} -:id: _QvtKL8r3vGS -:outputId: e0078805-bdfd-436e-f94f-7cd256d2574f - -y = jax.device_put(x, sharding1.reshape(4, 2)) +--- +colab: + base_uri: https://localhost:8080/ +id: _QvtKL8r3vGS +outputId: 761b1208-fe4b-4c09-a7d2-f62152183ef0 +--- +y = jax.device_put(x, sharding1) y + jnp.ones_like(y) y + jnp.arange(y.size).reshape(y.shape) print('no error!') @@ -603,14 +524,14 @@ While the compiler will attempt to decide how a function's intermediate values a ```{code-cell} :id: jniSFm5V3vGT -sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) +mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y')) ``` ```{code-cell} :id: Q1wuDp-L3vGT x = jax.random.normal(jax.random.key(0), (8192, 8192)) -x = jax.device_put(x, sharding.reshape(4, 2)) +x = jax.device_put(x, NamedSharding(mesh, P('x', 'y'))) ``` ```{code-cell} @@ -619,14 +540,18 @@ x = jax.device_put(x, sharding.reshape(4, 2)) @jax.jit def f(x): x = x + 1 - y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4)) + y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x'))) return y ``` ```{code-cell} -:id: zYFS-n4r3vGT -:outputId: d23a7938-cb7d-44b4-b9c7-83edf1d1145e - +--- +colab: + base_uri: https://localhost:8080/ + height: 347 +id: zYFS-n4r3vGT +outputId: 0ac96b8f-ed23-4413-aed9-edd00a841c37 +--- jax.debug.visualize_array_sharding(x) y = f(x) jax.debug.visualize_array_sharding(y) @@ -638,14 +563,18 @@ jax.debug.visualize_array_sharding(y) @jax.jit def f(x): x = x + 1 - y = jax.lax.with_sharding_constraint(x, sharding.replicate()) + y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P())) return y ``` ```{code-cell} -:id: AiRFtVsR3vGT -:outputId: f3e28a70-46cf-46fb-c801-82f0ddb447e4 - +--- +colab: + base_uri: https://localhost:8080/ + height: 347 +id: AiRFtVsR3vGT +outputId: 2edacc2c-ac80-4519-c9d1-bee364a22b31 +--- jax.debug.visualize_array_sharding(x) y = f(x) jax.debug.visualize_array_sharding(y) @@ -727,30 +656,43 @@ params, batch = init_model(jax.random.key(0), layer_sizes, batch_size) ### 8-way batch data parallelism +```{code-cell} +:id: mJLqRPpSDX0i + +mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch') +``` + ```{code-cell} :id: _Q5NbdOn3vGV -sharding = PositionalSharding(jax.devices()).reshape(8, 1) +sharding = NamedSharding(mesh, P('batch')) +replicated_sharding = NamedSharding(mesh, P()) ``` ```{code-cell} :id: 3KC6ieEe3vGV batch = jax.device_put(batch, sharding) -params = jax.device_put(params, sharding.replicate()) +params = jax.device_put(params, replicated_sharding) ``` ```{code-cell} -:id: MUb-QE2b3vGV -:outputId: 1f831ea5-5a30-49ad-8195-977ff7ed476a - +--- +colab: + base_uri: https://localhost:8080/ +id: MUb-QE2b3vGV +outputId: 5a27f007-c572-44f8-9f49-6e745ee739e8 +--- loss_jit(params, batch) ``` ```{code-cell} -:id: HUkw0u413vGV -:outputId: dfa2599c-9440-4657-9035-0dc3bbf625e1 - +--- +colab: + base_uri: https://localhost:8080/ +id: HUkw0u413vGV +outputId: 07e481a1-97fb-4bd0-d754-cb6d8317bff6 +--- step_size = 1e-5 for _ in range(30): @@ -762,9 +704,12 @@ print(loss_jit(params, batch)) ``` ```{code-cell} -:id: paCw6Zaj3vGV -:outputId: 8ab1c32c-f2b1-465c-df71-f5a599e7f19e - +--- +colab: + base_uri: https://localhost:8080/ +id: paCw6Zaj3vGV +outputId: ad4cce34-3a6a-4d44-9a86-477a7fee4841 +--- %timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready() ``` @@ -776,9 +721,12 @@ params_single = jax.device_put(params, jax.devices()[0]) ``` ```{code-cell} -:id: Z1wgUKXk3vGV -:outputId: 74df8892-c349-41dc-cb1b-e0843ec5c994 - +--- +colab: + base_uri: https://localhost:8080/ +id: Z1wgUKXk3vGV +outputId: d66767b7-3f17-482f-b811-919bb1793277 +--- %timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready() ``` @@ -787,58 +735,79 @@ params_single = jax.device_put(params, jax.devices()[0]) ### 4-way batch data parallelism and 2-way model tensor parallelism ```{code-cell} -:id: N5-zzgW03vGW +:id: k1hxOfgRDwo0 -sharding = sharding.reshape(4, 2) +mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model')) ``` ```{code-cell} -:id: sgIWCjJK3vGW -:outputId: b2fdc556-05cc-4e68-fa04-48643d194dee - -batch = jax.device_put(batch, sharding.replicate(1)) +--- +colab: + base_uri: https://localhost:8080/ + height: 314 +id: sgIWCjJK3vGW +outputId: 8cb0f19f-3942-415c-c57a-31bb81784f46 +--- +batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None))) jax.debug.visualize_array_sharding(batch[0]) jax.debug.visualize_array_sharding(batch[1]) ``` +```{code-cell} +:id: q9PQP-0eEAO6 + +replicated_sharding = NamedSharding(mesh, P()) +``` + ```{code-cell} :id: BqCjYCgg3vGW (W1, b1), (W2, b2), (W3, b3), (W4, b4) = params -W1 = jax.device_put(W1, sharding.replicate()) -b1 = jax.device_put(b1, sharding.replicate()) +W1 = jax.device_put(W1, replicated_sharding) +b1 = jax.device_put(b1, replicated_sharding) -W2 = jax.device_put(W2, sharding.replicate(0)) -b2 = jax.device_put(b2, sharding.replicate(0)) +W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model'))) +b2 = jax.device_put(b2, NamedSharding(mesh, P('model'))) -W3 = jax.device_put(W3, sharding.replicate(0).T) -b3 = jax.device_put(b3, sharding.replicate()) +W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None))) +b3 = jax.device_put(b3, replicated_sharding) -W4 = jax.device_put(W4, sharding.replicate()) -b4 = jax.device_put(b4, sharding.replicate()) +W4 = jax.device_put(W4, replicated_sharding) +b4 = jax.device_put(b4, replicated_sharding) params = (W1, b1), (W2, b2), (W3, b3), (W4, b4) ``` ```{code-cell} -:id: _lSJ63sh3vGW -:outputId: 5b37aa8b-3226-4805-8282-876e8d06edda - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: _lSJ63sh3vGW +outputId: bcd3e33e-36b5-4787-9cd2-60623fd6e5fa +--- jax.debug.visualize_array_sharding(W2) ``` ```{code-cell} -:id: fxkfWYkk3vGW -:outputId: 8a1063c3-540b-47c1-d990-a6845da861f7 - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: fxkfWYkk3vGW +outputId: 59e60b16-fe37-47d4-8214-96096ffbd79c +--- jax.debug.visualize_array_sharding(W3) ``` ```{code-cell} -:id: uPCVs-_k3vGW -:outputId: de01cdfc-36cb-4823-c692-22c692ef4220 - +--- +colab: + base_uri: https://localhost:8080/ +id: uPCVs-_k3vGW +outputId: 618516e9-9736-4ca0-dd22-09d094ce57a2 +--- print(loss_jit(params, batch)) ``` @@ -854,25 +823,35 @@ for _ in range(30): ``` ```{code-cell} -:id: c9Sbl69e3vGX -:outputId: 8272c5fa-e59f-4953-c2d5-658c42a28712 - +--- +colab: + base_uri: https://localhost:8080/ +id: c9Sbl69e3vGX +outputId: 2ee3d432-7172-46ca-e01a-614e83345808 +--- print(loss_jit(params, batch)) ``` ```{code-cell} -:id: lkAF0dAb3vGX -:outputId: acf0df31-c5e1-4683-b73f-b0cd1b0929f8 - +--- +colab: + base_uri: https://localhost:8080/ + height: 380 +id: lkAF0dAb3vGX +outputId: 6c1e317e-cded-4af4-8080-0de835fa4c71 +--- (W1, b1), (W2, b2), (W3, b3), (W4, b4) = params jax.debug.visualize_array_sharding(W2) jax.debug.visualize_array_sharding(W3) ``` ```{code-cell} -:id: I1Npor3i3vGX -:outputId: 4099f6dd-7b46-4123-c1cb-5173c3d3278e - +--- +colab: + base_uri: https://localhost:8080/ +id: I1Npor3i3vGX +outputId: 479c4d81-cb0b-40a5-89ba-394c10dc3297 +--- %timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready() ``` @@ -903,7 +882,8 @@ def f(key, x): return x + numbers key = jax.random.key(42) -x_sharding = jax.sharding.PositionalSharding(jax.devices()) +mesh = Mesh(jax.devices(), 'x') +x_sharding = NamedSharding(mesh, P('x')) x = jax.device_put(jnp.arange(24), x_sharding) ``` @@ -912,9 +892,13 @@ x = jax.device_put(jnp.arange(24), x_sharding) On a partitioned input, the function `f` produces output that is also partitioned: ```{code-cell} -:id: Oi97rpLz3vGY -:outputId: 204a7e8d-dc88-4b77-b7e3-0e72f306c5d3 - +--- +colab: + base_uri: https://localhost:8080/ + height: 67 +id: Oi97rpLz3vGY +outputId: 9dd63254-a483-4847-c0f5-5a4367bf08e9 +--- jax.debug.visualize_array_sharding(f(key, x)) ``` @@ -923,9 +907,12 @@ jax.debug.visualize_array_sharding(f(key, x)) But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication: ```{code-cell} -:id: 64wIZuSJ3vGY -:outputId: 1054fe99-0476-44ec-9693-b0d8f98bf6a8 - +--- +colab: + base_uri: https://localhost:8080/ +id: 64wIZuSJ3vGY +outputId: fa166d45-ca9c-457a-be84-bcc9236d0730 +--- f_exe = f.lower(key, x).compile() print('Communicating?', 'collective-permute' in f_exe.as_text()) ``` @@ -935,9 +922,12 @@ print('Communicating?', 'collective-permute' in f_exe.as_text()) One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the "collective permute" operation is now gone from the compiled computation: ```{code-cell} -:id: 1I7bqxA63vGY -:outputId: ec4c579d-f446-4b48-ceda-785c09ba299b - +--- +colab: + base_uri: https://localhost:8080/ +id: 1I7bqxA63vGY +outputId: 756e0a36-ff14-438f-bbd4-3ef03f97a47b +--- jax.config.update('jax_threefry_partitionable', True) f_exe = f.lower(key, x).compile() print('Communicating?', 'collective-permute' in f_exe.as_text()) @@ -948,9 +938,13 @@ print('Communicating?', 'collective-permute' in f_exe.as_text()) The output is still partitioned: ```{code-cell} -:id: zHPJzdn23vGY -:outputId: a8904d20-4d04-4f59-8eae-281e47d29246 - +--- +colab: + base_uri: https://localhost:8080/ + height: 67 +id: zHPJzdn23vGY +outputId: 3332de0f-4827-4f0b-b9ef-69249b7c6bc6 +--- jax.debug.visualize_array_sharding(f(key, x)) ``` @@ -959,9 +953,12 @@ jax.debug.visualize_array_sharding(f(key, x)) One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key: ```{code-cell} -:id: nBUHBBal3vGY -:outputId: f194c213-0688-4b7a-ffb8-c4453b82b1f1 - +--- +colab: + base_uri: https://localhost:8080/ +id: nBUHBBal3vGY +outputId: 4b9be948-ccab-4a31-a06f-37ec9c7b5235 +--- jax.config.update('jax_threefry_partitionable', False) print('Stable:') print(f(key, x)) From 42fe45f34b2f8a5a184d815c509d6910321c1d0f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 8 Aug 2024 00:59:51 -0700 Subject: [PATCH 020/702] [Mosaic TPU] Add support for removal of implicit 2nd minor for all 32-bit tilings PiperOrigin-RevId: 660724215 --- .../tpu/transforms/apply_vector_layout.cc | 33 ++++++++++++++----- tests/pallas/tpu_pallas_test.py | 2 -- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 0e5dfa7b51f4..0670ef1e2f09 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5396,12 +5396,17 @@ FailureOr>> changeImplicitDim( src_candidate.tileArrayImplicitShape(vty.getShape(), target_shape)); return std::make_pair(src_candidate, vregs); } - // Remove second minor implicit dim, for values that have (8, 128) tiling. - // TODO(apaszke): We should allow replicated dst_offset_hints[0]. + // Remove second minor implicit dim, for values that have (m, 128) tiling (for + // m that is a power of 2). if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && dst_implicit_dim == VectorLayout::ImplicitDim::kNone && - src.bitwidth() == 32 && src.tiling() == std::array{8, 128} && - dst_offset_hints[0]) { + src.bitwidth() == 32 && src.tiling()[1] == target_shape[1] && + llvm::isPowerOf2_32(src.tiling()[0])) { + // We should never see a replicated offset here. We're removing the implicit + // dim so the only case when this can happen is when its size is 1 (or else + // we can't prove replication in the logical value). But in that case, the + // equivalentTo case above triggers and we never reach this branch. + CHECK(dst_offset_hints[0].has_value()); int64_t dst_sublane_offset = *dst_offset_hints[0]; VectorLayout dst(src.bitwidth(), {dst_sublane_offset, src.offsets()[1]}, src.tiling(), dst_implicit_dim); @@ -5414,15 +5419,25 @@ FailureOr>> changeImplicitDim( src.insertImplicit(src_idx, 0); const int dst_sl_start = idx[dst_2nd_minor_idx] == 0 ? dst_sublane_offset : 0; - src_idx[dst_2nd_minor_idx] = target_shape[0] * idx[dst_2nd_minor_idx] + + // This could be optimized further to take offsets[1] into account. + // For example, extended offsets allow us to skip copies of low sublanes + // in tiles with idx.back() == 0. + const int tiles_per_vreg = src.tilesPerVreg(target_shape); + const int sublanes_per_tile = src.sublanesPerTile(target_shape); + src_idx[dst_2nd_minor_idx] = src.tiling()[0] * idx[dst_2nd_minor_idx] + dst_sl_start - dst_sublane_offset; for (int dst_sl_idx = dst_sl_start; - dst_sl_idx < target_shape[0] && + dst_sl_idx < src.tiling()[0] && src_idx[dst_2nd_minor_idx] < vregs.dim(dst_2nd_minor_idx); ++dst_sl_idx, ++src_idx[dst_2nd_minor_idx]) { - *tile = copy_one_sublane(builder, vregs(src_idx), - src.offsets()[0].value_or(dst_sl_idx), *tile, - dst_sl_idx, target_shape); + // This could be optimized further by copying multiple sublanes at once. + for (int tile_idx = 0; tile_idx < tiles_per_vreg; ++tile_idx) { + int tile_off = tile_idx * sublanes_per_tile; + *tile = + copy_one_sublane(builder, vregs(src_idx), + tile_off + src.offsets()[0].value_or(dst_sl_idx), + *tile, tile_off + dst_sl_idx, target_shape); + } } }); return std::make_pair(dst, new_vregs); diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 481e301e2db9..2f813afeab14 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2358,9 +2358,7 @@ def kernel(x_ref, y_ref, out_ref): np.testing.assert_array_equal(out, np.zeros((8, 128), dtype=jnp.float32)) - @only_passes_in_interpret() def test_sum(self): - """b/356467588""" x = np.zeros((8, 2, 8, 128), dtype=jnp.float32) def kernel(x_ref, out_ref): From 4ca341701f9752de9510141467d1b91abd701ec7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 8 Aug 2024 05:53:01 -0700 Subject: [PATCH 021/702] Improve documentation for jnp.piecewise & jnp.select --- jax/_src/numpy/lax_numpy.py | 122 +++++++++++++++++++++++++++++++++--- 1 file changed, 115 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 7ef27b1eea66..49bd8ca412fe 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2126,12 +2126,60 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): return util._where(condition, x, y) -@util.implements(np.select) def select( condlist: Sequence[ArrayLike], choicelist: Sequence[ArrayLike], default: ArrayLike = 0, ) -> Array: + """Select values based on a series of conditions. + + JAX implementation of :func:`numpy.select`, implemented in terms + of :func:`jax.lax.select_n` + + Args: + condlist: sequence of array-like conditions. All entries must be mutually + broadcast-compatible. + choicelist: sequence of array-like values to choose. Must have the same length + as ``condlist``, and all entries must be broadcast-compatible with entries + of ``condlist``. + default: value to return when every condition is False (default: 0). + + Returns: + Array of selected values from ``choicelist`` corresponding to the first + ``True`` entry in ``condlist`` at each location. + + See also: + - :func:`jax.numpy.where`: select between two values based on a single condition. + - :func:`jax.lax.select_n`: select between *N* values based on an index. + + Examples: + >>> condlist = [ + ... jnp.array([False, True, False, False]), + ... jnp.array([True, False, False, False]), + ... jnp.array([False, True, True, False]), + ... ] + >>> choicelist = [ + ... jnp.array([1, 2, 3, 4]), + ... jnp.array([10, 20, 30, 40]), + ... jnp.array([100, 200, 300, 400]), + ... ] + >>> jnp.select(condlist, choicelist, default=0) + Array([ 10, 2, 300, 0], dtype=int32) + + This is logically equivalent to the following nested ``where`` statement: + + >>> default = 0 + >>> jnp.where(condlist[0], + ... choicelist[0], + ... jnp.where(condlist[1], + ... choicelist[1], + ... jnp.where(condlist[2], + ... choicelist[2], + ... default))) + Array([ 10, 2, 300, 0], dtype=int32) + + However, for efficiency it is implemented in terms of :func:`jax.lax.select_n`. + """ if len(condlist) != len(choicelist): msg = "condlist must have length equal to choicelist ({} vs {})" raise ValueError(msg.format(len(condlist), len(choicelist))) @@ -8937,16 +8985,76 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: len(bins_arr) - searchsorted(bins_arr[::-1], x, side=side) ) -_PIECEWISE_DOC = """\ -Unlike `np.piecewise`, :py:func:`jax.numpy.piecewise` requires functions in -`funclist` to be traceable by JAX, as it is implemented via :func:`jax.lax.switch`. -See the :func:`jax.lax.switch` documentation for more information. -""" -@util.implements(np.piecewise, lax_description=_PIECEWISE_DOC) def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], funclist: list[ArrayLike | Callable[..., Array]], *args, **kw) -> Array: + """Evaluate a function defined piecewise across the domain. + + JAX implementation of :func:`numpy.piecewise`, in terms of :func:`jax.lax.switch`. + + Note: + Unlike :func:`numpy.piecewise`, :func:`jax.numpy.piecewise` requires functions + in ``funclist`` to be traceable by JAX, as it is implemented via + :func:`jax.lax.switch`. + + Args: + x: array of input values. + condlist: boolean array or sequence of boolean arrays corresponding to the + functions in ``funclist``. If a sequence of arrays, the length of each + array must match the length of ``x`` + funclist: list of arrays or functions; must either be the same length as + ``condlist``, or have length ``len(condlist) + 1``, in which case the + last entry is the default applied when none of the conditions are True. + Alternatively, entries of ``funclist`` may be numerical values, in which + case they indicate a constant function. + args, kwargs: additional arguments are passed to each function in + ``funclist``. + + Returns: + An array which is the result of evaluating the functions on ``x`` at + the specified conditions. + + See also: + - :func:`jax.lax.switch`: choose between *N* functions based on an index. + - :func:`jax.lax.cond`: choose between two functions based on a boolean condition. + - :func:`jax.numpy.where`: choose between two results based on a boolean mask. + - :func:`jax.lax.select`: choose between two results based on a boolean mask. + - :func:`jax.lax.select_n`: choose between *N* results based on a boolean mask. + + Examples: + Here's an example of a function which is zero for negative values, and linear + for positive values: + + >>> x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4]) + + >>> condlist = [x < 0, x >= 0] + >>> funclist = [lambda x: 0 * x, lambda x: x] + >>> jnp.piecewise(x, condlist, funclist) + Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32) + + ``funclist`` can also contain a simple scalar value for constant functions: + + >>> condlist = [x < 0, x >= 0] + >>> funclist = [0, lambda x: x] + >>> jnp.piecewise(x, condlist, funclist) + Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32) + + You can specify a default value by appending an extra condition to ``funclist``: + + >>> condlist = [x < -1, x > 1] + >>> funclist = [lambda x: 1 + x, lambda x: x - 1, 0] + >>> jnp.piecewise(x, condlist, funclist) + Array([-3, -2, -1, 0, 0, 0, 1, 2, 3], dtype=int32) + + ``condlist`` may also be a simple array of scalar conditions, in which case + the associated function applies to the whole range + + >>> condlist = jnp.array([False, True, False]) + >>> funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100] + >>> jnp.piecewise(x, condlist, funclist) + Array([-40, -30, -20, -10, 0, 10, 20, 30, 40], dtype=int32) + """ util.check_arraylike("piecewise", x) nc, nf = len(condlist), len(funclist) if nf == nc + 1: From 551f72979c85b27ca77ce5851a9fa6e5757bf157 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 8 Aug 2024 05:59:19 -0700 Subject: [PATCH 022/702] Rollback of #22869 This is causing breakages due to overly-restrictive checks on kwargs Reverts 893ae6eb800851b1c17c437982608bb59d3bc6be PiperOrigin-RevId: 660803968 --- jax/_src/api_util.py | 11 ----------- jax/_src/custom_batching.py | 9 ++------- jax/_src/custom_derivatives.py | 21 ++++++++++++++++----- tests/api_test.py | 26 -------------------------- 4 files changed, 18 insertions(+), 49 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 481dec0065a5..dd1cdcbe6bb8 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -556,17 +556,6 @@ def _assert_no_intersection(static_argnames, donate_argnames): f"{out} appear in both static_argnames and donate_argnames") -def resolve_kwargs(fun: Callable, args, kwargs): - if isinstance(fun, partial): - fun = lambda *args, **kwargs: None - ba = inspect.signature(fun).bind(*args, **kwargs) - ba.apply_defaults() - if ba.kwargs: - raise TypeError("keyword arguments could not be resolved to positions") - else: - return ba.args - - def _dtype(x): try: return dtypes.result_type(x) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 4b859e910165..4d41849b75d3 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -27,7 +27,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util -from jax._src.api_util import flatten_fun_nokwargs, resolve_kwargs +from jax._src.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters.batching import not_mapped @@ -64,12 +64,7 @@ def def_vmap(self, vmap_rule: Callable) -> Callable: @traceback_util.api_boundary def __call__(self, *args, **kwargs): - fun_name = getattr(self.fun, "__name__", str(self.fun)) - if not self.vmap_rule: - raise AttributeError( - f"No batching rule defined for custom_vmap function {fun_name} " - "using def_vmap.") - args = resolve_kwargs(self.fun, args, kwargs) + assert not kwargs args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index bc9f7a687dcb..d27b0efc7e5e 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -17,6 +17,7 @@ from collections.abc import Callable, Sequence import dataclasses from functools import update_wrapper, reduce, partial, wraps +import inspect from typing import Any, Generic, TypeVar from jax._src import config @@ -29,8 +30,7 @@ from jax._src import traceback_util from jax._src.ad_util import ( stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) -from jax._src.api_util import ( - argnums_partial, flatten_fun_nokwargs, resolve_kwargs) +from jax._src.api_util import argnums_partial, flatten_fun_nokwargs from jax._src.core import raise_to_shaped from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import ad @@ -56,6 +56,17 @@ ### util +def _resolve_kwargs(fun, args, kwargs): + if isinstance(fun, partial): + # functools.partial should have an opaque signature. + fun = lambda *args, **kwargs: None + ba = inspect.signature(fun).bind(*args, **kwargs) + ba.apply_defaults() + if ba.kwargs: + raise TypeError("keyword arguments could not be resolved to positions") + else: + return ba.args + def _initial_style_jaxpr(fun, in_avals): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) return jaxpr, consts @@ -229,7 +240,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp." raise AttributeError(msg) jvp_name = getattr(self.jvp, '__name__', str(self.jvp)) - args = resolve_kwargs(self.fun, args, kwargs) + args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple(_stop_gradient(x) if i in nondiff_argnums else x @@ -588,7 +599,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp." raise AttributeError(msg) fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) - args = resolve_kwargs(self.fun, args, kwargs) + args = _resolve_kwargs(self.fun, args, kwargs) if self.optimize_remat: fwd = optimize_remat_of_custom_vjp_fwd( self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums, @@ -1440,7 +1451,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: # above and it would be good to consolidate it. primal_name = getattr(fun, "__name__", str(fun)) fwd_name = getattr(fwd, "__name__", str(fwd)) - args = resolve_kwargs(fwd, args, kwargs) + args = _resolve_kwargs(fwd, args, kwargs) if nondiff_argnums: for i in nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums_ = set(nondiff_argnums) diff --git a/tests/api_test.py b/tests/api_test.py index 4aafc42b7a0e..cb0d7c0d40c7 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10798,32 +10798,6 @@ def g(x, a): self.assertAllClose(y, (x + a)**2) - def test_kwargs(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x=x) - self.assertAllClose(y, jnp.sin(x)) - ys = api.vmap(f)(x=xs) - self.assertAllClose(ys, jnp.cos(xs)) - - def test_undefined_rule(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - with self.assertRaisesRegex( - AttributeError, "No batching rule defined for custom_vmap function f"): - f(0.5) - - class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" From 595ca0affad273cc4e4106226bb2511107130135 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 8 Aug 2024 14:08:51 +0100 Subject: [PATCH 023/702] Improve error message for missing vmap rule in custom_vmap. This is a partial re-land of https://github.com/google/jax/pull/22869 after it was rolled back to fix internal users. This part of the change didn't cause the issues, and I'll follow up with the rest of the changes in a second PR. --- jax/_src/custom_batching.py | 5 +++++ tests/api_test.py | 10 +++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 4d41849b75d3..1d405c4e5bbf 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -65,6 +65,11 @@ def def_vmap(self, vmap_rule: Callable) -> Callable: @traceback_util.api_boundary def __call__(self, *args, **kwargs): assert not kwargs + fun_name = getattr(self.fun, "__name__", str(self.fun)) + if not self.vmap_rule: + raise AttributeError( + f"No batching rule defined for custom_vmap function {fun_name} " + "using def_vmap.") args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] diff --git a/tests/api_test.py b/tests/api_test.py index cb0d7c0d40c7..f04c7d307f97 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10780,7 +10780,6 @@ def test_batch_map_pytrees(self, batch_size: int): ) self.assertAllClose(outputs['b'], expected) - def test_batch_divides_axis(self): def f(t): x, a = t @@ -10798,6 +10797,15 @@ def g(x, a): self.assertAllClose(y, (x + a)**2) + def test_undefined_rule(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + with self.assertRaisesRegex( + AttributeError, "No batching rule defined for custom_vmap function f"): + f(0.5) + + class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" From 04a753ad02f35f22e9ad65ce19f8c455f10ff113 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 8 Aug 2024 07:20:57 -0700 Subject: [PATCH 024/702] [Mosaic TPU] Improve an error message in case someone tries to extract a non-32-bit scalar. PiperOrigin-RevId: 660826696 --- jax/_src/pallas/mosaic/lowering.py | 6 ++++++ .../tpu/transforms/canonicalize_mosaic.cc | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 7299f9929fe3..bef39142c120 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1617,6 +1617,12 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): (aval_in,) = ctx.avals_in (aval_out,) = ctx.avals_out if not aval_out.shape: + if aval_out.dtype.itemsize != 4: + raise ValueError( + "Only arrays with 32-bit element types can be converted to scalars," + f" but got: {aval_out.dtype}. Try casting the input before squeezing" + " the scalar." + ) return vector.ExtractOp(x, [], [0] * len(aval_in.shape)).result return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 54c0776514df..93c362a43671 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -317,6 +317,21 @@ LogicalResult canonicalize_contraction(int hardware_generation, Operation &op) { return result; } +LogicalResult canonicalize_extract(int hardware_generation, Operation &raw_op) { + auto op = dyn_cast(raw_op); + Type result_ty = op.getResult().getType(); + if (!isa(result_ty)) { + bool is_supported = result_ty.isSignlessIntOrFloat() && + result_ty.getIntOrFloatBitWidth() == 32; + if (!is_supported) { + return op.emitOpError( + "Only 32-bit scalar vector.extracts supported. Cast your input to a " + "32-bit type first."); + } + } + return success(); +} + using canonicalize_rule_type = std::function; @@ -324,6 +339,7 @@ const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ {tpu::MatmulOp::getOperationName(), canonicalize_matmul}, {vector::ContractionOp::getOperationName(), canonicalize_contraction}, + {vector::ContractionOp::getOperationName(), canonicalize_extract}, {vector::MultiDimReductionOp::getOperationName(), canonicalize_multi_dim_reduction}}; return *rules; From 11d9c2de2c7ef5089b56b28461932f2500835f8d Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 8 Aug 2024 07:35:06 -0700 Subject: [PATCH 025/702] Update GPU implementation of `lu_pivots_to_permutation` to infer the permutation size directly from the input dimensions, instead of using an input parameter. I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways. In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval. PiperOrigin-RevId: 660831000 --- jax/_src/lax/linalg.py | 5 ++-- jaxlib/gpu/linalg_kernels.cc | 47 +++++++++++++++++++++++++++--------- jaxlib/gpu_linalg.py | 30 ++++++++++------------- tests/extend_test.py | 4 +-- 4 files changed, 53 insertions(+), 33 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 3bd7e37e54ca..cecfa253ec11 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1172,10 +1172,11 @@ def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size): 'Argument to lu_pivots_to_permutation must have rank >= 1 and dtype ' 'int32. Got shape={} and dtype={}'.format(pivots.shape, pivots.dtype)) - if permutation_size < pivots.shape[-1]: + pivots_size = pivots.shape[-1] + if permutation_size < pivots_size: raise ValueError( 'Output permutation size {} has to exceed the trailing dimension of ' - 'the pivots. Got shape {}'.format(permutation_size, pivots.shape)) + 'the pivots. Got pivots size {}'.format(permutation_size, pivots_size)) batch_dims = pivots.shape[:-1] permutations = pivots.update(shape=batch_dims + (permutation_size,)) diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index 6636f5654180..6b143e893264 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include #include -#include #include #include +#include #include "absl/algorithm/container.h" #include "absl/status/status.h" @@ -60,23 +60,43 @@ void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque, } namespace { -ffi::Error LuPivotsToPermutationImpl( - gpuStream_t stream, std::int32_t permutation_size, - ffi::Buffer pivots, - ffi::Result> permutation) { - auto dims = pivots.dimensions(); - +absl::StatusOr> GetDimensions( + ffi::Span dims, const std::string& arg_name) { if (dims.size() < 1) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "pivots must have at least one dimension"); + return absl::InvalidArgumentError( + absl::StrFormat("%s must have at least one dimension", arg_name)); } - FFI_ASSIGN_OR_RETURN(std::int32_t pivot_size, - MaybeCastNoOverflow(dims.back())); std::int64_t batch_size = 1; if (dims.size() >= 2) { batch_size = absl::c_accumulate(dims.first(dims.size() - 1), 1, std::multiplies<>()); } + JAX_ASSIGN_OR_RETURN(auto size, + MaybeCastNoOverflow(dims.back())); + return std::make_pair(batch_size, size); +} + +ffi::Error LuPivotsToPermutationImpl( + gpuStream_t stream, ffi::Dictionary /* unused */, + ffi::Buffer pivots, + ffi::Result> permutation) { + FFI_ASSIGN_OR_RETURN(auto pivots_dims, + GetDimensions(pivots.dimensions(), "pivots")); + FFI_ASSIGN_OR_RETURN(auto permutation_dims, + GetDimensions(permutation->dimensions(), "permutation")); + auto [batch_size, pivot_size] = pivots_dims; + auto [permutation_batch, permutation_size] = permutation_dims; + if (permutation_batch != batch_size) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "pivots and permutation must have the same batch size."); + } + if (permutation_size < pivot_size) { + return ffi::Error( + ffi::ErrorCode::kInvalidArgument, + absl::StrFormat("Output permutation size %d must match or exceed the " + "trailing dimension of the input pivots %d.", + permutation_size, pivot_size)); + } LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size, permutation_size, pivots.typed_data(), permutation->typed_data()); @@ -88,7 +108,10 @@ ffi::Error LuPivotsToPermutationImpl( XLA_FFI_DEFINE_HANDLER_SYMBOL(LuPivotsToPermutation, LuPivotsToPermutationImpl, ffi::Ffi::Bind() .Ctx>() - .Attr("permutation_size") + // TODO(b/358275922): remove Attrs (and the + // unused Dictionary above) 12 weeks after + // release of jaxlib v0.4.32. + .Attrs() .Arg>() .Ret>()); diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index af79b3ae756f..32d31e7206a4 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -20,7 +20,7 @@ import jaxlib.mlir.ir as ir -from .hlo_helpers import custom_call +from .hlo_helpers import custom_call, mk_result_types_and_shapes from .gpu_common_utils import GpuLibNotLinkedError from jaxlib import xla_client @@ -61,39 +61,35 @@ _prod = lambda xs: functools.reduce(operator.mul, xs, 1) -def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_size): +def _lu_pivots_to_permutation_hlo(platform, pivots, *, permutation_size): """Kernel for the transformation of pivots to permutations on GPU.""" typ = ir.RankedTensorType(pivots.type) dims = typ.shape i32_type = ir.IntegerType.get_signless(32) - assert typ.element_type == i32_type, typ - if not gpu_linalg: - raise GpuLibNotLinkedError() - pivots_layout = tuple(range(len(dims) - 1, -1, -1)) permutations_layout = pivots_layout - permutations_dims = list(dims) - permutations_dims[-1] = permutation_size - permutations_type = ir.RankedTensorType.get(permutations_dims, i32_type) + permutations_dims = (*dims[:-1], permutation_size) + result_types, result_shapes = mk_result_types_and_shapes( + [(permutations_dims, i32_type)]) return custom_call( f"{platform}_lu_pivots_to_permutation", api_version=4, - result_types=[permutations_type], operands=[pivots], + operand_layouts=[pivots_layout], + result_types=result_types, + result_shapes=result_shapes, + result_layouts=[permutations_layout], + # TODO(b/358275922): remove backend_config 12 weeks after release of + # jaxlib v0.4.32. backend_config=dict( permutation_size=ir.IntegerAttr.get(i32_type, permutation_size), ), - operand_layouts=[pivots_layout], - result_layouts=[permutations_layout], ).results -cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "cu", - _cuda_linalg) -hip_lu_pivots_to_permutation = partial( - _lu_pivots_to_permutation_hlo, "hip", _hip_linalg) - +cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "cu") +hip_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "hip") def _cholesky_update_hlo(platform, gpu_linalg, r_matrix, w_vector, dtype): diff --git a/tests/extend_test.py b/tests/extend_test.py index f34d40cd3556..45c689239a1d 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -15,7 +15,8 @@ import os import numpy as np -from absl.testing import absltest, parameterized +from absl.testing import absltest +from absl.testing import parameterized import jax from jax import lax @@ -166,7 +167,6 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True) dtype=pivots.dtype, ), pivots, - permutation_size=np.int32(permutation_size), vectorized=vectorized, ) From 44ae9b30ec3eee24a819e0d48c277ee608ead829 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 8 Aug 2024 16:19:19 +0000 Subject: [PATCH 026/702] fix #22944 --- jax/_src/interpreters/partial_eval.py | 2 +- tests/api_test.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 558f735f4403..94f8918aa2ae 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2821,7 +2821,7 @@ def inline_jaxpr_into_trace( tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], [*consts, *arg_tracers])) def new_tracer(atom): - tracer = DynamicJaxprTracer(trace, atom.aval, src) + tracer = tracer_env[atom] = DynamicJaxprTracer(trace, atom.aval, src) trace.frame.tracers.append(tracer) trace.frame.tracer_to_var[id(tracer)] = env[atom] return tracer diff --git a/tests/api_test.py b/tests/api_test.py index cb0d7c0d40c7..e1162eba0f7d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4850,6 +4850,19 @@ def g(): with self.assertRaisesRegex(TracerBoolConversionError, "Attempted boolean"): f() + def test_inline_return_twice(self): + # https://github.com/google/jax/issues/22944 + @jax.jit + def add_one(x: int) -> int: + return x + 1 + + def add_one_and_dupe(x: int) -> tuple[int, int]: + y = add_one(x) + return (y, y) + + jit_add_one_dupe = jax.jit(add_one_and_dupe, inline=True) + jax.eval_shape(jit_add_one_dupe, 0) # don't crash + class RematTest(jtu.JaxTestCase): From ccc27a7a5fd60329710a9ebf9bb9aeed96a3a49c Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 8 Aug 2024 09:34:57 -0700 Subject: [PATCH 027/702] Remove PJRT version check in memories_test.py that is no longer needed. 0.43 is the version at 2024 Feb. Cloud TPU CI uses 20240228 so it should contain the PJRT C API needed for the test https://github.com/google/jax/blob/d3b6066f91b068dacede0a9f026253f21d1f731a/.github/workflows/cloud-tpu-ci-nightly.yml#L35. PiperOrigin-RevId: 660869710 --- tests/memories_test.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index 816fbeee3e3d..0d3720e807c0 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1480,8 +1480,7 @@ def g(ys, _): compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: - if jtu.pjrt_c_api_version_at_least(0, 43): - self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_layout_change_offloadable(self): if not jtu.test_device_matches(["tpu"]): @@ -1522,8 +1521,7 @@ def g(ys, _): compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: - if jtu.pjrt_c_api_version_at_least(0, 43): - self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_checkpoint_dots_with_no_batch_dims(self): policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( @@ -1554,8 +1552,7 @@ def f(x): compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: - if jtu.pjrt_c_api_version_at_least(0, 43): - self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From e6303244bf67e4efc2a51232287778dc81e47cbf Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 8 Aug 2024 11:23:50 -0700 Subject: [PATCH 028/702] If the memory kind is the default kind throughout the jaxpr, then revert back to the previous device_put behavior which was a no-op inside jit. This is also the same behavior for arguments and outputs, where we don't insert `mhlo.memory_kind` attributes in the stableHLO if the entire jaxpr only has the default memory kind. PiperOrigin-RevId: 660913387 --- jax/_src/dispatch.py | 5 +++++ jax/_src/interpreters/mlir.py | 8 ++++++-- jax/_src/interpreters/pxla.py | 6 +++--- tests/pjit_test.py | 14 ++++++++++++++ 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 6c0b46077dcc..8605c58a81cd 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -548,6 +548,10 @@ def _device_put_batcher(batched_args, batch_dims, **params): batching.primitive_batchers[device_put_p] = _device_put_batcher def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs): + # TODO(yashkatariya): Maybe we should add the custom calls anyways if it's + # being used inside jit? Atleast for now, this preserves the old behavior. + if ctx.module_context.all_default_mem_kind: + return xs def lower(x, device, src, aval, out_aval): if (isinstance(device, (Sharding, TransferToMemoryKind)) and device.memory_kind is not None): @@ -558,6 +562,7 @@ def lower(x, device, src, aval, out_aval): return x return x return list(map(lower, xs, devices, srcs, ctx.avals_in, ctx.avals_out)) + mlir.register_lowering( device_put_p, _tpu_gpu_device_put_lowering, platform='tpu') mlir.register_lowering( diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 31d281b88f11..3a666e357df1 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -604,6 +604,7 @@ class ModuleContext: host_callbacks: list[Any] # Keep state for the lowering of shape polymorphism shape_poly_state: ShapePolyLoweringState + all_default_mem_kind: bool # Cached primitive lowerings. cached_primitive_lowerings: dict[Any, func_dialect.FuncOp] @@ -633,7 +634,8 @@ def __init__( symbol_table: ir.SymbolTable | None = None, cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, traceback_caches: None | TracebackCaches = None, - shape_poly_state = None): + shape_poly_state = None, + all_default_mem_kind: bool = True): self.context = context or make_ir_context() self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context)) @@ -651,6 +653,7 @@ def __init__( self.host_callbacks = host_callbacks self.shape_poly_state = ( shape_poly_state or ShapePolyLoweringState((), tuple(platforms))) + self.all_default_mem_kind = all_default_mem_kind self.lowering_parameters = lowering_parameters @property @@ -1034,7 +1037,8 @@ def lower_jaxpr_to_module( channel_iterator=channel_iter, host_callbacks=host_callbacks, lowering_parameters=lowering_parameters, - shape_poly_state=ShapePolyLoweringState(dim_vars, platforms)) + shape_poly_state=ShapePolyLoweringState(dim_vars, platforms), + all_default_mem_kind=all_default_mem_kind) with ctx.context, ir.Location.unknown(ctx.context): # Remove module name characters that XLA would alter. This ensures that # XLA computation preserves the module name. diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 88297bd9204b..2cd09299e60a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2172,8 +2172,6 @@ def lower_sharding_computation( devices_from_context) platforms = lowering_platforms or (backend.platform,) - # TODO(yashkatariya): Enable this when offload APIs are stable. - # transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) committed = bool( devices_from_context or @@ -2184,10 +2182,12 @@ def lower_sharding_computation( da_object = _create_da_object(tuple(device_assignment)) + transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) all_default_mem_kind = are_all_shardings_default_mem_kind( da_object, it.chain(in_shardings, out_shardings, - [js for js, _ in unique_intermediate_shardings])) + [js for js, _ in unique_intermediate_shardings], + transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types # TODO(yashkatariya): Remove this when XLA can propagate memory kinds or when # JAX puts memory kinds in the types of jaxpr. diff --git a/tests/pjit_test.py b/tests/pjit_test.py index df87fed4bb7d..915a6d4b52d2 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3773,6 +3773,20 @@ def test_jit_in_shardings_none(self): self.assertArraysEqual(out2, np_inp * 2) self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0])) + def test_device_put_in_jit_default_mem_kind_no_op(self): + mesh = jtu.create_global_mesh((2,), 'x') + np_inp = np.arange(8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + + @jax.jit + def f(x): + y = x * 2 + return jax.device_put(y, NamedSharding(mesh, P())) + + lowered_text = f.lower(arr).as_text() + self.assertNotIn('@Sharding', lowered_text) + self.assertNotIn('@annotate_device_placement', lowered_text) + def test_jit_both_shardings_none(self): mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) From 751b5742fdb5ff3e36212e8e64ef668de41363ff Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 8 Aug 2024 11:57:18 -0700 Subject: [PATCH 029/702] Deprecate using build_cuda_plugin_from_source flag and rely on jaxlib_build config. If jaxlib needs to be built from source, cuda plugin will be built from source as well. PiperOrigin-RevId: 660926791 --- jaxlib/jax.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index d1decfd3a885..7df848aad843 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -267,10 +267,10 @@ def jax_test( deps = [ "//jax", "//jax:test_util", - ] + deps + if_building_jaxlib(["//jaxlib/cuda:gpu_only_test_deps"]) + select({ - "//jax:enable_build_cuda_plugin_from_source": ["//jax_plugins:gpu_plugin_only_test_deps"], - "//conditions:default": [], - }), + ] + deps + if_building_jaxlib([ + "//jaxlib/cuda:gpu_only_test_deps", + "//jax_plugins:gpu_plugin_only_test_deps", + ]), data = data, shard_count = test_shards, tags = test_tags, From 8105930a94fa569eee04ee7451b78e03dce1d966 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 8 Aug 2024 19:13:30 +0000 Subject: [PATCH 030/702] Add test --- jax/_src/cudnn/fused_attention_stablehlo.py | 2 +- tests/fused_attention_stablehlo_test.py | 23 ++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index eb2ed2ff8bd9..7d3ebe2595b0 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -284,7 +284,7 @@ def check_eq(a, b, c, msg): raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}") def check_is_flash_attention( - query, key, layout, cudnn_version, has_bias, is_training): + query, key, layout: int, cudnn_version, has_bias, is_training): if layout == AttentionLayout.BNTH.value: _, _, T, H = query.shape _, _, S, _ = key.shape diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index bc05c4b2e85c..662425f4b5d7 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -429,18 +429,27 @@ def _cvt_back(x): def test_sdpa_utils(self): test_cases = [ - (1, 257, 64, 8905, False, True), - (1, 1024, 64, 8905, False, False), - (1024, 1024, 64, 8905, False, False), - (1024, 1024, 128, 8905, False, False), + (1, 257, 64, 8905, False, True, True), + (1, 1024, 64, 8905, False, False, True), + (1024, 1024, 64, 8905, False, False, True), + (1024, 1024, 128, 8905, False, False, True), + (1024, 1024, 127, 8905, False, False, False), ] for k in test_cases: - sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training = k + sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training, \ + expected_pass = k query = jnp.empty((4, sql_q, 4, head_dim)) key = jnp.empty((4, sql_v, 4, head_dim)) - check_is_flash_attention( - query, key, AttentionLayout.BNTH, cudnn_version, has_bias, is_training) + if expected_pass: + check_is_flash_attention( + query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, + is_training) + else: + with self.assertRaises(NotImplementedError): + check_is_flash_attention( + query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, + is_training) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From d999208863d7118807e8d5f93174ce32158e805c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 8 Aug 2024 12:33:30 -0700 Subject: [PATCH 031/702] [array API] update test suite to most recent commit --- .github/workflows/jax-array-api.yml | 2 +- pyproject.toml | 3 ++- tests/array_api_skips.txt | 31 +++-------------------------- 3 files changed, 6 insertions(+), 30 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 78cddb411feb..3709f0557a46 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -25,7 +25,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: '33f2d2ea2f3dd2b3ceeeb4519d55e08096184149' # Latest commit as of 2024-05-28 + ref: 'db95e67b29235249e5776ca2b6bb4e77117e0690' # Latest commit as of 2024-08-08 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/pyproject.toml b/pyproject.toml index bc424a13e14b..193c6b9fdad0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,8 @@ filterwarnings = [ # TODO(jakevdp): remove when array_api_tests stabilize "default:.*not machine-readable.*:UserWarning", "default:Special cases found for .* but none were parsed.*:UserWarning", - "default:.*is not JSON-serializable. Using the repr instead.", + "default:.*is not JSON-serializable. Using the repr instead.*:UserWarning", + "default:The .* method is good for exploring strategies.*", # These are transitive warnings coming from TensorFlow dependencies. # TODO(slebedev): Remove once we bump the minimum TensorFlow version. diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index f7d80d94f96f..2ac2edcdfd99 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -4,36 +4,11 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32] # Test suite attempts in-place mutation: -array_api_tests/test_special_cases.py::test_iop -array_api_tests/test_special_cases.py::test_nan_propagation array_api_tests/test_array_object.py::test_setitem +array_api_tests/test_array_object.py::test_setitem_masking -# Raises NonInteractiveExampleWarning -array_api_tests/test_special_cases.py::test_binary -array_api_tests/test_special_cases.py::test_unary - -# Pending implementation update for proper dtype promotion behavior, -# see https://github.com/data-apis/array-api-tests/issues/234 -array_api_tests/test_statistical_functions.py::test_sum -array_api_tests/test_statistical_functions.py::test_prod - -# Pending bugfix, see https://github.com/data-apis/array-api-tests/issues/256 -array_api_tests/test_signatures.py::test_func_signature[logical_and] -array_api_tests/test_signatures.py::test_func_signature[logical_or] -array_api_tests/test_signatures.py::test_func_signature[logical_xor] +# Returns wrong zero sign +array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted - -# Various info functions not yet defined -# Pending bugfix, see https://github.com/data-apis/array-api-tests/pull/262 -array_api_tests/test_has_names.py::test_has_names[info-capabilities] -array_api_tests/test_has_names.py::test_has_names[info-default_device] -array_api_tests/test_has_names.py::test_has_names[info-default_dtypes] -array_api_tests/test_has_names.py::test_has_names[info-devices] -array_api_tests/test_has_names.py::test_has_names[info-dtypes] -array_api_tests/test_signatures.py::test_func_signature[capabilities] -array_api_tests/test_signatures.py::test_func_signature[default_device] -array_api_tests/test_signatures.py::test_func_signature[default_dtypes] -array_api_tests/test_signatures.py::test_func_signature[devices] -array_api_tests/test_signatures.py::test_func_signature[dtypes] From efb77216710edb95fd2cc4de9f933e3b8681b726 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 8 Aug 2024 12:48:46 -0700 Subject: [PATCH 032/702] Remove unnecessary constraint on keyword-only arguments in `custom_vjp` with `optimize_remat=True`. PiperOrigin-RevId: 660945559 --- jax/_src/custom_derivatives.py | 4 +++- tests/api_test.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index d27b0efc7e5e..56accc273dbf 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1451,7 +1451,9 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: # above and it would be good to consolidate it. primal_name = getattr(fun, "__name__", str(fun)) fwd_name = getattr(fwd, "__name__", str(fwd)) - args = _resolve_kwargs(fwd, args, kwargs) + # Note: we use `fun` instead of `fwd` here for consistency with + # custom_vjp.__call__ above. + args = _resolve_kwargs(fun, args, kwargs) if nondiff_argnums: for i in nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums_ = set(nondiff_argnums) diff --git a/tests/api_test.py b/tests/api_test.py index 15c4c8e7ae4f..fa2f389f494b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -9762,6 +9762,23 @@ def f_bwd(res, g): x, y = 3.2, 1.0 self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) + def test_optimize_remat_kwargs(self): + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + def f_fwd(x, y, *, keyword=False): + del keyword + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 3.2, 1.0 + jax.grad(f)(x, y) # Doesn't error + def transpose_unary(f, x_example): def transposed(y): From d8eafc8ee3d44420331e0e15b9b93abe260c718b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 8 Aug 2024 13:01:55 -0700 Subject: [PATCH 033/702] Disabled nn_test under asan on TPU as well, since it also times out PiperOrigin-RevId: 660950262 --- tests/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index b5a99b254c16..07069b08e7b2 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -692,6 +692,9 @@ jax_test( "gpu": [ "noasan", # Times out under asan. ], + "tpu": [ + "noasan", # Times out under asan. + ], }, shard_count = { "cpu": 10, From d28d14917e265f82ec3a1665f8caf42cef697292 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Thu, 8 Aug 2024 13:29:38 -0700 Subject: [PATCH 034/702] Fix error message in dot_product_attention PiperOrigin-RevId: 660960409 --- jax/_src/nn/functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index bfbee04b6a83..821c4a413796 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -937,11 +937,11 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: _check_has_shape(value_arr, [B, S, K, H], 'value') _check_has_shape(query_arr, [B, -1, -1, H], 'query') if query_arr.shape[-2] % K != 0: - raise ValueError(f"The number of query heads must to a multiple of " + raise ValueError(f"The number of query heads must be a multiple of " f"key/value heads, but got {query_arr.shape[-2]} vs {K}") if not (query_arr.dtype == key_arr.dtype == value_arr.dtype): - raise ValueError(f"query/key/value should have the same shape, but got " - f"{query_arr.shape} vs {key_arr.shape} vs {value_arr.shape}.") + raise ValueError(f"query/key/value should have the same dtype, but got " + f"{query_arr.dtype} vs {key_arr.dtype} vs {value_arr.dtype}.") if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4: raise ValueError(f"Mask must be a 4D boolean tensor, but got " f"rank={mask.ndim}, dtype={mask.dtype}.") From 12a9c8cfd4eca9d18ce1ca6f097fe9e728d97317 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 8 Aug 2024 13:54:21 -0700 Subject: [PATCH 035/702] Pallas Mosaic GPU lowering now supports (at least the basic) pl.BlockSpecs Note that we still don't do any pipelining whatsoever, but it can be done once this change lands. PiperOrigin-RevId: 660969393 --- jax/_src/pallas/mosaic_gpu/BUILD | 1 + jax/_src/pallas/mosaic_gpu/lowering.py | 191 ++++++++++++------ .../mosaic/gpu/fragmented_array.py | 5 +- tests/pallas/mosaic_gpu_test.py | 26 ++- 4 files changed, 156 insertions(+), 67 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 038826a663e8..8f351020a86f 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -59,6 +59,7 @@ pytype_strict_library( "//jax:mlir", "//jax:mosaic_gpu", "//jax:pallas", + "//jax:partial_eval", "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index be9168e07567..109e81306b4b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -19,6 +19,7 @@ from collections.abc import Sequence import dataclasses import functools +import itertools as it import math from typing import Any, cast @@ -27,9 +28,11 @@ from jax._src import pjit from jax._src import util from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives @@ -97,7 +100,7 @@ def scratch_view( memory_space=smem, ) views.append( - memref_dialect.view(scratch_ty, self.runtime_smem, _index(off), []) + memref_dialect.view(scratch_ty, self.runtime_smem, _as_index(off), []) ) off += math.prod(s.shape) * jnp.dtype(s.dtype).itemsize @@ -112,17 +115,16 @@ def stack_free_smem(self, bytes: int): self.smem_used_bytes -= bytes -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class LoweringRuleContext: module_context: ModuleContext avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] - block_shapes: list[tuple[int | pallas_core.Mapped, ...]] | None replace = dataclasses.replace -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class LoweringResult: module: ir.Module grid: tuple[int, ...] @@ -130,17 +132,26 @@ class LoweringResult: out_structs: tuple[jax.ShapeDtypeStruct, ...] -@dataclasses.dataclass -class BlockInfo: - full_shape_dtype: jax.ShapeDtypeStruct - start_indices: Sequence[Any] - block_shape: tuple[int, ...] - - class LoweringError(Exception): # pylint: disable=g-bad-exception-name pass +def _eval_index_map( + ctx: ModuleContext, idx, block_mapping: pallas_core.BlockMapping +) -> Sequence[ir.Value]: + block_indices = lower_jaxpr_to_mosaic_gpu( + ctx, block_mapping.index_map_jaxpr.jaxpr, idx + ) + result = [] + for i, b in zip(block_indices, block_mapping.block_shape): + if b is pallas_core.mapped: + result.append(i) + else: + # TODO(slebedev): Use a type-agnostic multiplication wrapper. + result.append(arith_dialect.muli(_as_index(i), _as_index(b))) + return tuple(result) + + def lower_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, @@ -149,10 +160,50 @@ def lower_jaxpr_to_module( cost_estimate: pallas_core.CostEstimate | None, ) -> LoweringResult: del cost_estimate # Unused. - in_structs = tuple(grid_mapping.in_shapes) - out_structs = grid_mapping.out_shapes + + in_structs_gmem = [*grid_mapping.in_shapes] + in_structs_smem = [ + jax.ShapeDtypeStruct(bm.block_shape, s.dtype) + for bm, s in zip( + grid_mapping.block_mappings[: grid_mapping.num_inputs], + grid_mapping.in_shapes, + ) + ] + out_structs_gmem = [*grid_mapping.out_shapes] + out_structs_smem = [ + jax.ShapeDtypeStruct(bm.block_shape, s.dtype) + for bm, s in zip( + grid_mapping.block_mappings[grid_mapping.num_inputs :], + grid_mapping.out_shapes, + ) + ] assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims + if len(grid_mapping.grid) > 3: + raise NotImplementedError( + "Only <=3D grids are supported in Mosaic GPU lowering." + ) + if grid_mapping.num_dynamic_grid_bounds: + raise NotImplementedError( + "Dynamic grid bounds not supported in the Mosaic GPU lowering." + ) + if grid_mapping.num_index_operands: + raise NotImplementedError( + "Scalar prefetch not supported in Mosaic GPU lowering." + ) + if not all( + isinstance(bm.indexing_mode, pallas_core.Blocked) + for bm in grid_mapping.block_mappings + ): + raise NotImplementedError( + "Only Blocked indexing mode is supported in Mosaic GPU lowering." + ) + + with grid_mapping.trace_env(): + jaxpr, _ = pe.dce_jaxpr( + jaxpr, [True] * len(jaxpr.outvars), instantiate=True + ) + grid = grid_mapping.grid if len(grid) < 3: grid += (1,) * (3 - len(grid)) @@ -161,19 +212,40 @@ def lower_jaxpr_to_module( def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): *buffers_gmem, (*buffers_smem, runtime_smem, barriers) = buffers assert len(buffers_gmem) == len(buffers_smem) - in_buffers_gmem = buffers_gmem[: len(in_structs)] - in_buffers_smem = buffers_smem[: len(in_structs)] - out_buffers_gmem = buffers_gmem[len(in_structs) :] - out_buffers_smem = buffers_smem[len(in_structs) :] + in_buffers_gmem = buffers_gmem[: len(in_structs_gmem)] + in_buffers_smem = buffers_smem[: len(in_structs_smem)] + out_buffers_gmem = buffers_gmem[len(in_structs_gmem) :] + out_buffers_smem = buffers_smem[len(in_structs_smem) :] [barrier] = cast(mgpu.BarrierRef, barriers) + module_ctx = ModuleContext( + name_and_src_info.name, grid_mapping, runtime_smem, smem_used_bytes=0 + ) + program_ids = [ + arith_dialect.index_cast( + ir.IntegerType.get_signless(32), gpu_dialect.block_id(dim) + ) + for dim in it.islice(gpu_dialect.Dimension, len(grid_mapping.grid)) + ] + start_indices = map( + functools.partial(_eval_index_map, module_ctx, program_ids), + grid_mapping.block_mappings, + ) + in_start_indices = start_indices[: len(in_structs_gmem)] + out_start_indices = start_indices[len(in_structs_gmem) :] + with mgpu.single_thread(): - for b_gmem, b_smem in zip(in_buffers_gmem, in_buffers_smem): + for start_indices, b_gmem, b_smem in zip( + in_start_indices, in_buffers_gmem, in_buffers_smem + ): # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. launch_ctx.async_copy( src_ref=b_gmem, dst_ref=b_smem, + gmem_slice=tuple( + map(mgpu.ds, start_indices, ir.MemRefType(b_smem.type).shape) + ), barrier=barrier, swizzle=None, arrive=True, @@ -182,13 +254,21 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): barrier.wait() - module_ctx = ModuleContext(name_and_src_info.name, - grid_mapping, runtime_smem, smem_used_bytes=0) - _ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, None, buffers_smem) + _ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, buffers_smem) + mgpu.commit_shared() - for b_gmem, b_smem in zip(out_buffers_gmem, out_buffers_smem): + for start_indices, b_gmem, b_smem in zip( + out_start_indices, out_buffers_gmem, out_buffers_smem + ): # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. - launch_ctx.async_copy(src_ref=b_smem, dst_ref=b_gmem, swizzle=None) + launch_ctx.async_copy( + src_ref=b_smem, + dst_ref=b_gmem, + gmem_slice=tuple( + map(mgpu.ds, start_indices, ir.MemRefType(b_smem.type).shape) + ), + swizzle=None, + ) launch_ctx.await_async_copy(0) @@ -200,23 +280,25 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): dtype=np.int8, ) ] - module, out_structs, gmem_scratch_bytes, _ = mosaic_gpu._lower_as_gpu_kernel( - body, - grid=grid, - cluster=(), - block=block, - in_shapes=in_structs, - out_shape=out_structs, - smem_scratch_shape=( - *in_structs, - *out_structs, - *extra_smem_scratch, - mgpu.TMABarrier(), - ), - module_name=name_and_src_info.name, + module, out_structs_smem, gmem_scratch_bytes, _ = ( + mosaic_gpu._lower_as_gpu_kernel( + body, + grid=grid, + cluster=(), + block=block, + in_shapes=in_structs_gmem, + out_shape=out_structs_gmem, + smem_scratch_shape=( + *in_structs_smem, + *out_structs_smem, + *extra_smem_scratch, + mgpu.TMABarrier(), + ), + module_name=name_and_src_info.name, + ) ) - return LoweringResult(module, grid, gmem_scratch_bytes, out_structs) + return LoweringResult(module, grid, gmem_scratch_bytes, out_structs_smem) mosaic_lowering_rules = {} @@ -233,28 +315,17 @@ def deco(fn): def lower_jaxpr_to_mosaic_gpu( ctx: ModuleContext, jaxpr: jax_core.Jaxpr, - block_infos: Sequence[BlockInfo | None] | None, - args, + args: Sequence[ir.Value], consts=(), ) -> Sequence[ir.Value]: env = {} - block_info_env = {} def read_env(atom: jax_core.Atom): return atom.val if isinstance(atom, jax_core.Literal) else env[atom] - def read_block_info_env(atom: jax_core.Atom): - if isinstance(atom, jax_core.Literal): - return None - return block_info_env.get(atom, None) - def write_env(var: jax_core.Var, val): env[var] = val - if block_infos is None: - block_infos = [None] * len(jaxpr.invars) - for invar, block_info in zip(jaxpr.invars, block_infos): - block_info_env[invar] = block_info map(write_env, jaxpr.constvars, consts) map(write_env, jaxpr.invars, args) for eqn in jaxpr.eqns: @@ -270,7 +341,6 @@ def write_env(var: jax_core.Var, val): ctx, avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], - block_shapes=map(read_block_info_env, eqn.invars), ) try: outvals = rule(rule_ctx, *invals, **eqn.params) @@ -291,10 +361,9 @@ def write_env(var: jax_core.Var, val): @register_lowering_rule(sp.get_p) def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree): - del tree, ctx # Unused. + del ctx, tree # Unused. if indexers: raise NotImplementedError("No support for indexers yet") - return mgpu.FragmentedArray.load_strided(x_smem) @@ -302,7 +371,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree): def _swap_lowering_rule( ctx: LoweringRuleContext, x_smem, value, *indexers, tree ): - del tree, ctx # Unused. + del ctx, tree # Unused. if indexers: raise NotImplementedError("No support for indexers yet") old_value = mgpu.FragmentedArray.load_strided(x_smem) @@ -314,7 +383,7 @@ def _swap_lowering_rule( def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): if jaxpr.consts: raise NotImplementedError - return lower_jaxpr_to_mosaic_gpu(ctx.module_context, jaxpr.jaxpr, None, args) + return lower_jaxpr_to_mosaic_gpu(ctx.module_context, jaxpr.jaxpr, args) @register_lowering_rule(lax.broadcast_in_dim_p) @@ -398,7 +467,7 @@ def _run_scoped_lowering_rule( [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype) for aval in in_avals] ) outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_context, jaxpr, None, input_refs, consts + ctx.module_context, jaxpr, input_refs, consts ) ctx.module_context.stack_free_smem(bytes_allocated) return outs @@ -446,7 +515,7 @@ def _ensure_fa(x: object, aval: jax_core.ShapedArray) -> mgpu.FragmentedArray: def _ir_constant(v: object, t: ir.Type) -> ir.Value: if isinstance(v, (np.number, np.ndarray, int, float)): - if isinstance(t, ir.IntegerType): + if isinstance(t, (ir.IntegerType, ir.IndexType)): v = int(v) else: assert isinstance(t, ir.FloatType) @@ -455,5 +524,9 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value: raise NotImplementedError(f"Unsupported constant: {v!r}") -def _index(i: int) -> ir.Value: - return arith_dialect.constant(ir.IndexType.get(), int(i)) +def _as_index(v: int | ir.Value) -> ir.Value: + if isinstance(v, int): + return arith_dialect.constant(ir.IndexType.get(), v) + if ir.IndexType.isinstance(v.type): + return v + return arith_dialect.index_cast(ir.IndexType.get(), v) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 8b13a00bced9..44f6904e335e 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -15,6 +15,7 @@ """Utilities for code generator.""" import dataclasses +import math from typing import Callable import jax @@ -98,10 +99,10 @@ def from_memref_type(cls, memref_ty: ir.Type): memref_type = ir.MemRefType(memref_ty) bw = mgpu.bytewidth(memref_type.element_type) assert 8 % bw == 0 and 8 // bw != 0, bw - if np.prod(memref_type.shape) % WARPGROUP_SIZE != 0: + if math.prod(memref_type.shape) % WARPGROUP_SIZE != 0: raise ValueError( "Ref must have a number of elements that is a multiple of" - f" {WARPGROUP_SIZE}" + f" {WARPGROUP_SIZE} (got {math.prod(memref_type.shape)})" ) max_vec_size = np.prod(memref_type.shape) // WARPGROUP_SIZE return cls( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7d072678fb4c..cb2301e1a2b1 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -46,22 +46,36 @@ def test_add_one(self): pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) - def add_one(x_ref, o_ref): + def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] + 1.0 x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_array_equal(add_one(x), x + 1.0) + np.testing.assert_array_equal(kernel(x), x + 1.0) + + def test_add_one_grid(self): + @functools.partial( + pl.pallas_call, + in_specs=[pl.BlockSpec((128,), lambda *i: i)], + out_specs=pl.BlockSpec((128,), lambda *i: i), + out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), + grid=2, + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 + + x = jnp.arange(128 * 2).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_doubled_sum(self): @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), ) - def add_one(x_ref, o_ref): + def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] + jnp.sum(x_ref[...]) + jnp.sum(x_ref[...]) - x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_array_equal(add_one(x), x + x.sum()*2) + x = jnp.arange(128).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + x.sum()*2) @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) def test_layer_norm(self, input_factor): From deefbdd62673aca65e01234ccfb2b893912b99a5 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 8 Aug 2024 14:03:21 -0700 Subject: [PATCH 036/702] Temporarily disable broken tests in tpu_pallas_pipeline_test.py PiperOrigin-RevId: 660972804 --- tests/pallas/tpu_pallas_pipeline_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index e61f0dfa56b3..ca64275d3f09 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -139,6 +139,8 @@ def setUp(self): ('hbm', pltpu.TPUMemorySpace.ANY), ) def test_pipeline_matmul(self, memory_space): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.uniform(k1, (512, 512)) y = jax.random.uniform(k2, (512, 512)) @@ -184,6 +186,8 @@ def matmul_kernel(x_ref, y_ref, z_ref): ('hbm', pltpu.TPUMemorySpace.ANY), ) def test_double_pipeline_matmul(self, memory_space): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.uniform(k1, (512, 512)) y = jax.random.uniform(k2, (512, 512)) @@ -535,6 +539,8 @@ def reference(x, y): ) def test_pipeline_throughput_optimized_allgather_matmul( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') input_dtype = out_dtype num_devices = jax.local_device_count() @@ -1065,6 +1071,8 @@ def reference(x, y): ) def test_pipeline_throughput_optimized_matmul_reducescatter( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') input_dtype = jnp.float32 num_devices = jax.device_count() @@ -1325,6 +1333,8 @@ def setUp(self): super().setUp() def test_can_partition_nondivisible_grid_with_dynamic_dimensions(self): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') def mul_pipeline(x_ref, y_ref): y_ref[...] = x_ref[...] * 2 @@ -1359,6 +1369,8 @@ def mul_kernel(iters_ref, x_ref, y_ref): np.testing.assert_allclose(func(jnp.array([5]), x), x * 2) def test_megacore_mul(self): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') x = jax.random.uniform(jax.random.key(0), (512, 512)) def matmul_pipeline(x_ref, y_ref): @@ -1396,6 +1408,8 @@ def matmul_kernel(x_ref, y_ref): (768, 1024, 768, 256, 512, 256), ) def test_megacore_matmul(self, m, k, n, bm, bk, bn): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') k1, k2 = jax.random.split(jax.random.key(42)) x = jax.random.uniform(k1, (m, k)) y = jax.random.uniform(k2, (k, n)) From 9c2caedab1ada42187c1a49a97947f78678307ce Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 8 Aug 2024 14:34:06 -0700 Subject: [PATCH 037/702] Add subdirectories to the output path when building editable wheels for jaxlib and GPU plugin. When `build_gpu_plugin` is true, three wheels will be produced (jaxlib, jax-cuda-pjrt and jax-cuda-plugin). If they are editable, they need to be placed in subdirectories to avoid overwrite. Tested on GPU. After the editable wheels are built, they can be installed with `pip install -e /jax/dist/jax_gpu_pjrt /jax/dist/jaxlib /jax/dist/jax_gpu_plugin`. PiperOrigin-RevId: 660984311 --- build/build.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/build/build.py b/build/build.py index 7a418c9b4d78..0db2630a3e2b 100755 --- a/build/build.py +++ b/build/build.py @@ -364,6 +364,15 @@ def add_boolean_argument(parser, name, default=False, help_str=None): group.add_argument("--no" + name, dest=name, action="store_false") +def _get_editable_output_paths(output_path): + """Returns the paths to the editable wheels.""" + return ( + os.path.join(output_path, "jaxlib"), + os.path.join(output_path, "jax_gpu_pjrt"), + os.path.join(output_path, "jax_gpu_plugin"), + ) + + def main(): cwd = os.getcwd() parser = argparse.ArgumentParser( @@ -678,11 +687,20 @@ def main(): *args.bazel_options, ) + if args.build_gpu_plugin and args.editable: + output_path_jaxlib, output_path_jax_pjrt, output_path_jax_kernel = ( + _get_editable_output_paths(output_path) + ) + else: + output_path_jaxlib = output_path + output_path_jax_pjrt = output_path + output_path_jax_kernel = output_path + if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin: build_cpu_wheel_command = [ *command_base, "//jaxlib/tools:build_wheel", "--", - f"--output_path={output_path}", + f"--output_path={output_path_jaxlib}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}" ] @@ -698,7 +716,7 @@ def main(): build_gpu_kernels_command = [ *command_base, "//jaxlib/tools:build_gpu_kernels_wheel", "--", - f"--output_path={output_path}", + f"--output_path={output_path_jax_kernel}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}", ] @@ -719,7 +737,7 @@ def main(): build_pjrt_plugin_command = [ *command_base, "//jaxlib/tools:build_gpu_plugin_wheel", "--", - f"--output_path={output_path}", + f"--output_path={output_path_jax_pjrt}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}", ] From f2068bb4ad9f258f26bb435e5096d1a8e103218a Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 8 Aug 2024 15:18:07 -0700 Subject: [PATCH 038/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/76978f280df19d6bfcaa4559ccb5573e13367c7b. PiperOrigin-RevId: 660999885 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b885cad3f539..815b007ce5fc 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "3bf7e1ae488174aa6b29cc3f2c216785dd161af8" -XLA_SHA256 = "6f11fc246856472069926e5de3506c740fb9af750e74819e39737e1c8460da78" +XLA_COMMIT = "76978f280df19d6bfcaa4559ccb5573e13367c7b" +XLA_SHA256 = "124185ce5c8da06f7e3f48eac5e2fdfd4049d17f4ef7b0dc343bb087722037f0" def repo(): tf_http_archive( From dcd186f552d0e865a930de74ebe4e2c1f19387a7 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 5 Aug 2024 17:53:50 -0700 Subject: [PATCH 039/702] [Pallas] Add pallas distributed computation tutorial --- .../_static/pallas/distributed/all_gather.svg | 1 + .../pallas/distributed/race_condition.svg | 1 + docs/_static/pallas/distributed/rdma_recv.svg | 1 + docs/_static/pallas/distributed/rdma_send.svg | 1 + .../_static/pallas/distributed/rdma_start.svg | 1 + .../pallas/distributed/reduce_scatter_1.svg | 1 + .../pallas/distributed/reduce_scatter_2.svg | 1 + .../pallas/distributed/reduce_sum_1.svg | 1 + .../pallas/distributed/reduce_sum_2.svg | 1 + docs/conf.py | 2 + docs/pallas/tpu/distributed.ipynb | 1743 +++++++++++++++++ docs/pallas/tpu/distributed.md | 1527 +++++++++++++++ docs/pallas/tpu/index.rst | 1 + 13 files changed, 3282 insertions(+) create mode 100644 docs/_static/pallas/distributed/all_gather.svg create mode 100644 docs/_static/pallas/distributed/race_condition.svg create mode 100644 docs/_static/pallas/distributed/rdma_recv.svg create mode 100644 docs/_static/pallas/distributed/rdma_send.svg create mode 100644 docs/_static/pallas/distributed/rdma_start.svg create mode 100644 docs/_static/pallas/distributed/reduce_scatter_1.svg create mode 100644 docs/_static/pallas/distributed/reduce_scatter_2.svg create mode 100644 docs/_static/pallas/distributed/reduce_sum_1.svg create mode 100644 docs/_static/pallas/distributed/reduce_sum_2.svg create mode 100644 docs/pallas/tpu/distributed.ipynb create mode 100644 docs/pallas/tpu/distributed.md diff --git a/docs/_static/pallas/distributed/all_gather.svg b/docs/_static/pallas/distributed/all_gather.svg new file mode 100644 index 000000000000..5bbf6f70cf8f --- /dev/null +++ b/docs/_static/pallas/distributed/all_gather.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/race_condition.svg b/docs/_static/pallas/distributed/race_condition.svg new file mode 100644 index 000000000000..e4f981186dab --- /dev/null +++ b/docs/_static/pallas/distributed/race_condition.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/rdma_recv.svg b/docs/_static/pallas/distributed/rdma_recv.svg new file mode 100644 index 000000000000..d49ba5eb8541 --- /dev/null +++ b/docs/_static/pallas/distributed/rdma_recv.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/rdma_send.svg b/docs/_static/pallas/distributed/rdma_send.svg new file mode 100644 index 000000000000..579ba1323667 --- /dev/null +++ b/docs/_static/pallas/distributed/rdma_send.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/rdma_start.svg b/docs/_static/pallas/distributed/rdma_start.svg new file mode 100644 index 000000000000..f37bde6e83e1 --- /dev/null +++ b/docs/_static/pallas/distributed/rdma_start.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/reduce_scatter_1.svg b/docs/_static/pallas/distributed/reduce_scatter_1.svg new file mode 100644 index 000000000000..c66df4acf8a5 --- /dev/null +++ b/docs/_static/pallas/distributed/reduce_scatter_1.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/reduce_scatter_2.svg b/docs/_static/pallas/distributed/reduce_scatter_2.svg new file mode 100644 index 000000000000..bb4ae3496297 --- /dev/null +++ b/docs/_static/pallas/distributed/reduce_scatter_2.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/reduce_sum_1.svg b/docs/_static/pallas/distributed/reduce_sum_1.svg new file mode 100644 index 000000000000..6c397a87be88 --- /dev/null +++ b/docs/_static/pallas/distributed/reduce_sum_1.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/reduce_sum_2.svg b/docs/_static/pallas/distributed/reduce_sum_2.svg new file mode 100644 index 000000000000..ef2a76330a61 --- /dev/null +++ b/docs/_static/pallas/distributed/reduce_sum_2.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index a07ab12a96fc..06b3b179bf32 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -125,6 +125,7 @@ 'notebooks/*.md', 'pallas/quickstart.md', 'pallas/tpu/pipelining.md', + 'pallas/tpu/distributed.md', 'jep/9407-type-promotion.md', 'autodidax.md', 'sharded-computation.md', @@ -212,6 +213,7 @@ # Requires accelerators 'pallas/quickstart.*', 'pallas/tpu/pipelining.*', + 'pallas/tpu/distributed.*', 'sharded-computation.*' ] diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb new file mode 100644 index 000000000000..5209f2ff8e52 --- /dev/null +++ b/docs/pallas/tpu/distributed.ipynb @@ -0,0 +1,1743 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "zSNjLhGQJMgq" + }, + "source": [ + "# Distributed Computing in Pallas for TPUs\n", + "\n", + "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n", + "\n", + "Some recommended readings beforehand:\n", + " - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html)\n", + " - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "executionInfo": { + "elapsed": 1978, + "status": "ok", + "timestamp": 1722904801801, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "PyAGnWc9yI8T", + "outputId": "1d8229bd-cab5-495f-93e9-fff2e41db480" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running with 4 TPU v5 lite devices.\n" + ] + } + ], + "source": [ + "import jax\n", + "from jax import lax\n", + "from jax import numpy as jnp\n", + "from jax.experimental import mesh_utils\n", + "from jax.experimental import pallas as pl\n", + "from jax.experimental import shard_map\n", + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "P = jax.sharding.PartitionSpec\n", + "\n", + "num_devices = jax.local_device_count()\n", + "assert num_devices > 1, \"Please run this notebook with more than one device.\"\n", + "assert \"TPU\" in jax.devices()[0].device_kind, \"Please run this notebook with TPU devices.\"\n", + "print(f\"Running with {num_devices} {jax.devices()[0].device_kind} devices.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DySMGNByclMi" + }, + "source": [ + "## TPU Topologies\n", + "\n", + "TPUs are typically deployed in pods of multiple devices connected via a high-bandwidth interchip interconnect (ICI) for communication within the pod that is much faster than a typical network connection. For example, the specifications sheet for a [TPU v5p](https://cloud.google.com/tpu/docs/v5p) states an ICI bandwidth of 4.8Tb/s per chip (for reference, TPU v5p also has 21Tb/s of *local* HBM bandwidth). The ICI allows us to implement fast and performant distributed kernels that require high-bandwidth communication within a pod, and use the datacenter network for parallelization over less bandwidth-intensive operations, such as data-parallelism over a batch dimension.\n", + "\n", + "TPUs pods are typically arranged in an ND torus topology. The following graphic gives several examples of configurations of different sizes.\n", + "\n", + "![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png)\n", + "\n", + "Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod.\n", + "\n", + "![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1Oc_WD1hChfN" + }, + "source": [ + "## Remote Direct Memory Access (RDMA) Model\n", + "\n", + "TPUs communicate via a push-only model known as a remote direct memory access (RDMA). A TPU is allowed to issue copy instruction to push from a local buffer to any buffer on another device within the same pod that executes asynchronously from the main program thread. However, a TPU can only read data that is stored locally. This is in contrast to more traditional multi-core programming where it is possible to both read from and write to values to a shared memory.\n", + "\n", + "### Async Remote Copy Operation\n", + "The `pltpu.make_async_remote_copy` function is used to create a remote DMA descriptor object which parameterizes both a \"send\" operation and a \"receive\" operation. Here's its signature:\n", + "\n", + "```python\n", + " def make_async_remote_copy(\n", + " src_ref: Ref,\n", + " dst_ref: Ref,\n", + " send_sem: Ref[SemaphoreType],\n", + " recv_sem: Ref[SemaphoreType],\n", + " device_id: int | tuple[int, ...],\n", + " device_id_type: DeviceIdType\n", + " ) -> AsyncCopyDescriptor:\n", + "```\n", + "\n", + "- `src_ref` is the local `Ref` (in any memory space) containing the data you wish to send to `dst_ref` on another device.\n", + "- `dst_ref` is the remote `Ref` (in any memory space) at which data will be copied to on the target device.\n", + "- `send_sem` is a DMA semaphore used to block until all data has been sent from `src_ref`.\n", + "- `recv_sem` is a DMA semaphore used to block until the expected number of bytes have been received at `dst_ref`. The sender of the DMA will write to the receiver's `recv_sem`.\n", + "- `device_id` is the device ID of the target device to send to.\n", + "- `device_id_type` specifies the format of `device_id`, which can either be in LOGICAL format (integer device ID), or in MESH format (an ND-tuple index into the logical device mesh). The default mode is MESH.\n", + "\n", + "`make_async_remote_copy` returns a descriptor object on which you use the `.start()` method to initiate the DMA, and the `.wait_send()` to block on `send_sem` and `.wait_recv()` to block on `recv_sem` (or `.wait()` to block on both). If a device is only expected to send data, it is sufficient to only call `.start()` and `.wait_send()`, and likewise if a device is only receiving it is sufficient to only call `.wait_recv()`. If using a SPMD pattern where all devices execute the DMA, each device will generally call both `.start()` and `.wait()`.\n", + "```python\n", + "dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id)\n", + "dma_descriptor.start() # Initiate the DMA (non-blocking).\n", + "# ... do other work\n", + "dma_descriptor.wait_send() # Block until all data has been sent.\n", + "dma_descriptor.wait_recv() # Block until all data has been received.\n", + "```\n", + "\n", + "As an example, let's visualize a DMA where we consider 4 devices (indexed 0, 1, 2, 3). We consider a scheme where device 0 copies to device 1, and device 2 & 3 copy to each other. In practice, we can create such an asymmetric communication pattern by using `@pl.when` to branch on the device ID.\n", + "\n", + "(1) Each device creates the DMA descriptor. Devices 0, 2, and 3 call `.start()` to initiate the DMA from `src_ref`. Device 1 is skips the `.start()` and does nothing, e.g. by using `pl.when`.\n", + "\n", + "![rdma_start](../../_static/pallas/distributed/rdma_start.svg)\n", + "\n", + "(2) As `.start()` is non-blocking, each device is free to do other computation while the DMA is in flight. Devices 0, 2, and 3 call `.wait_send()` to wait on `send_sem` which blocks until all data has been sent.\n", + "\n", + "![rdma_send](../../_static/pallas/distributed/rdma_send.svg)\n", + "\n", + "(3) Finally, devices 1, 2, and 3 will call `.wait_recv()` to wait on `recv_sem` until all data has arrived at `dst_ref`.\n", + "\n", + "![rdma_recv](../../_static/pallas/distributed/rdma_recv.svg)\n", + "\n", + "The above communication pattern can be written as follows:\n", + "```python\n", + "def example_kernel(input_ref, output_ref, send_sem, recv_sem):\n", + " device_id = lax.axis_index('x')\n", + " copy_0_to_1 = pltpu.make_async_remote_copy(\n", + " src_ref=input_ref,\n", + " dst_ref=output_ref,\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sem,\n", + " device_id=1,\n", + " )\n", + " copy_2_to_3 = pltpu.make_async_remote_copy(\n", + " src_ref=input_ref,\n", + " dst_ref=output_ref,\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sem,\n", + " device_id=3,\n", + " )\n", + " copy_3_to_2 = pltpu.make_async_remote_copy(\n", + " src_ref=input_ref,\n", + " dst_ref=output_ref,\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sem,\n", + " device_id=2,\n", + " )\n", + " @pl.when(device_id == 0)\n", + " def _():\n", + " copy_0_to_1.start()\n", + " copy_0_to_1.wait_send()\n", + " @pl.when(device_id == 1)\n", + " def _():\n", + " copy_0_to_1.wait_recv()\n", + " @pl.when(device_id == 2)\n", + " def _():\n", + " copy_2_to_3.start()\n", + " copy_2_to_3.wait_send()\n", + " copy_3_to_2.wait_recv()\n", + " @pl.when(device_id == 3)\n", + " def _():\n", + " copy_3_to_2.start()\n", + " copy_3_to_2.wait_send()\n", + " copy_2_to_3.wait_recv()\n", + "```\n", + "\n", + "### DMA Semaphores\n", + "\n", + "`send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`.\n", + "\n", + "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n", + "\n", + "### Routing\n", + "\n", + "A sender is allowed to send data to any receiver within the same pod, even if they do not share a direct connection (the exception to this rule is for TPU v5e, where devices can only route to a power of 2 offset from themselves). TPUs have an internal routing mechanism which can pass data along to the next device on the path to the destination. However, communicating in this way is not recommended as you have no control over network contention as a kernel writer. The examples we will cover in this tutorial minimize inefficient communication by only transferring data to neighboring devices.\n", + "\n", + "### Failure modes\n", + "\n", + "If using remote DMAs incorrectly, you may encounter several failure modes which can be difficult to debug. The general symptoms of buggy DMA usage are crashes, hanging, or silent data corruption:\n", + "- If semaphores exit the program with an invalid non-zero value, Pallas will crash and exit the program.\n", + "- If semaphores are waited on but an insufficient number of bytes are received (i.e. there is no sender, or if the sent data is less than the size of `dst_ref` on the receiving device), the program may hang indefinitely waiting for bytes that are never sent. In this case the program would need to be restarted.\n", + "- If encountering a race condition, there could be silent data corruption if two simultaneous writes or a simultaneous read and write occur.\n", + "\n", + "Some common causes of the above include:\n", + "- If a device calls `.wait_recv()` but no other device sends to it, the kernel may hang.\n", + "- If a device is sent a more bytes than it expected to receive, it may also crash due to non-zero semaphore states. If sent less, it may hang indefinitely.\n", + "- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states.\n", + "- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states.\n", + "\n", + "### Megacore\n", + "\n", + "Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = lax.axis_index(name)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vpGSN1Sui0Bu" + }, + "source": [ + "### Example: Right Permute (`lax.ppermute`)\n", + "\n", + "Let's dive into a very basic example. We will implement a kernel that performs a right permutation, where each device sends its slice of the data to its right neighbor.\n", + "\n", + "Suppose we had an array with 512 elements, which we shard into slices of size 128 across 4 devices. Each device will pass its slice to the next device, and the output will consist of the same data, but with the slices rotated by 1. This is identical to the `lax.ppermute` operation where the permutation is set to `(n, (n+1) % 4)`.\n", + "\n", + "In order to call the kernel in distributed mode, we wrap the `pallas_call` in a `shard_map` transformation. From there, we can write the kernel the same way as you would write a normal single-device Pallas kernel, except we now have access to remote DMA instructions. JAX collective primitives such as `lax.axis_index` can be used to obtain a `device_id` that can be used to compute which target devices to copy to, by referencing the same named axes names passed into `shard_map`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "executionInfo": { + "elapsed": 1606, + "status": "ok", + "timestamp": 1722904803566, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "YkyIKN2thZ-V", + "outputId": "9b7ed142-d161-4237-fed8-cbce41adc5f0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input = [0.9858954 0.11763906 0.9955574 0.775211 ]\n", + "Pallas Result = [0.775211 0.9858954 0.11763906 0.9955574 ]\n", + "lax.ppermute Result = [0.775211 0.9858954 0.11763906 0.9955574 ]\n", + "Difference |Pallas - lax.ppermute| = 0.0\n" + ] + } + ], + "source": [ + "partition = P(None, 'x')\n", + "devices = mesh_utils.create_device_mesh((1, num_devices))\n", + "mesh = jax.sharding.Mesh(devices, partition)\n", + "sharding = jax.sharding.NamedSharding(mesh, partition)\n", + "\n", + "# Create an input array that shards the last dimension across\n", + "# all devices.\n", + "input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices))\n", + "input_arr = jax.device_put(input_arr, sharding)\n", + "\n", + "\n", + "def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):\n", + " my_id = lax.axis_index('x')\n", + " right_neighbor = lax.rem(my_id + 1, num_devices)\n", + " remote_copy_op = pltpu.make_async_remote_copy(\n", + " src_ref=input_ref,\n", + " dst_ref=output_ref,\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " remote_copy_op.start()\n", + " remote_copy_op.wait()\n", + "\n", + "\n", + "out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=0,\n", + " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", + " in_specs=[\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " scratch_shapes=(\n", + " # We allocate DMA semaphores in scratch memory.\n", + " [pltpu.SemaphoreType.DMA] * 2\n", + " ),\n", + ")\n", + "right_permute = pl.pallas_call(\n", + " right_permute_kernel,\n", + " out_shape=out_shape,\n", + " grid_spec=grid_spec,\n", + ")\n", + "# Wrap the kernel within a shard_map to call.\n", + "pallas_result = jax.jit(\n", + " shard_map.shard_map(\n", + " right_permute,\n", + " mesh=mesh,\n", + " in_specs=partition,\n", + " out_specs=partition,\n", + " check_rep=False,\n", + " )\n", + ")(input_arr)\n", + "\n", + "# Compare Pallas result to XLA shard_map result.\n", + "perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))\n", + "\n", + "xla_result = jax.jit(\n", + " shard_map.shard_map(\n", + " lambda x: lax.ppermute(x, 'x', perm),\n", + " mesh=mesh, in_specs=partition, out_specs=partition)\n", + ")(input_arr)\n", + "\n", + "print('Input = ', input_arr[0, ::128])\n", + "print('Pallas Result = ', pallas_result[0, ::128])\n", + "print('lax.ppermute Result = ', xla_result[0, ::128])\n", + "print(\n", + " 'Difference |Pallas - lax.ppermute| = ',\n", + " jnp.mean(jnp.abs(pallas_result - xla_result)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iyfhdGXuUnq2" + }, + "source": [ + "### Example: All-gather (`lax.all_gather`)\n", + "\n", + "In this next example we will implement the all-gather collective operation, which has a JAX equivalent in `lax.all_gather`. In contrast with the right-permute example from above which only involves a pair of source and destination neighbors, an all-gather operation requires communication between all devices and therefore we must think about how data is routed between them. The specifics of how we implement this are dictated by the device topology, for which we assume is a ring.\n", + "\n", + "#### Ring Communication Pattern\n", + "\n", + "We will write our kernel assuming a ring topology. Rings are a natural fit for TPUs as slicing along any dimension of a torus produces a ring. When writing collectives, we often only need to think about 1D slices of our torus at a time because the different dimensions of the torus are reserved for different types of parallelism (data vs. model, for example).\n", + "\n", + "The strategy we will use is to write a looped kernel, where on each iteration a device receives one slice of the sharded array from its left neighbor, and copies the previously received slice to its right neighbor. After `num_devices` iterations, each device will have a copy of the entire array in its local HBM.\n", + "\n", + "![all_gather](../../_static/pallas/distributed/all_gather.svg)\n", + "\n", + "We can re-purpose Pallas's `grid` argument to implement the loop. Rather than iterating over tiles of an array as we have done in previous tutorials, we instead set the grid to `(num_devices,)` to indicate that we want to loop over the number of devices and use `pl.program_id` to obtain the loop iteration inside of the Pallas kernel. The following code snippet demonstrates how to implement this:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "executionInfo": { + "elapsed": 812, + "status": "ok", + "timestamp": 1722904804531, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "ojQEZB5mBRqM", + "outputId": "e1648f54-737c-4921-ca3b-b4c639a38d2b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input: (32, 128) [0.9858954 0.54248166 0.9547038 0.954962 ]\n", + "Pallas Result: (16, 8, 128) [0.9858954 0.54248166 0.9547038 0.954962 0.9858954 0.54248166\n", + " 0.9547038 0.954962 0.9858954 0.54248166 0.9547038 0.954962\n", + " 0.9858954 0.54248166 0.9547038 0.954962 ]\n", + "lax.all_gather Result: (16, 8, 128) [0.9858954 0.54248166 0.9547038 0.954962 0.9858954 0.54248166\n", + " 0.9547038 0.954962 0.9858954 0.54248166 0.9547038 0.954962\n", + " 0.9858954 0.54248166 0.9547038 0.954962 ]\n", + "Difference |Pallas - lax.all_gather| = 0.0\n" + ] + } + ], + "source": [ + "partition = P('x', None)\n", + "devices = mesh_utils.create_device_mesh((num_devices, 1))\n", + "mesh = jax.sharding.Mesh(devices, partition)\n", + "sharding = jax.sharding.NamedSharding(mesh, partition)\n", + "\n", + "# Create an input array that shards the first dimension across\n", + "# all devices.\n", + "input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))\n", + "input_arr = jax.device_put(input_arr, sharding)\n", + "\n", + "\n", + "def all_gather_kernel(input_ref,\n", + " output_ref,\n", + " local_copy_sem,\n", + " send_sem,\n", + " recv_sems):\n", + " outer_step = pl.program_id(0)\n", + " my_id = lax.axis_index('x')\n", + " right_neighbor = lax.rem(my_id + 1, num_devices)\n", + " copy_slot = my_id - outer_step\n", + " copy_slot = lax.rem(copy_slot + num_devices, num_devices)\n", + "\n", + " @pl.when(outer_step == 0)\n", + " def _():\n", + " local_copy_op = pltpu.make_async_copy(\n", + " src_ref=input_ref,\n", + " dst_ref=output_ref.at[my_id],\n", + " sem=local_copy_sem,\n", + " )\n", + " local_copy_op.start()\n", + " local_copy_op.wait()\n", + "\n", + " # Copy to our right neighbor.\n", + " # Note that we will also be receiving data from our left neighbor,\n", + " # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact\n", + " # that the indices do not need to be symmetric between remote DMAs.\n", + " remote_copy_op = pltpu.make_async_remote_copy(\n", + " src_ref=output_ref.at[copy_slot],\n", + " dst_ref=output_ref.at[copy_slot],\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sems.at[outer_step],\n", + " device_id=(right_neighbor, 0),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " remote_copy_op.start()\n", + " remote_copy_op.wait()\n", + "\n", + "out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=0,\n", + " in_specs=[\n", + " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " scratch_shapes=(\n", + " # DMA semaphores are allocated in scratch memory.\n", + " # We allocated one semaphore for a local HBM-VMEM copy,\n", + " # and one for the remote send semaphore.\n", + " [pltpu.SemaphoreType.DMA] * 2\n", + " # We additionally allocate one receive semaphore per device.\n", + " # This is to avoid situations where we have multiple\n", + " # DMAs in flight, as we do not want to share a receive\n", + " # semaphore between the DMAs.\n", + " + [pltpu.SemaphoreType.DMA((num_devices-1,))]\n", + "\n", + " ),\n", + " grid=(num_devices-1,)\n", + " )\n", + "\n", + "all_gather = pl.pallas_call(\n", + " all_gather_kernel,\n", + " out_shape=out_shape,\n", + " grid_spec=grid_spec,\n", + " )\n", + "\n", + "# Wrap the kernel within a shard_map to call.\n", + "pallas_result = jax.jit(\n", + " shard_map.shard_map(\n", + " all_gather,\n", + " mesh=mesh,\n", + " in_specs=partition,\n", + " out_specs=partition,\n", + " check_rep=False\n", + " )\n", + ")(input_arr)\n", + "\n", + "# Compare Pallas result to XLA shard_map result.\n", + "xla_result = jax.jit(\n", + " shard_map.shard_map(\n", + " lambda x: lax.all_gather(x, 'x'),\n", + " mesh=mesh, in_specs=partition, out_specs=partition\n", + " )\n", + ")(input_arr)\n", + "\n", + "print('Input: ', input_arr.shape, input_arr[::8, 0])\n", + "print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0])\n", + "print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0])\n", + "print('Difference |Pallas - lax.all_gather| = ',\n", + " jnp.mean(jnp.abs(pallas_result - xla_result)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KgU7HI2pS4om" + }, + "source": [ + "A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KgU7HI2pS4om" + }, + "source": [ + "## Advanced Techniques\n", + "\n", + "Now that we have seen how to write several basic kernels using remote DMA operations, we will go over more advanced techniques for synchronization and writing efficient kernels." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8M_kdl0FCtrL" + }, + "source": [ + "### Synchronization: Regular and Barrier Semaphores\n", + "\n", + "The examples we implemented in the basic tutorial do not require special handling of synchronization as all necessary communication writes to disjoint buffers. However, other operations may require more complex communication patterns that need additional synchronization primitives to avoid race conditions. Pallas provides two additional primitives to help with this: regular and barrier semaphores.\n", + "\n", + "#### Regular Semaphores\n", + "\n", + "Regular semaphores are the standard tool used to synchronize across multiple devices. Semaphores are fundamentally counters - they can be incremented by any device after which a device can block until the value of the semaphore reaches a specific value (and then decrement the value).\n", + "\n", + "The three main operations that can be used on regular semaphores are signal, wait, and read:\n", + "```python\n", + "def semaphore_signal(\n", + " sem: Ref[SemaphoreType],\n", + " inc: int,\n", + " device_id: int | tuple[int, ...],\n", + " device_id_type: DeviceIdType\n", + ") -> None:\n", + " ... # Increments the semaphore `sem` on the target device `device_id` by `inc`.\n", + " \n", + "def semaphore_wait(\n", + " semaphore: Ref[SemaphoreType],\n", + " value: int,\n", + ") -> None:\n", + " ... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed.\n", + " \n", + "def semaphore_read(\n", + " sem: Ref[SemaphoreType],\n", + ") -> jax.Array:\n", + " ... # Returns the current value of `sem` as an `int32[]`.\n", + "```\n", + "\n", + "In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`.\n", + "\n", + "Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen:\n", + " - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program.\n", + " - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted.\n", + "\n", + "#### Barrier Semaphores\n", + "\n", + "Barrier semaphores are globally-allocated semaphores used to synchronize devices across an entire program and ensure that all devices have entered the Pallas kernel.\n", + "\n", + "If a Pallas kernel is executed within the context of a larger XLA program, we need to ensure that all devices that communicate have entered the kernel. However, DMA and regular semaphores are both locally scoped - they are only understood by other devices that have entered the kernel. Barrier semaphores serve as a globally understood semaphore that can be used for synchronization no matter where in the XLA program the device is currently executing.\n", + "\n", + "By default, if you do not specify a barrier semaphore, Pallas will automatically insert a barrier semaphore at the beginning of your program. However, it can be more efficient to write your own. Barrier semaphores are similar to regular semaphores in that they are counters that can be incremented via `semaphore_signal` and can be decremented via `semaphore_wait`. They are created by calling `get_barrier_semaphore()` within a kernel. Typically, we use barriers once at the beginning of a kernel to synchronize with all devices we are communicating with.\n", + "\n", + "```python\n", + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "def example_kernel(...):\n", + " # Use barrier semaphores at the beginning of a kernel.\n", + " # is_start_of_kernel = ...\n", + " # right_neighbor = ...\n", + " # ...\n", + " @pl.when(is_start_of_kernel)\n", + " def _():\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " # Increment the semaphore of your right neighbor.\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " device_id=right_neighbor,\n", + " device_id_type=pltpu.DeviceIdType.LOGICAL,\n", + " )\n", + " # Wait until your left neighbor has incremented your semaphore\n", + " pltpu.semaphore_wait(barrier_sem, 1)\n", + " # ...\n", + "```\n", + "\n", + "When using barrier semaphores, the `collective_id` compiler parameter must be passed to `pallas_call` to specify which barrier semaphore is being used. A TPU has a small, fixed number of barrier semaphores available (typically on the order of 20-30) and therefore they should be used sparingly. In order to ensure correctness, only kernels that share the same communication pattern should use the same `collective_id`. For example, if two kernels synchronize only with neighbors on the same mesh axis, they are allowed to share the same `collective_id`. However, if two kernels synchronize along different axes, they must have different `collective_id`s. Failure to do so may result in race conditions that are difficult to debug.\n", + "\n", + "```python\n", + "kernel = pl.pallas_call(\n", + " example_kernel,\n", + " ...,\n", + " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zy20AxN5TSLA" + }, + "source": [ + "### Double-buffering\n", + "\n", + "In order to avoid reading from a local `Ref` that is also being written into by another device and creating a race condition, a useful technique is the \"double-buffered\" strategy where we allocate a two `Ref`s for each destination value. On each iteration, one `Ref` will be designated as a \"working\" slot, and the other will be designated as a \"receiving\" slot. The device is free to use the working slot for computation, but will only copy data into its neighbor's receiving slot. The working and receiving slots alternate every iteration, so that once a copy is finished, the old receiving slot becomes the new working slot, and vice versa. Using this scheme properly, data is never read from and written to the same buffer.\n", + "\n", + "The following code skeleton demonstrates how double-buffering can be used. We keep a running iteration counter in the variable `iteration`, and the `working_slot` and `receiving_slot` alternate between 0 and 1 every iteration. `dst_ref` is allocated as a double-buffer and has the size `[2, ...]`. On each iteration, we read from the working slot using `dst_ref.at[working_slot, ...]` and use the value to perform computation. Simultaneously, we copy to our neighbor's `dst_ref.at[receiving_slot]` to avoid overwriting their `working_slot` value. By structuring our communication in this fashion it is possible to overlap the communication latency of the remote DMA with local computation while minimizing the risk of race conditions.\n", + "```python\n", + "def kernel(...):\n", + " # ...\n", + " iteration = pl.program_id(0)\n", + " working_slot = lax.rem(iteration, 2)\n", + " receiving_slot = 1 - working_slot\n", + " # ...\n", + "\n", + " local_copy_op = pltpu.make_async_copy(\n", + " src_ref=dst_ref.at[working_slot, ...],\n", + " dst_ref=local_scratch_ref,\n", + " sem=local_copy_sem,\n", + " )\n", + " local_copy_op.start()\n", + " remote_copy_op = pltpu.make_async_remote_copy(\n", + " src_ref=src_ref,\n", + " dst_ref=dst_ref.at[receiving_slot, ...],\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sem,\n", + " device_id=target_device,\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " remote_copy_op.start()\n", + " \n", + " local_copy_op.wait()\n", + " # ... do work on local_scratch while waiting for async_copy_op to finish.\n", + " remote_copy_op.wait()\n", + "\n", + "```\n", + "\n", + "In terms of synchronization, the double-buffered construction works if all devices are executing on the same iteration. If a sender manages to get one iteration ahead of its receiver, it's `working_slot` and `receiving_slot` indices will be flipped compared to the receiver, meaning that it could be writing into the `working_slot` at the same time the receiver is reading from it. In order to avoid this, it may be necessary to use a semaphore to synchronize the sender with the receiver, or add additional buffering slots (\"triple\", \"quadruple\", or N-buffered) to allow additional run-ahead at the cost of more memory. In our previous `all_gather` example, note that the kernel contained a receiving buffer with N slots, which avoids race conditions altogether. In our next kernel, we will instead go through an example which uses a double-buffer with explicit synchronization." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Or0Itv72No5d" + }, + "source": [ + "### Example: All-Reduce Sum (`lax.psum`)\n", + "\n", + "We will now implement an all-reduce sum kernel using double-buffering and semaphores for synchronization. For those familiar with collective operations in JAX, the equivalent operation is `lax.psum`. All-reduce is a standard collective operation where the objective is to reduce along an axis of an array, but the array is sharded across multiple devices.\n", + "\n", + "![reduce_sum_1](../../_static/pallas/distributed/reduce_sum_1.svg)\n", + "\n", + "In the above example, we have the array [5, 2, 1, 3] sharded across 4 devices. An all-reduce sum operation would sum all values and replicate the result on each device, leading to the result [11, 11, 11, 11] sharded across all 4 devices.\n", + "\n", + "The naive implementation of all-reduce would be to gather all required values onto each device, and then reduce. However, we can improve the performance of this implementation by interleaving communication with computation. An interleaved, single-direction all-reduce can be visualized as follows. On each iteration, we receive an input value from our left neighbor, and concurrently pass input along to our next neighbor while incrementing it with our local accumulator. After N-1 iterations, each device will have a copy of the full sum in it's memory.\n", + "\n", + "![reduce_sum_2](../../_static/pallas/distributed/reduce_sum_2.svg)\n", + "\n", + "#### Putting it all together\n", + "\n", + "The following kernel demonstrates how to combine these principles into a functional kernel.\n", + "\n", + "The prologue (executed when `outer_step==0`) first initiates a barrier with both neighbors to ensure that they have also entered the kernel. It also handles initialization for all `Ref`s and handles the first remote copy to the right neighbor's \"working\" slot.\n", + "\n", + "The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`).\n", + "\n", + "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device.\n", + "\n", + "Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "executionInfo": { + "elapsed": 254, + "status": "ok", + "timestamp": 1722904804952, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "XrY5bMlvBroQ", + "outputId": "77497000-4496-462e-cc3c-73fb640cc14c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input = [0.9858954 0.11763906 0.9955574 0.775211 ]\n", + "Pallas result = [2.8743029 2.8743029 2.8743029 2.8743029]\n", + "lax.psum result = [2.8743029 2.8743029 2.8743029 2.8743029]\n", + "Difference |Pallas - lax.psum| = 1.4959369e-08\n" + ] + } + ], + "source": [ + "partition = P(None, 'x')\n", + "devices = mesh_utils.create_device_mesh((1, num_devices))\n", + "mesh = jax.sharding.Mesh(devices, partition)\n", + "sharding = jax.sharding.NamedSharding(mesh, partition)\n", + "\n", + "input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))\n", + "input_arr = jax.device_put(input_arr, sharding)\n", + "\n", + "\n", + "def all_reduce_kernel(\n", + " x_ref,\n", + " o_ref,\n", + " hbm_scratch,\n", + " copy_sem,\n", + " remote_recv_sem,\n", + " remote_send_sem,\n", + " capacity_sem,\n", + " receive_scratch,\n", + "):\n", + " outer_step = pl.program_id(0)\n", + " working_slot = lax.rem(outer_step, 2)\n", + " receiving_slot = 1 - working_slot\n", + "\n", + " my_id = lax.axis_index('x')\n", + " right_neighbor = lax.rem(my_id + 1, num_devices)\n", + " left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices)\n", + "\n", + " @pl.when(outer_step == 0)\n", + " def _():\n", + " # Barrier with both neighbors at the start, since we will be\n", + " # communicating with both.\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(barrier_sem, 2)\n", + "\n", + " # Initialize o_ref, acc_scratch, and hbm_scratch.\n", + " o_ref[...] = jnp.zeros_like(o_ref)\n", + " receive_scratch[...] = jnp.zeros_like(receive_scratch)\n", + " initial_copy = pltpu.make_async_remote_copy(\n", + " src_ref=x_ref,\n", + " dst_ref=hbm_scratch.at[working_slot],\n", + " send_sem=remote_send_sem,\n", + " recv_sem=remote_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " initial_copy.start()\n", + " initial_copy.wait()\n", + "\n", + " # Signal to our left neighbor that we are ready to receive.\n", + " # Without this signal, our left neighbor can be >=1 iteration ahead,\n", + " # meaning it could write into our working slot.\n", + " pltpu.semaphore_signal(\n", + " capacity_sem,\n", + " inc=1,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " # Copy the partial result our left neighbor sent to us into VMEM for\n", + " # computation.\n", + " local_copy = pltpu.make_async_copy(\n", + " src_ref=hbm_scratch.at[working_slot],\n", + " dst_ref=receive_scratch,\n", + " sem=copy_sem,\n", + " )\n", + " local_copy.start()\n", + "\n", + " # Block until our right neighbor is ready to receive.\n", + " pltpu.semaphore_wait(capacity_sem, 1)\n", + " # Pass the value to our right neighbor.\n", + " remote_copy = pltpu.make_async_remote_copy(\n", + " src_ref=hbm_scratch.at[working_slot],\n", + " dst_ref=hbm_scratch.at[receiving_slot],\n", + " send_sem=remote_send_sem,\n", + " recv_sem=remote_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " remote_copy.start()\n", + " # Finish local copy and accumulate while remote_copy is happening.\n", + " local_copy.wait()\n", + " o_ref[...] += receive_scratch[...]\n", + " # Block until remote copy finishes.\n", + " remote_copy.wait()\n", + "\n", + "\n", + "out_shape = (\n", + " jax.ShapeDtypeStruct((8, 128), jnp.float32),\n", + " # We allocate the double-buffer as a Pallas output so that it is\n", + " # resident in HBM.\n", + " jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), # hbm_scratch\n", + ")\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=0,\n", + " in_specs=[\n", + " # Our input lives in VMEM\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " ],\n", + " out_specs=[\n", + " # Our output lives in VMEM\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " # Our double-buffer lives in HBM\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " grid=(num_devices,),\n", + " scratch_shapes=(\n", + " [pltpu.SemaphoreType.DMA] * 3\n", + " + [pltpu.SemaphoreType.REGULAR] # capacity_sem\n", + " + [pltpu.VMEM((8, 128), jnp.float32)] # receive_scratch\n", + " ),\n", + ")\n", + "\n", + "kernel = pl.pallas_call(\n", + " all_reduce_kernel,\n", + " out_shape=out_shape,\n", + " grid_spec=grid_spec,\n", + " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + ")\n", + "\n", + "pallas_result = jax.jit(\n", + " shard_map.shard_map(\n", + " kernel,\n", + " mesh=mesh,\n", + " in_specs=partition,\n", + " out_specs=partition,\n", + " check_rep=False,\n", + " )\n", + ")(input_arr)\n", + "pallas_result = jax.block_until_ready(pallas_result)[0]\n", + "\n", + "\n", + "def lax_sum(x):\n", + " return lax.psum(x, 'x')\n", + "\n", + "\n", + "xla_result = jax.jit(\n", + " shard_map.shard_map(\n", + " lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')\n", + " )\n", + ")(input_arr)\n", + "\n", + "print('Input = ', input_arr[0, ::128])\n", + "print('Pallas result = ', pallas_result[0, ::128])\n", + "print('lax.psum result = ', xla_result[0, ::128])\n", + "difference = jnp.mean(jnp.abs(pallas_result - xla_result))\n", + "print('Difference |Pallas - lax.psum| = ', difference)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d8bsZAzQreC_" + }, + "source": [ + "### Run-ahead and Race Conditions\n", + "\n", + "As a general rule of thumb, to maximize performance we want to allow a device to run-ahead of other devices without synchronization as much as possible without sacrificing correctness of the program. While we could enforce a barrier across all devices at the beginning of each iteration, this bottlenecks the performance of the program to the slowest device on each loop. By relaxing synchronization and allowing a moderate amount of run-ahead, we can better accommodate variance in latency between iterations and devices because a device that is slow on one iteration could catch up on the next iteration.\n", + "\n", + "In the all-reduce kernel we wrote previously, we allow devices to run ahead but by less than one iteration compared to its neighbors (however, non-neighboring devices could be more than 1 iteration apart). To see why the semaphore synchronization is necessary, consider the case when one device (say device 2) hangs and falls behind the other devices. An RDMA has no \"handshake\" — only the receiver is blocked while waiting for the data to arrive. Therefore, each device can run up to one iteration ahead before it becomes blocked waiting for the next RDMA to arrive. If we have N devices, this means that the final device can be up to N iterations ahead of the first device.\n", + "\n", + "![race_condition](../../_static/pallas/distributed/race_condition.svg)\n", + "\n", + "Without adding synchronization in the other direction (forcing senders to block), device 1 could potentially run up to `N` iterations (`N = num_devices`) ahead of device 2, sending multiple writes and overwriting values in the process. To solve this in the `all_reduce` kernel we wrote previously we implemented a \"handshake\" protocol where the receiver signals back to the sender that it is ready to receive, and only then does the sender begin issuing the next RDMA." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UD8lNrqsUeXy" + }, + "source": [ + "### Bi-directional Communication\n", + "\n", + "In our previous kernels, we communicated in a single direction around a ring from left-to-right. However, as ICI connections are bi-directional, we are effectively wasting half of the total bandwidth by not sending values in the opposite direction from right-to-left. In this next kernel we will demonstrate an example which communicates in both directions to maximize ICI bandwidth." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4KjakLhbBk73" + }, + "source": [ + "### Example: Bi-directional Reduce-Scatter (`lax.psum_scatter`)\n", + "\n", + "A reduce-scatter operation is the combination of an all-reduce followed by a scatter. Or alternatively, an all-reduce is the combination of a reduce-scatter followed by all-gather.\n", + "\n", + "The following graphic depicts the semantics of this operation. We assume that each device starts with a collection of partial sums (denoted by a letter + number, such as `A0`). The goal is to reduce along one axis (numbers), while sharding along the other axis (letters).\n", + "\n", + "![reduce_scatter_1](../../_static/pallas/distributed/reduce_scatter_1.svg)\n", + "\n", + "In order to implement a bi-directional communication strategy, we slice each input block in half, and designate a direction for each half. The top half of each block will be passed from right-to-left, and the bottom half will be passed from left-to-right. A second deviation from the communication patterns of our previous all-reduce and all-gather kernels is that we will also pass around accumulators or partial sums and keep the inputs local to each device. This is in contrast to the previous examples where we passed around inputs but kept the accumulator local to the device. Passing around the accumulator is a more natural fit for this problem as in contrast to all-reduce, most of the data in the inputs are not part of the output that will be stored locally on the device. (e.g. `B0`, `C0`, and `D0` in the above graphic will not be stored on the device holding `A` at the end).\n", + "\n", + "The following diagram illustrates this communication pattern, where the colored boxes represent accumulators (not inputs!). Initially, the accumulator is simply the value that was contained in the input. At each iteration of the algorithm, we will receive a partial sum from our neighbors in each direction. We then compute the correct slice of our input to accumulate into the partial buffer, then pass the new partial sum along to our next neighbor. After N iterations, the accumulator will have passed through each device, meaning that it will hold the full sum in the end.\n", + "\n", + "![reduce_scatter_2](../../_static/pallas/distributed/reduce_scatter_2.svg)\n", + "\n", + "In terms of construction of the kernel, we introduce an additional `phase` dimension to the Pallas grid, which denotes which accumulator (left or right) we are currently computing on. We let `phase=0` denote the accumulator moving to the left, and `phase=1` denote the accumulator moving to the right. We then pipeline the two phases, such that while computing the result for one phase we are transferring our previously computed values in the opposite direction in preparation for the next phase. For example, when we are on `phase=0` (left), we first begin a DMA to transfer results we computed in the previous iteration to our right neighbor (right-DMA). Then, we accumulate into the left-buffer and save the result to HBM. We then wait for the right-DMA to complete so that it is ready for `phase=1` (right)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "executionInfo": { + "elapsed": 544, + "status": "ok", + "timestamp": 1722904805699, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "nRauUAxNHg28" + }, + "outputs": [], + "source": [ + "partition = P(None, 'x')\n", + "devices = mesh_utils.create_device_mesh((1, num_devices))\n", + "mesh = jax.sharding.Mesh(devices, partition)\n", + "sharding = jax.sharding.NamedSharding(mesh, partition)\n", + "\n", + "# We need a block size of (16, 128) to ensure that a half-slice is at least\n", + "# of size (8, 128), which is the size of a VREG. This makes tiling easier\n", + "# for the compiler.\n", + "block_size = (16, 128)\n", + "input_arr = jax.random.uniform(\n", + " jax.random.key(0),\n", + " shape=(block_size[0] * num_devices, block_size[1] * num_devices),\n", + ")\n", + "input_arr = jax.device_put(input_arr, sharding)\n", + "\n", + "LEFT = 0\n", + "RIGHT = 1\n", + "\n", + "\n", + "def mod(x, n):\n", + " return lax.rem(x + n, n)\n", + "\n", + "\n", + "def signal(left_or_right, semaphore):\n", + " my_id = lax.axis_index('x')\n", + " if left_or_right == LEFT:\n", + " neighbor = mod(my_id - 1, num_devices)\n", + " else:\n", + " neighbor = mod(my_id + 1, num_devices)\n", + " pltpu.semaphore_signal(\n", + " semaphore,\n", + " inc=1,\n", + " device_id=(0, neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + "\n", + "def reduce_scatter_kernel(\n", + " x_ref,\n", + " o_ref,\n", + " hbm_scratch,\n", + " local_copy_sem,\n", + " left_recv_sem,\n", + " left_send_sem,\n", + " right_recv_sem,\n", + " right_send_sem,\n", + " left_capacity_sem,\n", + " right_capacity_sem,\n", + " accum_scratch,\n", + "):\n", + " outer_step = pl.program_id(0)\n", + " phase = pl.program_id(1)\n", + " is_start = jnp.logical_and(outer_step == 0, phase == 0)\n", + " last_iteration = outer_step == pl.num_programs(0) - 1\n", + "\n", + " working_slot = lax.rem(outer_step, 2)\n", + " receiving_slot = 1 - working_slot\n", + " my_id = lax.axis_index('x')\n", + " right_neighbor = mod(my_id + 1, num_devices)\n", + " left_neighbor = mod(my_id - 1, num_devices)\n", + "\n", + " left_copy_device = mod(my_id + outer_step + 1, num_devices)\n", + " right_copy_device = mod(my_id - outer_step - 1, num_devices)\n", + " # Slices can be specified using pl.ds(start, size)\n", + " left_copy_slice = pl.ds(0, block_size[0] // 2)\n", + " right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)\n", + " current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)\n", + "\n", + " initial_left_copy = pltpu.make_async_remote_copy(\n", + " src_ref=x_ref.at[my_id, left_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", + " send_sem=left_send_sem,\n", + " recv_sem=left_recv_sem,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " initial_right_copy = pltpu.make_async_remote_copy(\n", + " src_ref=x_ref.at[my_id, right_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", + " send_sem=right_send_sem,\n", + " recv_sem=right_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " left_copy = pltpu.make_async_remote_copy(\n", + " src_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", + " dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n", + " send_sem=left_send_sem,\n", + " recv_sem=left_recv_sem,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " right_copy = pltpu.make_async_remote_copy(\n", + " # Note: Right copy is flipped with regards to slots since we are copying\n", + " # to the next outer_step iteration.\n", + " src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", + " send_sem=right_send_sem,\n", + " recv_sem=right_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " # --- Prologue ---\n", + " @pl.when(is_start)\n", + " def _():\n", + " # Barrier with both neighbors at the start, since we will be\n", + " # communicating with both.\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(barrier_sem, 2)\n", + "\n", + " # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.\n", + " o_ref[...] = jnp.zeros_like(o_ref[...])\n", + " accum_scratch[...] = jnp.zeros_like(accum_scratch[...])\n", + "\n", + " initial_left_copy.start()\n", + " initial_left_copy.wait()\n", + " initial_right_copy.start()\n", + "\n", + " # We tell our left neighbor that it is allowed to send to the right.\n", + " # (and vice versa for right neighbor)\n", + " signal(LEFT, right_capacity_sem)\n", + " signal(RIGHT, left_capacity_sem)\n", + "\n", + " # --- Body ---\n", + " # At the beginning of our kernel body, we start a DMA which copies\n", + " # the result we computed in the previous phase to our neighbor.\n", + " # This allows us to overlap the communication of sending our previous phase\n", + " # with the computation for the current phase.\n", + " @pl.when(~is_start)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " # We block here until our right neighbor tells use we can send to\n", + " # the right.\n", + " pltpu.semaphore_wait(right_capacity_sem, 1)\n", + " right_copy.start()\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " # We block here until our left neighbor tells use we can send to\n", + " # the left.\n", + " pltpu.semaphore_wait(left_capacity_sem, 1)\n", + " left_copy.start()\n", + "\n", + " local_copy = pltpu.make_async_copy(\n", + " src_ref=hbm_scratch.at[working_slot, current_phase_slice],\n", + " dst_ref=accum_scratch,\n", + " sem=local_copy_sem,\n", + " )\n", + " local_copy.start()\n", + " local_copy.wait()\n", + "\n", + " @pl.when(~last_iteration)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]\n", + "\n", + " local_copy = pltpu.make_async_copy(\n", + " src_ref=accum_scratch,\n", + " dst_ref=hbm_scratch.at[working_slot, current_phase_slice],\n", + " sem=local_copy_sem,\n", + " )\n", + " local_copy.start()\n", + " local_copy.wait()\n", + "\n", + " @pl.when(is_start)\n", + " def _():\n", + " initial_right_copy.wait()\n", + "\n", + " # At the end of our kernel body, we wait on the DMA of the previous phase\n", + " # to make sure the results are ready for the next phase.\n", + " @pl.when(~is_start)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " right_copy.wait()\n", + " signal(LEFT, right_capacity_sem)\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " left_copy.wait()\n", + " signal(RIGHT, left_capacity_sem)\n", + "\n", + " # --- Epilogue ---\n", + " # Store result on last iteration.\n", + " @pl.when(last_iteration)\n", + " def _():\n", + " # Clean up semaphores so that they exit with a value of 0.\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " o_ref[left_copy_slice, ...] = accum_scratch[...]\n", + " pltpu.semaphore_wait(right_capacity_sem, 1)\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " o_ref[right_copy_slice, ...] = accum_scratch[...]\n", + " pltpu.semaphore_wait(left_capacity_sem, 1)\n", + "\n", + "\n", + "out_shape = (\n", + " jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output\n", + " # Shape: [working/recv, block[0], block[1]]\n", + " jax.ShapeDtypeStruct(\n", + " (2, block_size[0], block_size[1]), jnp.float32\n", + " ), # hbm_scratch\n", + ")\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=0,\n", + " in_specs=[\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " ],\n", + " out_specs=[\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " grid=(num_devices, 2),\n", + " scratch_shapes=(\n", + " [pltpu.SemaphoreType.DMA] * 5\n", + " + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores\n", + " + [\n", + " pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)\n", + " ] # accum_scratch\n", + " ),\n", + ")\n", + "\n", + "\n", + "def pallas_reduce_scatter(input_arr):\n", + " input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])\n", + " return pl.pallas_call(\n", + " reduce_scatter_kernel,\n", + " out_shape=out_shape,\n", + " grid_spec=grid_spec,\n", + " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + " )(input_arr)[0]\n", + "\n", + "\n", + "pallas_result = jax.jit(\n", + " shard_map.shard_map(\n", + " pallas_reduce_scatter,\n", + " mesh=mesh,\n", + " in_specs=P(None, 'x'),\n", + " out_specs=P('x', None),\n", + " check_rep=False,\n", + " )\n", + ")(input_arr)\n", + "\n", + "pallas_result = jax.block_until_ready(pallas_result)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "executionInfo": { + "elapsed": 596, + "status": "ok", + "timestamp": 1722904806442, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "E-NMh-_teoi4", + "outputId": "24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input: (64, 512) [0.78051674 0.3524047 0.59993696 0.9714314 0.24692321 0.01347649\n", + " 0.01857424 0.24841607 0.86097646 0.8261659 0.9753758 0.6902338\n", + " 0.4431417 0.963323 0.3158517 0.535548 ]\n", + "Pallas Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869 1.4194957 1.4163033 1.2401303\n", + " 1.1892898 2.6545286 2.221559 2.7995253 2.08431 2.2509837 3.0726733\n", + " 2.4662397 1.9542246]\n", + "lax.psum_scatter Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869 1.4194957 1.4163033 1.2401303\n", + " 1.1892898 2.6545286 2.221559 2.7995253 2.08431 2.2509837 3.0726733\n", + " 2.4662397 1.9542246]\n", + "Difference |Pallas - lax.psum_scatter|: 2.3841858e-07\n" + ] + } + ], + "source": [ + "# Compare our result to XLA.\n", + "def lax_reduce_sum_scatter(x):\n", + " x = x.reshape(num_devices, block_size[0], block_size[1])\n", + " return lax.psum_scatter(x, 'x')\n", + "\n", + "\n", + "xla_result = jax.jit(\n", + " shard_map.shard_map(\n", + " lax_reduce_sum_scatter,\n", + " mesh=mesh,\n", + " in_specs=P(None, 'x'),\n", + " out_specs=P('x', None),\n", + " )\n", + ")(input_arr)\n", + "\n", + "print('Input:', input_arr.shape, input_arr[::4, 0])\n", + "print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])\n", + "print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])\n", + "print(\n", + " 'Difference |Pallas - lax.psum_scatter|:',\n", + " jnp.max(jnp.abs(pallas_result - xla_result)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ThKas40r40Ji" + }, + "source": [ + "### Nested Remote and Local DMA Pipelines\n", + "\n", + "A limitation of the previous all-reduce and reduce-scatter kernels that we wrote is that the blocks we copy via remote DMA must be small enough to fit in our working VMEM that we use for accumulation. For some kernels it may be advantageous to use larger block sizes to better utilize the TPU. For example, a matrix multiplication requires on the order of $O(N^3)$ compute operations, but only $O(N^2)$ memory transfers. Therefore, we want each block of work transferred between devices to be large enough such that the operation becomes compute bound and we can hide the communication cost using pipelining. For reference, the VMEM of a TPU (for generations v4/v5) is typically on the order of 10-100MB, whereas HBM ranges from 10-100GB.\n", + "\n", + "To address this problem, we need to be able to write an \"inner kernel\" that handles local HBM-VMEM pipelining inside of the \"outer kernel\" that handles pipelining larger HBM-HBM transfers between devices. Pallas offers an API for constructing nested pipelines using the `emit_pipeline` function. The basic call signature for `emit_pipeline` follows that of a standard `pallas_call` by specifying a `grid` and `BlockSpec`s for the inputs and outputs:\n", + "\n", + "```python\n", + "def emit_pipeline(\n", + " kernel: Callable,\n", + " grid: tuple[int],\n", + " in_specs: PyTree[BlockSpec] = None,\n", + " out_specs: PyTree[BlockSpec] = None,\n", + " should_accumulate_out: bool = False,\n", + " dimension_semantics: tuple[GridDimensionSemantics] = None,\n", + ") -> Callable:\n", + " ... # Returns a custom pipeline given an inner kernel and BlockSpecs.\n", + "```\n", + "\n", + "Indeed, one can view `pallas_call` itself as simply a wrapper around `emit_pipeline`. Because our outer kernel only involves remote HBM-HBM transfers, we are not using any of the built-in pipelining that `pallas_call` provides for HBM-VMEM transfers. The following code skeleton demonstrates what a typical program structure would look like using this pattern:\n", + "\n", + "```python\n", + "\n", + "def outer_kernel(...):\n", + " # ... do work to pipeline remote HBM-HBM transfers (outer kernel)\n", + "\n", + " def inner_kernel(...):\n", + " # ... do work (inner kernel)\n", + " pltpu.emit_pipeline(\n", + " inner_kernel,\n", + " grid=inner_grid,\n", + " in_specs=...,\n", + " out_specs=...,\n", + " )(inner_kernel_args)\n", + " # ... do more work (outer kernel)\n", + "\n", + "pl.pallas_call(\n", + " outer_kernel,\n", + " grid=outer_grid,\n", + " in_specs=...\n", + " out_specs=...\n", + " scratch=inner_kernel_allocs\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DzFeQjYaasX5" + }, + "source": [ + "### Example: Reduce-Scatter with large HBM blocks\n", + "\n", + "In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter.\n", + "\n", + "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n", + "\n", + "In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM:\n", + "\n", + "```python\n", + "local_copy = pltpu.make_async_copy(\n", + " src_ref=hbm_scratch.at[working_slot, current_phase_slice],\n", + " dst_ref=accum_scratch,\n", + " sem=local_copy_sem,\n", + ")\n", + "local_copy.start()\n", + "local_copy.wait()\n", + "@pl.when(~last_iteration)\n", + "def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]\n", + "local_copy = pltpu.make_async_copy(\n", + " src_ref=accum_scratch,\n", + " dst_ref=hbm_scratch.at[working_slot, current_phase_slice],\n", + " sem=local_copy_sem,\n", + ")\n", + "local_copy.start()\n", + "local_copy.wait()\n", + "```\n", + "\n", + "Our new kernel replaces it with the following `emit_pipeline` call:\n", + "\n", + "```python\n", + "def inner_kernel(input_ref, accum_ref):\n", + " accum_ref[...] = input_ref[...]\n", + "accum_pipeline = pltpu.emit_pipeline(inner_kernel,\n", + " in_specs=[inner_block_spec],\n", + " out_specs=inner_block_spec,\n", + " should_accumulate_out=True,\n", + " grid=inner_grid)\n", + "@pl.when(~last_iteration)\n", + "def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " accum_pipeline(x_ref.at[left_copy_device, left_copy_slice],\n", + " hbm_scratch.at[working_slot, left_copy_slice],\n", + " )\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " accum_pipeline(x_ref.at[right_copy_device, right_copy_slice],\n", + " hbm_scratch.at[working_slot, right_copy_slice],\n", + " )\n", + "```\n", + "\n", + "The full kernel is as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "executionInfo": { + "elapsed": 1341, + "status": "ok", + "timestamp": 1722904807930, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "27jni-pSartL" + }, + "outputs": [], + "source": [ + "partition = P(None, 'x')\n", + "devices = mesh_utils.create_device_mesh((1, num_devices))\n", + "mesh = jax.sharding.Mesh(devices, partition)\n", + "sharding = jax.sharding.NamedSharding(mesh, partition)\n", + "\n", + "# We pick a large outer kernel block size that we do not want to place\n", + "# in VMEM. For pedagogical purposes we use (4096, 4096), although in\n", + "# principle this can be much larger.\n", + "outer_block_size = (4096, 4096)\n", + "# We pick a smaller VMEM block size for the inner kernel.\n", + "inner_block_size = (128, 128)\n", + "input_arr = jax.random.uniform(\n", + " jax.random.key(0),\n", + " shape=(\n", + " outer_block_size[0] * num_devices,\n", + " outer_block_size[1] * num_devices,\n", + " ),\n", + ")\n", + "input_arr = jax.device_put(input_arr, sharding)\n", + "\n", + "\n", + "inner_grid = (\n", + " outer_block_size[0] // inner_block_size[0] // 2,\n", + " outer_block_size[1] // inner_block_size[1],\n", + ")\n", + "inner_block_spec = pl.BlockSpec(\n", + " index_map=lambda i, j: (i, j),\n", + " block_shape=inner_block_size,\n", + " memory_space=pltpu.TPUMemorySpace.ANY,\n", + ")\n", + "\n", + "\n", + "def reduce_scatter_kernel(\n", + " x_ref,\n", + " o_ref,\n", + " hbm_scratch,\n", + " left_recv_sem,\n", + " left_send_sem,\n", + " copy_sem,\n", + " right_recv_sem,\n", + " right_send_sem,\n", + " left_capacity_sem,\n", + " right_capacity_sem,\n", + "):\n", + " outer_step = pl.program_id(0)\n", + " phase = pl.program_id(1)\n", + " is_start = jnp.logical_and(outer_step == 0, phase == 0)\n", + " last_iteration = outer_step == pl.num_programs(0) - 1\n", + "\n", + " working_slot = lax.rem(outer_step, 2)\n", + " receiving_slot = 1 - working_slot\n", + " my_id = lax.axis_index('x')\n", + " right_neighbor = mod(my_id + 1, num_devices)\n", + " left_neighbor = mod(my_id - 1, num_devices)\n", + "\n", + " left_copy_device = mod(my_id + outer_step + 1, num_devices)\n", + " right_copy_device = mod(my_id - outer_step - 1, num_devices)\n", + " left_copy_slice = pl.ds(0, outer_block_size[0] // 2)\n", + " right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)\n", + " current_phase_slice = pl.ds(\n", + " phase * (outer_block_size[0] // 2), outer_block_size[0] // 2\n", + " )\n", + "\n", + " initial_left_copy = pltpu.make_async_remote_copy(\n", + " src_ref=x_ref.at[my_id, left_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", + " send_sem=left_send_sem,\n", + " recv_sem=left_recv_sem,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " initial_right_copy = pltpu.make_async_remote_copy(\n", + " src_ref=x_ref.at[my_id, right_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", + " send_sem=right_send_sem,\n", + " recv_sem=right_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " left_copy = pltpu.make_async_remote_copy(\n", + " src_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", + " dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n", + " send_sem=left_send_sem,\n", + " recv_sem=left_recv_sem,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " right_copy = pltpu.make_async_remote_copy(\n", + " src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", + " send_sem=right_send_sem,\n", + " recv_sem=right_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " # --- Prologue ---\n", + " @pl.when(is_start)\n", + " def _():\n", + " # Barrier with both neighbors at the start, since we will be\n", + " # communicating with both.\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(barrier_sem, 2)\n", + "\n", + " initial_left_copy.start()\n", + " initial_left_copy.wait()\n", + " initial_right_copy.start()\n", + "\n", + " # We tell our left neighbor that it is allowed to send to the right.\n", + " # (and vice versa for right neighbor)\n", + " signal(LEFT, right_capacity_sem)\n", + " signal(RIGHT, left_capacity_sem)\n", + "\n", + " @pl.when(~is_start)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " # We block here until our right neighbor tells use we can send to\n", + " # the right.\n", + " pltpu.semaphore_wait(right_capacity_sem, 1)\n", + " right_copy.start()\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " # We block here until our left neighbor tells use we can send to\n", + " # the left.\n", + " pltpu.semaphore_wait(left_capacity_sem, 1)\n", + " left_copy.start()\n", + "\n", + " # --- Body ---\n", + " def inner_kernel(input_ref, accum_ref):\n", + " # We do not explicitly use += because we set should_accumulate_out=True.\n", + " accum_ref[...] = input_ref[...]\n", + "\n", + " accum_pipeline = pltpu.emit_pipeline(\n", + " inner_kernel,\n", + " in_specs=[inner_block_spec],\n", + " out_specs=inner_block_spec,\n", + " should_accumulate_out=True,\n", + " grid=inner_grid,\n", + " )\n", + "\n", + " @pl.when(~last_iteration)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " accum_pipeline(\n", + " x_ref.at[left_copy_device, left_copy_slice],\n", + " hbm_scratch.at[working_slot, left_copy_slice],\n", + " )\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " accum_pipeline(\n", + " x_ref.at[right_copy_device, right_copy_slice],\n", + " hbm_scratch.at[working_slot, right_copy_slice],\n", + " )\n", + "\n", + " # --- Epilogue ---\n", + " @pl.when(is_start)\n", + " def _():\n", + " initial_right_copy.wait()\n", + "\n", + " @pl.when(~is_start)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " right_copy.wait()\n", + " signal(LEFT, right_capacity_sem)\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " left_copy.wait()\n", + " signal(RIGHT, left_capacity_sem)\n", + "\n", + " # Store result on last iteration.\n", + " @pl.when(last_iteration)\n", + " def _():\n", + " output_copy = pltpu.make_async_copy(\n", + " src_ref=hbm_scratch.at[working_slot, current_phase_slice],\n", + " dst_ref=o_ref.at[current_phase_slice],\n", + " sem=copy_sem,\n", + " )\n", + " output_copy.start()\n", + " output_copy.wait()\n", + "\n", + " # Clean up semaphores so that they exit with a value of 0.\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " pltpu.semaphore_wait(right_capacity_sem, 1)\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " pltpu.semaphore_wait(left_capacity_sem, 1)\n", + "\n", + "\n", + "out_shape = (\n", + " jax.ShapeDtypeStruct(\n", + " (outer_block_size[0], outer_block_size[1]), jnp.float32\n", + " ),\n", + " # Shape: [working/recv, block[0], block[1]]\n", + " jax.ShapeDtypeStruct(\n", + " (2, outer_block_size[0], outer_block_size[1]), jnp.float32\n", + " ), # hbm_scratch\n", + ")\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=0,\n", + " in_specs=[\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " out_specs=[\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " grid=(num_devices, 2),\n", + " scratch_shapes=(\n", + " [pltpu.SemaphoreType.DMA] * 5\n", + " + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores\n", + " ),\n", + ")\n", + "\n", + "\n", + "def pallas_reduce_scatter(input_arr):\n", + " input_arr = input_arr.reshape(\n", + " num_devices, outer_block_size[0], outer_block_size[1]\n", + " )\n", + " return pl.pallas_call(\n", + " reduce_scatter_kernel,\n", + " out_shape=out_shape,\n", + " grid_spec=grid_spec,\n", + " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + " )(input_arr)[0]\n", + "\n", + "\n", + "pallas_result = jax.jit(\n", + " shard_map.shard_map(\n", + " pallas_reduce_scatter,\n", + " mesh=mesh,\n", + " in_specs=P(None, 'x'),\n", + " out_specs=P('x', None),\n", + " check_rep=False,\n", + " )\n", + ")(input_arr)\n", + "\n", + "pallas_result = jax.block_until_ready(pallas_result)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "executionInfo": { + "elapsed": 768, + "status": "ok", + "timestamp": 1722904808851, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "cTEyiMDyx9Y0", + "outputId": "1de26695-3713-430e-9ab4-4ea646691680" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input: (16384, 16384) [0.74162567 0.0242182 0.27751946 ... 0.05213022 0.36088037 0.04494429]\n", + "Pallas Result: (16384, 4096) [2.0648427 1.674587 1.9148926 ... 1.3371865 1.3296283 1.2887063]\n", + "lax.psum_scatter Result: (16384, 4096) [2.0648427 1.674587 1.9148926 ... 1.3371865 1.3296283 1.2887063]\n", + "Difference |Pallas - lax.psum_scatter|: 2.3841858e-07\n" + ] + } + ], + "source": [ + "# Now we compare our result to XLA.\n", + "def lax_reduce_sum_scatter(x):\n", + " x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])\n", + " return lax.psum_scatter(x, 'x')\n", + "\n", + "\n", + "xla_result = jax.jit(\n", + " shard_map.shard_map(\n", + " lax_reduce_sum_scatter,\n", + " mesh=mesh,\n", + " in_specs=P(None, 'x'),\n", + " out_specs=P('x', None),\n", + " )\n", + ")(input_arr)\n", + "\n", + "print('Input:', input_arr.shape, input_arr[::4, 0])\n", + "print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])\n", + "print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])\n", + "print(\n", + " 'Difference |Pallas - lax.psum_scatter|:',\n", + " jnp.max(jnp.abs(pallas_result - xla_result)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zz5AFbriliyv" + }, + "source": [ + "## Final Notes\n", + "\n", + "### Interaction with XLA\n", + "\n", + "In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations.\n", + "\n", + "### Next Steps\n", + "\n", + "Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md new file mode 100644 index 000000000000..b7c058b117ca --- /dev/null +++ b/docs/pallas/tpu/distributed.md @@ -0,0 +1,1527 @@ +--- +jupytext: + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.1 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + ++++ {"id": "zSNjLhGQJMgq"} + +# Distributed Computing in Pallas for TPUs + +In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. + +Some recommended readings beforehand: + - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html) + - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial) + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 1978 + status: ok + timestamp: 1722904801801 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: PyAGnWc9yI8T +outputId: 1d8229bd-cab5-495f-93e9-fff2e41db480 +--- +import jax +from jax import lax +from jax import numpy as jnp +from jax.experimental import mesh_utils +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import tpu as pltpu + +P = jax.sharding.PartitionSpec + +num_devices = jax.local_device_count() +assert num_devices > 1, "Please run this notebook with more than one device." +assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices." +print(f"Running with {num_devices} {jax.devices()[0].device_kind} devices.") +``` + ++++ {"id": "DySMGNByclMi"} + +## TPU Topologies + +TPUs are typically deployed in pods of multiple devices connected via a high-bandwidth interchip interconnect (ICI) for communication within the pod that is much faster than a typical network connection. For example, the specifications sheet for a [TPU v5p](https://cloud.google.com/tpu/docs/v5p) states an ICI bandwidth of 4.8Tb/s per chip (for reference, TPU v5p also has 21Tb/s of *local* HBM bandwidth). The ICI allows us to implement fast and performant distributed kernels that require high-bandwidth communication within a pod, and use the datacenter network for parallelization over less bandwidth-intensive operations, such as data-parallelism over a batch dimension. + +TPUs pods are typically arranged in an ND torus topology. The following graphic gives several examples of configurations of different sizes. + +![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png) + +Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod. + +![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png) + ++++ {"id": "1Oc_WD1hChfN"} + +## Remote Direct Memory Access (RDMA) Model + +TPUs communicate via a push-only model known as a remote direct memory access (RDMA). A TPU is allowed to issue copy instruction to push from a local buffer to any buffer on another device within the same pod that executes asynchronously from the main program thread. However, a TPU can only read data that is stored locally. This is in contrast to more traditional multi-core programming where it is possible to both read from and write to values to a shared memory. + +### Async Remote Copy Operation +The `pltpu.make_async_remote_copy` function is used to create a remote DMA descriptor object which parameterizes both a "send" operation and a "receive" operation. Here's its signature: + +```python + def make_async_remote_copy( + src_ref: Ref, + dst_ref: Ref, + send_sem: Ref[SemaphoreType], + recv_sem: Ref[SemaphoreType], + device_id: int | tuple[int, ...], + device_id_type: DeviceIdType + ) -> AsyncCopyDescriptor: +``` + +- `src_ref` is the local `Ref` (in any memory space) containing the data you wish to send to `dst_ref` on another device. +- `dst_ref` is the remote `Ref` (in any memory space) at which data will be copied to on the target device. +- `send_sem` is a DMA semaphore used to block until all data has been sent from `src_ref`. +- `recv_sem` is a DMA semaphore used to block until the expected number of bytes have been received at `dst_ref`. The sender of the DMA will write to the receiver's `recv_sem`. +- `device_id` is the device ID of the target device to send to. +- `device_id_type` specifies the format of `device_id`, which can either be in LOGICAL format (integer device ID), or in MESH format (an ND-tuple index into the logical device mesh). The default mode is MESH. + +`make_async_remote_copy` returns a descriptor object on which you use the `.start()` method to initiate the DMA, and the `.wait_send()` to block on `send_sem` and `.wait_recv()` to block on `recv_sem` (or `.wait()` to block on both). If a device is only expected to send data, it is sufficient to only call `.start()` and `.wait_send()`, and likewise if a device is only receiving it is sufficient to only call `.wait_recv()`. If using a SPMD pattern where all devices execute the DMA, each device will generally call both `.start()` and `.wait()`. +```python +dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id) +dma_descriptor.start() # Initiate the DMA (non-blocking). +# ... do other work +dma_descriptor.wait_send() # Block until all data has been sent. +dma_descriptor.wait_recv() # Block until all data has been received. +``` + +As an example, let's visualize a DMA where we consider 4 devices (indexed 0, 1, 2, 3). We consider a scheme where device 0 copies to device 1, and device 2 & 3 copy to each other. In practice, we can create such an asymmetric communication pattern by using `@pl.when` to branch on the device ID. + +(1) Each device creates the DMA descriptor. Devices 0, 2, and 3 call `.start()` to initiate the DMA from `src_ref`. Device 1 is skips the `.start()` and does nothing, e.g. by using `pl.when`. + +![rdma_start](../../_static/pallas/distributed/rdma_start.svg) + +(2) As `.start()` is non-blocking, each device is free to do other computation while the DMA is in flight. Devices 0, 2, and 3 call `.wait_send()` to wait on `send_sem` which blocks until all data has been sent. + +![rdma_send](../../_static/pallas/distributed/rdma_send.svg) + +(3) Finally, devices 1, 2, and 3 will call `.wait_recv()` to wait on `recv_sem` until all data has arrived at `dst_ref`. + +![rdma_recv](../../_static/pallas/distributed/rdma_recv.svg) + +The above communication pattern can be written as follows: +```python +def example_kernel(input_ref, output_ref, send_sem, recv_sem): + device_id = lax.axis_index('x') + copy_0_to_1 = pltpu.make_async_remote_copy( + src_ref=input_ref, + dst_ref=output_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=1, + ) + copy_2_to_3 = pltpu.make_async_remote_copy( + src_ref=input_ref, + dst_ref=output_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=3, + ) + copy_3_to_2 = pltpu.make_async_remote_copy( + src_ref=input_ref, + dst_ref=output_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=2, + ) + @pl.when(device_id == 0) + def _(): + copy_0_to_1.start() + copy_0_to_1.wait_send() + @pl.when(device_id == 1) + def _(): + copy_0_to_1.wait_recv() + @pl.when(device_id == 2) + def _(): + copy_2_to_3.start() + copy_2_to_3.wait_send() + copy_3_to_2.wait_recv() + @pl.when(device_id == 3) + def _(): + copy_3_to_2.start() + copy_3_to_2.wait_send() + copy_2_to_3.wait_recv() +``` + +### DMA Semaphores + +`send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`. + +Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore). + +### Routing + +A sender is allowed to send data to any receiver within the same pod, even if they do not share a direct connection (the exception to this rule is for TPU v5e, where devices can only route to a power of 2 offset from themselves). TPUs have an internal routing mechanism which can pass data along to the next device on the path to the destination. However, communicating in this way is not recommended as you have no control over network contention as a kernel writer. The examples we will cover in this tutorial minimize inefficient communication by only transferring data to neighboring devices. + +### Failure modes + +If using remote DMAs incorrectly, you may encounter several failure modes which can be difficult to debug. The general symptoms of buggy DMA usage are crashes, hanging, or silent data corruption: +- If semaphores exit the program with an invalid non-zero value, Pallas will crash and exit the program. +- If semaphores are waited on but an insufficient number of bytes are received (i.e. there is no sender, or if the sent data is less than the size of `dst_ref` on the receiving device), the program may hang indefinitely waiting for bytes that are never sent. In this case the program would need to be restarted. +- If encountering a race condition, there could be silent data corruption if two simultaneous writes or a simultaneous read and write occur. + +Some common causes of the above include: +- If a device calls `.wait_recv()` but no other device sends to it, the kernel may hang. +- If a device is sent a more bytes than it expected to receive, it may also crash due to non-zero semaphore states. If sent less, it may hang indefinitely. +- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states. +- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states. + +### Megacore + +Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = lax.axis_index(name)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core. + ++++ {"id": "vpGSN1Sui0Bu"} + +### Example: Right Permute (`lax.ppermute`) + +Let's dive into a very basic example. We will implement a kernel that performs a right permutation, where each device sends its slice of the data to its right neighbor. + +Suppose we had an array with 512 elements, which we shard into slices of size 128 across 4 devices. Each device will pass its slice to the next device, and the output will consist of the same data, but with the slices rotated by 1. This is identical to the `lax.ppermute` operation where the permutation is set to `(n, (n+1) % 4)`. + +In order to call the kernel in distributed mode, we wrap the `pallas_call` in a `shard_map` transformation. From there, we can write the kernel the same way as you would write a normal single-device Pallas kernel, except we now have access to remote DMA instructions. JAX collective primitives such as `lax.axis_index` can be used to obtain a `device_id` that can be used to compute which target devices to copy to, by referencing the same named axes names passed into `shard_map`. + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 1606 + status: ok + timestamp: 1722904803566 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: YkyIKN2thZ-V +outputId: 9b7ed142-d161-4237-fed8-cbce41adc5f0 +--- +partition = P(None, 'x') +devices = mesh_utils.create_device_mesh((1, num_devices)) +mesh = jax.sharding.Mesh(devices, partition) +sharding = jax.sharding.NamedSharding(mesh, partition) + +# Create an input array that shards the last dimension across +# all devices. +input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices)) +input_arr = jax.device_put(input_arr, sharding) + + +def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): + my_id = lax.axis_index('x') + right_neighbor = lax.rem(my_id + 1, num_devices) + remote_copy_op = pltpu.make_async_remote_copy( + src_ref=input_ref, + dst_ref=output_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy_op.start() + remote_copy_op.wait() + + +out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + scratch_shapes=( + # We allocate DMA semaphores in scratch memory. + [pltpu.SemaphoreType.DMA] * 2 + ), +) +right_permute = pl.pallas_call( + right_permute_kernel, + out_shape=out_shape, + grid_spec=grid_spec, +) +# Wrap the kernel within a shard_map to call. +pallas_result = jax.jit( + shard_map.shard_map( + right_permute, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_rep=False, + ) +)(input_arr) + +# Compare Pallas result to XLA shard_map result. +perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices)) + +xla_result = jax.jit( + shard_map.shard_map( + lambda x: lax.ppermute(x, 'x', perm), + mesh=mesh, in_specs=partition, out_specs=partition) +)(input_arr) + +print('Input = ', input_arr[0, ::128]) +print('Pallas Result = ', pallas_result[0, ::128]) +print('lax.ppermute Result = ', xla_result[0, ::128]) +print( + 'Difference |Pallas - lax.ppermute| = ', + jnp.mean(jnp.abs(pallas_result - xla_result)), +) +``` + ++++ {"id": "iyfhdGXuUnq2"} + +### Example: All-gather (`lax.all_gather`) + +In this next example we will implement the all-gather collective operation, which has a JAX equivalent in `lax.all_gather`. In contrast with the right-permute example from above which only involves a pair of source and destination neighbors, an all-gather operation requires communication between all devices and therefore we must think about how data is routed between them. The specifics of how we implement this are dictated by the device topology, for which we assume is a ring. + +#### Ring Communication Pattern + +We will write our kernel assuming a ring topology. Rings are a natural fit for TPUs as slicing along any dimension of a torus produces a ring. When writing collectives, we often only need to think about 1D slices of our torus at a time because the different dimensions of the torus are reserved for different types of parallelism (data vs. model, for example). + +The strategy we will use is to write a looped kernel, where on each iteration a device receives one slice of the sharded array from its left neighbor, and copies the previously received slice to its right neighbor. After `num_devices` iterations, each device will have a copy of the entire array in its local HBM. + +![all_gather](../../_static/pallas/distributed/all_gather.svg) + +We can re-purpose Pallas's `grid` argument to implement the loop. Rather than iterating over tiles of an array as we have done in previous tutorials, we instead set the grid to `(num_devices,)` to indicate that we want to loop over the number of devices and use `pl.program_id` to obtain the loop iteration inside of the Pallas kernel. The following code snippet demonstrates how to implement this: + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 812 + status: ok + timestamp: 1722904804531 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: ojQEZB5mBRqM +outputId: e1648f54-737c-4921-ca3b-b4c639a38d2b +--- +partition = P('x', None) +devices = mesh_utils.create_device_mesh((num_devices, 1)) +mesh = jax.sharding.Mesh(devices, partition) +sharding = jax.sharding.NamedSharding(mesh, partition) + +# Create an input array that shards the first dimension across +# all devices. +input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128)) +input_arr = jax.device_put(input_arr, sharding) + + +def all_gather_kernel(input_ref, + output_ref, + local_copy_sem, + send_sem, + recv_sems): + outer_step = pl.program_id(0) + my_id = lax.axis_index('x') + right_neighbor = lax.rem(my_id + 1, num_devices) + copy_slot = my_id - outer_step + copy_slot = lax.rem(copy_slot + num_devices, num_devices) + + @pl.when(outer_step == 0) + def _(): + local_copy_op = pltpu.make_async_copy( + src_ref=input_ref, + dst_ref=output_ref.at[my_id], + sem=local_copy_sem, + ) + local_copy_op.start() + local_copy_op.wait() + + # Copy to our right neighbor. + # Note that we will also be receiving data from our left neighbor, + # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact + # that the indices do not need to be symmetric between remote DMAs. + remote_copy_op = pltpu.make_async_remote_copy( + src_ref=output_ref.at[copy_slot], + dst_ref=output_ref.at[copy_slot], + send_sem=send_sem, + recv_sem=recv_sems.at[outer_step], + device_id=(right_neighbor, 0), + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy_op.start() + remote_copy_op.wait() + +out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32) +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + scratch_shapes=( + # DMA semaphores are allocated in scratch memory. + # We allocated one semaphore for a local HBM-VMEM copy, + # and one for the remote send semaphore. + [pltpu.SemaphoreType.DMA] * 2 + # We additionally allocate one receive semaphore per device. + # This is to avoid situations where we have multiple + # DMAs in flight, as we do not want to share a receive + # semaphore between the DMAs. + + [pltpu.SemaphoreType.DMA((num_devices-1,))] + + ), + grid=(num_devices-1,) + ) + +all_gather = pl.pallas_call( + all_gather_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + ) + +# Wrap the kernel within a shard_map to call. +pallas_result = jax.jit( + shard_map.shard_map( + all_gather, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_rep=False + ) +)(input_arr) + +# Compare Pallas result to XLA shard_map result. +xla_result = jax.jit( + shard_map.shard_map( + lambda x: lax.all_gather(x, 'x'), + mesh=mesh, in_specs=partition, out_specs=partition + ) +)(input_arr) + +print('Input: ', input_arr.shape, input_arr[::8, 0]) +print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0]) +print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0]) +print('Difference |Pallas - lax.all_gather| = ', + jnp.mean(jnp.abs(pallas_result - xla_result))) +``` + ++++ {"id": "KgU7HI2pS4om"} + +A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure. + ++++ {"id": "KgU7HI2pS4om"} + +## Advanced Techniques + +Now that we have seen how to write several basic kernels using remote DMA operations, we will go over more advanced techniques for synchronization and writing efficient kernels. + ++++ {"id": "8M_kdl0FCtrL"} + +### Synchronization: Regular and Barrier Semaphores + +The examples we implemented in the basic tutorial do not require special handling of synchronization as all necessary communication writes to disjoint buffers. However, other operations may require more complex communication patterns that need additional synchronization primitives to avoid race conditions. Pallas provides two additional primitives to help with this: regular and barrier semaphores. + +#### Regular Semaphores + +Regular semaphores are the standard tool used to synchronize across multiple devices. Semaphores are fundamentally counters - they can be incremented by any device after which a device can block until the value of the semaphore reaches a specific value (and then decrement the value). + +The three main operations that can be used on regular semaphores are signal, wait, and read: +```python +def semaphore_signal( + sem: Ref[SemaphoreType], + inc: int, + device_id: int | tuple[int, ...], + device_id_type: DeviceIdType +) -> None: + ... # Increments the semaphore `sem` on the target device `device_id` by `inc`. + +def semaphore_wait( + semaphore: Ref[SemaphoreType], + value: int, +) -> None: + ... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed. + +def semaphore_read( + sem: Ref[SemaphoreType], +) -> jax.Array: + ... # Returns the current value of `sem` as an `int32[]`. +``` + +In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`. + +Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen: + - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program. + - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted. + +#### Barrier Semaphores + +Barrier semaphores are globally-allocated semaphores used to synchronize devices across an entire program and ensure that all devices have entered the Pallas kernel. + +If a Pallas kernel is executed within the context of a larger XLA program, we need to ensure that all devices that communicate have entered the kernel. However, DMA and regular semaphores are both locally scoped - they are only understood by other devices that have entered the kernel. Barrier semaphores serve as a globally understood semaphore that can be used for synchronization no matter where in the XLA program the device is currently executing. + +By default, if you do not specify a barrier semaphore, Pallas will automatically insert a barrier semaphore at the beginning of your program. However, it can be more efficient to write your own. Barrier semaphores are similar to regular semaphores in that they are counters that can be incremented via `semaphore_signal` and can be decremented via `semaphore_wait`. They are created by calling `get_barrier_semaphore()` within a kernel. Typically, we use barriers once at the beginning of a kernel to synchronize with all devices we are communicating with. + +```python +from jax.experimental.pallas import tpu as pltpu + +def example_kernel(...): + # Use barrier semaphores at the beginning of a kernel. + # is_start_of_kernel = ... + # right_neighbor = ... + # ... + @pl.when(is_start_of_kernel) + def _(): + barrier_sem = pltpu.get_barrier_semaphore() + # Increment the semaphore of your right neighbor. + pltpu.semaphore_signal( + barrier_sem, + device_id=right_neighbor, + device_id_type=pltpu.DeviceIdType.LOGICAL, + ) + # Wait until your left neighbor has incremented your semaphore + pltpu.semaphore_wait(barrier_sem, 1) + # ... +``` + +When using barrier semaphores, the `collective_id` compiler parameter must be passed to `pallas_call` to specify which barrier semaphore is being used. A TPU has a small, fixed number of barrier semaphores available (typically on the order of 20-30) and therefore they should be used sparingly. In order to ensure correctness, only kernels that share the same communication pattern should use the same `collective_id`. For example, if two kernels synchronize only with neighbors on the same mesh axis, they are allowed to share the same `collective_id`. However, if two kernels synchronize along different axes, they must have different `collective_id`s. Failure to do so may result in race conditions that are difficult to debug. + +```python +kernel = pl.pallas_call( + example_kernel, + ..., + compiler_params=dict(mosaic=dict(collective_id=0)), +) +``` + ++++ {"id": "zy20AxN5TSLA"} + +### Double-buffering + +In order to avoid reading from a local `Ref` that is also being written into by another device and creating a race condition, a useful technique is the "double-buffered" strategy where we allocate a two `Ref`s for each destination value. On each iteration, one `Ref` will be designated as a "working" slot, and the other will be designated as a "receiving" slot. The device is free to use the working slot for computation, but will only copy data into its neighbor's receiving slot. The working and receiving slots alternate every iteration, so that once a copy is finished, the old receiving slot becomes the new working slot, and vice versa. Using this scheme properly, data is never read from and written to the same buffer. + +The following code skeleton demonstrates how double-buffering can be used. We keep a running iteration counter in the variable `iteration`, and the `working_slot` and `receiving_slot` alternate between 0 and 1 every iteration. `dst_ref` is allocated as a double-buffer and has the size `[2, ...]`. On each iteration, we read from the working slot using `dst_ref.at[working_slot, ...]` and use the value to perform computation. Simultaneously, we copy to our neighbor's `dst_ref.at[receiving_slot]` to avoid overwriting their `working_slot` value. By structuring our communication in this fashion it is possible to overlap the communication latency of the remote DMA with local computation while minimizing the risk of race conditions. +```python +def kernel(...): + # ... + iteration = pl.program_id(0) + working_slot = lax.rem(iteration, 2) + receiving_slot = 1 - working_slot + # ... + + local_copy_op = pltpu.make_async_copy( + src_ref=dst_ref.at[working_slot, ...], + dst_ref=local_scratch_ref, + sem=local_copy_sem, + ) + local_copy_op.start() + remote_copy_op = pltpu.make_async_remote_copy( + src_ref=src_ref, + dst_ref=dst_ref.at[receiving_slot, ...], + send_sem=send_sem, + recv_sem=recv_sem, + device_id=target_device, + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy_op.start() + + local_copy_op.wait() + # ... do work on local_scratch while waiting for async_copy_op to finish. + remote_copy_op.wait() + +``` + +In terms of synchronization, the double-buffered construction works if all devices are executing on the same iteration. If a sender manages to get one iteration ahead of its receiver, it's `working_slot` and `receiving_slot` indices will be flipped compared to the receiver, meaning that it could be writing into the `working_slot` at the same time the receiver is reading from it. In order to avoid this, it may be necessary to use a semaphore to synchronize the sender with the receiver, or add additional buffering slots ("triple", "quadruple", or N-buffered) to allow additional run-ahead at the cost of more memory. In our previous `all_gather` example, note that the kernel contained a receiving buffer with N slots, which avoids race conditions altogether. In our next kernel, we will instead go through an example which uses a double-buffer with explicit synchronization. + ++++ {"id": "Or0Itv72No5d"} + +### Example: All-Reduce Sum (`lax.psum`) + +We will now implement an all-reduce sum kernel using double-buffering and semaphores for synchronization. For those familiar with collective operations in JAX, the equivalent operation is `lax.psum`. All-reduce is a standard collective operation where the objective is to reduce along an axis of an array, but the array is sharded across multiple devices. + +![reduce_sum_1](../../_static/pallas/distributed/reduce_sum_1.svg) + +In the above example, we have the array [5, 2, 1, 3] sharded across 4 devices. An all-reduce sum operation would sum all values and replicate the result on each device, leading to the result [11, 11, 11, 11] sharded across all 4 devices. + +The naive implementation of all-reduce would be to gather all required values onto each device, and then reduce. However, we can improve the performance of this implementation by interleaving communication with computation. An interleaved, single-direction all-reduce can be visualized as follows. On each iteration, we receive an input value from our left neighbor, and concurrently pass input along to our next neighbor while incrementing it with our local accumulator. After N-1 iterations, each device will have a copy of the full sum in it's memory. + +![reduce_sum_2](../../_static/pallas/distributed/reduce_sum_2.svg) + +#### Putting it all together + +The following kernel demonstrates how to combine these principles into a functional kernel. + +The prologue (executed when `outer_step==0`) first initiates a barrier with both neighbors to ensure that they have also entered the kernel. It also handles initialization for all `Ref`s and handles the first remote copy to the right neighbor's "working" slot. + +The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`). + +A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device. + +Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections. + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 254 + status: ok + timestamp: 1722904804952 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: XrY5bMlvBroQ +outputId: 77497000-4496-462e-cc3c-73fb640cc14c +--- +partition = P(None, 'x') +devices = mesh_utils.create_device_mesh((1, num_devices)) +mesh = jax.sharding.Mesh(devices, partition) +sharding = jax.sharding.NamedSharding(mesh, partition) + +input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices)) +input_arr = jax.device_put(input_arr, sharding) + + +def all_reduce_kernel( + x_ref, + o_ref, + hbm_scratch, + copy_sem, + remote_recv_sem, + remote_send_sem, + capacity_sem, + receive_scratch, +): + outer_step = pl.program_id(0) + working_slot = lax.rem(outer_step, 2) + receiving_slot = 1 - working_slot + + my_id = lax.axis_index('x') + right_neighbor = lax.rem(my_id + 1, num_devices) + left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices) + + @pl.when(outer_step == 0) + def _(): + # Barrier with both neighbors at the start, since we will be + # communicating with both. + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + + # Initialize o_ref, acc_scratch, and hbm_scratch. + o_ref[...] = jnp.zeros_like(o_ref) + receive_scratch[...] = jnp.zeros_like(receive_scratch) + initial_copy = pltpu.make_async_remote_copy( + src_ref=x_ref, + dst_ref=hbm_scratch.at[working_slot], + send_sem=remote_send_sem, + recv_sem=remote_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + initial_copy.start() + initial_copy.wait() + + # Signal to our left neighbor that we are ready to receive. + # Without this signal, our left neighbor can be >=1 iteration ahead, + # meaning it could write into our working slot. + pltpu.semaphore_signal( + capacity_sem, + inc=1, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + # Copy the partial result our left neighbor sent to us into VMEM for + # computation. + local_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot], + dst_ref=receive_scratch, + sem=copy_sem, + ) + local_copy.start() + + # Block until our right neighbor is ready to receive. + pltpu.semaphore_wait(capacity_sem, 1) + # Pass the value to our right neighbor. + remote_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[working_slot], + dst_ref=hbm_scratch.at[receiving_slot], + send_sem=remote_send_sem, + recv_sem=remote_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy.start() + # Finish local copy and accumulate while remote_copy is happening. + local_copy.wait() + o_ref[...] += receive_scratch[...] + # Block until remote copy finishes. + remote_copy.wait() + + +out_shape = ( + jax.ShapeDtypeStruct((8, 128), jnp.float32), + # We allocate the double-buffer as a Pallas output so that it is + # resident in HBM. + jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), # hbm_scratch +) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + # Our input lives in VMEM + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ], + out_specs=[ + # Our output lives in VMEM + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + # Our double-buffer lives in HBM + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + grid=(num_devices,), + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 3 + + [pltpu.SemaphoreType.REGULAR] # capacity_sem + + [pltpu.VMEM((8, 128), jnp.float32)] # receive_scratch + ), +) + +kernel = pl.pallas_call( + all_reduce_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + compiler_params=dict(mosaic=dict(collective_id=0)), +) + +pallas_result = jax.jit( + shard_map.shard_map( + kernel, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_rep=False, + ) +)(input_arr) +pallas_result = jax.block_until_ready(pallas_result)[0] + + +def lax_sum(x): + return lax.psum(x, 'x') + + +xla_result = jax.jit( + shard_map.shard_map( + lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x') + ) +)(input_arr) + +print('Input = ', input_arr[0, ::128]) +print('Pallas result = ', pallas_result[0, ::128]) +print('lax.psum result = ', xla_result[0, ::128]) +difference = jnp.mean(jnp.abs(pallas_result - xla_result)) +print('Difference |Pallas - lax.psum| = ', difference) +``` + ++++ {"id": "d8bsZAzQreC_"} + +### Run-ahead and Race Conditions + +As a general rule of thumb, to maximize performance we want to allow a device to run-ahead of other devices without synchronization as much as possible without sacrificing correctness of the program. While we could enforce a barrier across all devices at the beginning of each iteration, this bottlenecks the performance of the program to the slowest device on each loop. By relaxing synchronization and allowing a moderate amount of run-ahead, we can better accommodate variance in latency between iterations and devices because a device that is slow on one iteration could catch up on the next iteration. + +In the all-reduce kernel we wrote previously, we allow devices to run ahead but by less than one iteration compared to its neighbors (however, non-neighboring devices could be more than 1 iteration apart). To see why the semaphore synchronization is necessary, consider the case when one device (say device 2) hangs and falls behind the other devices. An RDMA has no "handshake" — only the receiver is blocked while waiting for the data to arrive. Therefore, each device can run up to one iteration ahead before it becomes blocked waiting for the next RDMA to arrive. If we have N devices, this means that the final device can be up to N iterations ahead of the first device. + +![race_condition](../../_static/pallas/distributed/race_condition.svg) + +Without adding synchronization in the other direction (forcing senders to block), device 1 could potentially run up to `N` iterations (`N = num_devices`) ahead of device 2, sending multiple writes and overwriting values in the process. To solve this in the `all_reduce` kernel we wrote previously we implemented a "handshake" protocol where the receiver signals back to the sender that it is ready to receive, and only then does the sender begin issuing the next RDMA. + ++++ {"id": "UD8lNrqsUeXy"} + +### Bi-directional Communication + +In our previous kernels, we communicated in a single direction around a ring from left-to-right. However, as ICI connections are bi-directional, we are effectively wasting half of the total bandwidth by not sending values in the opposite direction from right-to-left. In this next kernel we will demonstrate an example which communicates in both directions to maximize ICI bandwidth. + ++++ {"id": "4KjakLhbBk73"} + +### Example: Bi-directional Reduce-Scatter (`lax.psum_scatter`) + +A reduce-scatter operation is the combination of an all-reduce followed by a scatter. Or alternatively, an all-reduce is the combination of a reduce-scatter followed by all-gather. + +The following graphic depicts the semantics of this operation. We assume that each device starts with a collection of partial sums (denoted by a letter + number, such as `A0`). The goal is to reduce along one axis (numbers), while sharding along the other axis (letters). + +![reduce_scatter_1](../../_static/pallas/distributed/reduce_scatter_1.svg) + +In order to implement a bi-directional communication strategy, we slice each input block in half, and designate a direction for each half. The top half of each block will be passed from right-to-left, and the bottom half will be passed from left-to-right. A second deviation from the communication patterns of our previous all-reduce and all-gather kernels is that we will also pass around accumulators or partial sums and keep the inputs local to each device. This is in contrast to the previous examples where we passed around inputs but kept the accumulator local to the device. Passing around the accumulator is a more natural fit for this problem as in contrast to all-reduce, most of the data in the inputs are not part of the output that will be stored locally on the device. (e.g. `B0`, `C0`, and `D0` in the above graphic will not be stored on the device holding `A` at the end). + +The following diagram illustrates this communication pattern, where the colored boxes represent accumulators (not inputs!). Initially, the accumulator is simply the value that was contained in the input. At each iteration of the algorithm, we will receive a partial sum from our neighbors in each direction. We then compute the correct slice of our input to accumulate into the partial buffer, then pass the new partial sum along to our next neighbor. After N iterations, the accumulator will have passed through each device, meaning that it will hold the full sum in the end. + +![reduce_scatter_2](../../_static/pallas/distributed/reduce_scatter_2.svg) + +In terms of construction of the kernel, we introduce an additional `phase` dimension to the Pallas grid, which denotes which accumulator (left or right) we are currently computing on. We let `phase=0` denote the accumulator moving to the left, and `phase=1` denote the accumulator moving to the right. We then pipeline the two phases, such that while computing the result for one phase we are transferring our previously computed values in the opposite direction in preparation for the next phase. For example, when we are on `phase=0` (left), we first begin a DMA to transfer results we computed in the previous iteration to our right neighbor (right-DMA). Then, we accumulate into the left-buffer and save the result to HBM. We then wait for the right-DMA to complete so that it is ready for `phase=1` (right). + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 544 + status: ok + timestamp: 1722904805699 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: nRauUAxNHg28 +--- +partition = P(None, 'x') +devices = mesh_utils.create_device_mesh((1, num_devices)) +mesh = jax.sharding.Mesh(devices, partition) +sharding = jax.sharding.NamedSharding(mesh, partition) + +# We need a block size of (16, 128) to ensure that a half-slice is at least +# of size (8, 128), which is the size of a VREG. This makes tiling easier +# for the compiler. +block_size = (16, 128) +input_arr = jax.random.uniform( + jax.random.key(0), + shape=(block_size[0] * num_devices, block_size[1] * num_devices), +) +input_arr = jax.device_put(input_arr, sharding) + +LEFT = 0 +RIGHT = 1 + + +def mod(x, n): + return lax.rem(x + n, n) + + +def signal(left_or_right, semaphore): + my_id = lax.axis_index('x') + if left_or_right == LEFT: + neighbor = mod(my_id - 1, num_devices) + else: + neighbor = mod(my_id + 1, num_devices) + pltpu.semaphore_signal( + semaphore, + inc=1, + device_id=(0, neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + +def reduce_scatter_kernel( + x_ref, + o_ref, + hbm_scratch, + local_copy_sem, + left_recv_sem, + left_send_sem, + right_recv_sem, + right_send_sem, + left_capacity_sem, + right_capacity_sem, + accum_scratch, +): + outer_step = pl.program_id(0) + phase = pl.program_id(1) + is_start = jnp.logical_and(outer_step == 0, phase == 0) + last_iteration = outer_step == pl.num_programs(0) - 1 + + working_slot = lax.rem(outer_step, 2) + receiving_slot = 1 - working_slot + my_id = lax.axis_index('x') + right_neighbor = mod(my_id + 1, num_devices) + left_neighbor = mod(my_id - 1, num_devices) + + left_copy_device = mod(my_id + outer_step + 1, num_devices) + right_copy_device = mod(my_id - outer_step - 1, num_devices) + # Slices can be specified using pl.ds(start, size) + left_copy_slice = pl.ds(0, block_size[0] // 2) + right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2) + current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2) + + initial_left_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, left_copy_slice], + dst_ref=hbm_scratch.at[working_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + initial_right_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + left_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[working_slot, left_copy_slice], + dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + right_copy = pltpu.make_async_remote_copy( + # Note: Right copy is flipped with regards to slots since we are copying + # to the next outer_step iteration. + src_ref=hbm_scratch.at[receiving_slot, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + # --- Prologue --- + @pl.when(is_start) + def _(): + # Barrier with both neighbors at the start, since we will be + # communicating with both. + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + + # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies. + o_ref[...] = jnp.zeros_like(o_ref[...]) + accum_scratch[...] = jnp.zeros_like(accum_scratch[...]) + + initial_left_copy.start() + initial_left_copy.wait() + initial_right_copy.start() + + # We tell our left neighbor that it is allowed to send to the right. + # (and vice versa for right neighbor) + signal(LEFT, right_capacity_sem) + signal(RIGHT, left_capacity_sem) + + # --- Body --- + # At the beginning of our kernel body, we start a DMA which copies + # the result we computed in the previous phase to our neighbor. + # This allows us to overlap the communication of sending our previous phase + # with the computation for the current phase. + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + # We block here until our right neighbor tells use we can send to + # the right. + pltpu.semaphore_wait(right_capacity_sem, 1) + right_copy.start() + + @pl.when(phase == RIGHT) + def _(): + # We block here until our left neighbor tells use we can send to + # the left. + pltpu.semaphore_wait(left_capacity_sem, 1) + left_copy.start() + + local_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot, current_phase_slice], + dst_ref=accum_scratch, + sem=local_copy_sem, + ) + local_copy.start() + local_copy.wait() + + @pl.when(~last_iteration) + def _(): + @pl.when(phase == LEFT) + def _(): + accum_scratch[...] += x_ref[left_copy_device, left_copy_slice] + + @pl.when(phase == RIGHT) + def _(): + accum_scratch[...] += x_ref[right_copy_device, right_copy_slice] + + local_copy = pltpu.make_async_copy( + src_ref=accum_scratch, + dst_ref=hbm_scratch.at[working_slot, current_phase_slice], + sem=local_copy_sem, + ) + local_copy.start() + local_copy.wait() + + @pl.when(is_start) + def _(): + initial_right_copy.wait() + + # At the end of our kernel body, we wait on the DMA of the previous phase + # to make sure the results are ready for the next phase. + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + right_copy.wait() + signal(LEFT, right_capacity_sem) + + @pl.when(phase == RIGHT) + def _(): + left_copy.wait() + signal(RIGHT, left_capacity_sem) + + # --- Epilogue --- + # Store result on last iteration. + @pl.when(last_iteration) + def _(): + # Clean up semaphores so that they exit with a value of 0. + @pl.when(phase == LEFT) + def _(): + o_ref[left_copy_slice, ...] = accum_scratch[...] + pltpu.semaphore_wait(right_capacity_sem, 1) + + @pl.when(phase == RIGHT) + def _(): + o_ref[right_copy_slice, ...] = accum_scratch[...] + pltpu.semaphore_wait(left_capacity_sem, 1) + + +out_shape = ( + jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output + # Shape: [working/recv, block[0], block[1]] + jax.ShapeDtypeStruct( + (2, block_size[0], block_size[1]), jnp.float32 + ), # hbm_scratch +) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ], + out_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + grid=(num_devices, 2), + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 5 + + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores + + [ + pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32) + ] # accum_scratch + ), +) + + +def pallas_reduce_scatter(input_arr): + input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1]) + return pl.pallas_call( + reduce_scatter_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + compiler_params=dict(mosaic=dict(collective_id=0)), + )(input_arr)[0] + + +pallas_result = jax.jit( + shard_map.shard_map( + pallas_reduce_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + check_rep=False, + ) +)(input_arr) + +pallas_result = jax.block_until_ready(pallas_result) +``` + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 596 + status: ok + timestamp: 1722904806442 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: E-NMh-_teoi4 +outputId: 24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0 +--- +# Compare our result to XLA. +def lax_reduce_sum_scatter(x): + x = x.reshape(num_devices, block_size[0], block_size[1]) + return lax.psum_scatter(x, 'x') + + +xla_result = jax.jit( + shard_map.shard_map( + lax_reduce_sum_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + ) +)(input_arr) + +print('Input:', input_arr.shape, input_arr[::4, 0]) +print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0]) +print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0]) +print( + 'Difference |Pallas - lax.psum_scatter|:', + jnp.max(jnp.abs(pallas_result - xla_result)), +) +``` + ++++ {"id": "ThKas40r40Ji"} + +### Nested Remote and Local DMA Pipelines + +A limitation of the previous all-reduce and reduce-scatter kernels that we wrote is that the blocks we copy via remote DMA must be small enough to fit in our working VMEM that we use for accumulation. For some kernels it may be advantageous to use larger block sizes to better utilize the TPU. For example, a matrix multiplication requires on the order of $O(N^3)$ compute operations, but only $O(N^2)$ memory transfers. Therefore, we want each block of work transferred between devices to be large enough such that the operation becomes compute bound and we can hide the communication cost using pipelining. For reference, the VMEM of a TPU (for generations v4/v5) is typically on the order of 10-100MB, whereas HBM ranges from 10-100GB. + +To address this problem, we need to be able to write an "inner kernel" that handles local HBM-VMEM pipelining inside of the "outer kernel" that handles pipelining larger HBM-HBM transfers between devices. Pallas offers an API for constructing nested pipelines using the `emit_pipeline` function. The basic call signature for `emit_pipeline` follows that of a standard `pallas_call` by specifying a `grid` and `BlockSpec`s for the inputs and outputs: + +```python +def emit_pipeline( + kernel: Callable, + grid: tuple[int], + in_specs: PyTree[BlockSpec] = None, + out_specs: PyTree[BlockSpec] = None, + should_accumulate_out: bool = False, + dimension_semantics: tuple[GridDimensionSemantics] = None, +) -> Callable: + ... # Returns a custom pipeline given an inner kernel and BlockSpecs. +``` + +Indeed, one can view `pallas_call` itself as simply a wrapper around `emit_pipeline`. Because our outer kernel only involves remote HBM-HBM transfers, we are not using any of the built-in pipelining that `pallas_call` provides for HBM-VMEM transfers. The following code skeleton demonstrates what a typical program structure would look like using this pattern: + +```python + +def outer_kernel(...): + # ... do work to pipeline remote HBM-HBM transfers (outer kernel) + + def inner_kernel(...): + # ... do work (inner kernel) + pltpu.emit_pipeline( + inner_kernel, + grid=inner_grid, + in_specs=..., + out_specs=..., + )(inner_kernel_args) + # ... do more work (outer kernel) + +pl.pallas_call( + outer_kernel, + grid=outer_grid, + in_specs=... + out_specs=... + scratch=inner_kernel_allocs +) +``` + ++++ {"id": "DzFeQjYaasX5"} + +### Example: Reduce-Scatter with large HBM blocks + +In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter. + +We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size. + +In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM: + +```python +local_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot, current_phase_slice], + dst_ref=accum_scratch, + sem=local_copy_sem, +) +local_copy.start() +local_copy.wait() +@pl.when(~last_iteration) +def _(): + @pl.when(phase == LEFT) + def _(): + accum_scratch[...] += x_ref[left_copy_device, left_copy_slice] + @pl.when(phase == RIGHT) + def _(): + accum_scratch[...] += x_ref[right_copy_device, right_copy_slice] +local_copy = pltpu.make_async_copy( + src_ref=accum_scratch, + dst_ref=hbm_scratch.at[working_slot, current_phase_slice], + sem=local_copy_sem, +) +local_copy.start() +local_copy.wait() +``` + +Our new kernel replaces it with the following `emit_pipeline` call: + +```python +def inner_kernel(input_ref, accum_ref): + accum_ref[...] = input_ref[...] +accum_pipeline = pltpu.emit_pipeline(inner_kernel, + in_specs=[inner_block_spec], + out_specs=inner_block_spec, + should_accumulate_out=True, + grid=inner_grid) +@pl.when(~last_iteration) +def _(): + @pl.when(phase == LEFT) + def _(): + accum_pipeline(x_ref.at[left_copy_device, left_copy_slice], + hbm_scratch.at[working_slot, left_copy_slice], + ) + @pl.when(phase == RIGHT) + def _(): + accum_pipeline(x_ref.at[right_copy_device, right_copy_slice], + hbm_scratch.at[working_slot, right_copy_slice], + ) +``` + +The full kernel is as follows: + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 1341 + status: ok + timestamp: 1722904807930 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: 27jni-pSartL +--- +partition = P(None, 'x') +devices = mesh_utils.create_device_mesh((1, num_devices)) +mesh = jax.sharding.Mesh(devices, partition) +sharding = jax.sharding.NamedSharding(mesh, partition) + +# We pick a large outer kernel block size that we do not want to place +# in VMEM. For pedagogical purposes we use (4096, 4096), although in +# principle this can be much larger. +outer_block_size = (4096, 4096) +# We pick a smaller VMEM block size for the inner kernel. +inner_block_size = (128, 128) +input_arr = jax.random.uniform( + jax.random.key(0), + shape=( + outer_block_size[0] * num_devices, + outer_block_size[1] * num_devices, + ), +) +input_arr = jax.device_put(input_arr, sharding) + + +inner_grid = ( + outer_block_size[0] // inner_block_size[0] // 2, + outer_block_size[1] // inner_block_size[1], +) +inner_block_spec = pl.BlockSpec( + index_map=lambda i, j: (i, j), + block_shape=inner_block_size, + memory_space=pltpu.TPUMemorySpace.ANY, +) + + +def reduce_scatter_kernel( + x_ref, + o_ref, + hbm_scratch, + left_recv_sem, + left_send_sem, + copy_sem, + right_recv_sem, + right_send_sem, + left_capacity_sem, + right_capacity_sem, +): + outer_step = pl.program_id(0) + phase = pl.program_id(1) + is_start = jnp.logical_and(outer_step == 0, phase == 0) + last_iteration = outer_step == pl.num_programs(0) - 1 + + working_slot = lax.rem(outer_step, 2) + receiving_slot = 1 - working_slot + my_id = lax.axis_index('x') + right_neighbor = mod(my_id + 1, num_devices) + left_neighbor = mod(my_id - 1, num_devices) + + left_copy_device = mod(my_id + outer_step + 1, num_devices) + right_copy_device = mod(my_id - outer_step - 1, num_devices) + left_copy_slice = pl.ds(0, outer_block_size[0] // 2) + right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2) + current_phase_slice = pl.ds( + phase * (outer_block_size[0] // 2), outer_block_size[0] // 2 + ) + + initial_left_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, left_copy_slice], + dst_ref=hbm_scratch.at[working_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + initial_right_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + left_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[working_slot, left_copy_slice], + dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + right_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[receiving_slot, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + # --- Prologue --- + @pl.when(is_start) + def _(): + # Barrier with both neighbors at the start, since we will be + # communicating with both. + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + + initial_left_copy.start() + initial_left_copy.wait() + initial_right_copy.start() + + # We tell our left neighbor that it is allowed to send to the right. + # (and vice versa for right neighbor) + signal(LEFT, right_capacity_sem) + signal(RIGHT, left_capacity_sem) + + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + # We block here until our right neighbor tells use we can send to + # the right. + pltpu.semaphore_wait(right_capacity_sem, 1) + right_copy.start() + + @pl.when(phase == RIGHT) + def _(): + # We block here until our left neighbor tells use we can send to + # the left. + pltpu.semaphore_wait(left_capacity_sem, 1) + left_copy.start() + + # --- Body --- + def inner_kernel(input_ref, accum_ref): + # We do not explicitly use += because we set should_accumulate_out=True. + accum_ref[...] = input_ref[...] + + accum_pipeline = pltpu.emit_pipeline( + inner_kernel, + in_specs=[inner_block_spec], + out_specs=inner_block_spec, + should_accumulate_out=True, + grid=inner_grid, + ) + + @pl.when(~last_iteration) + def _(): + @pl.when(phase == LEFT) + def _(): + accum_pipeline( + x_ref.at[left_copy_device, left_copy_slice], + hbm_scratch.at[working_slot, left_copy_slice], + ) + + @pl.when(phase == RIGHT) + def _(): + accum_pipeline( + x_ref.at[right_copy_device, right_copy_slice], + hbm_scratch.at[working_slot, right_copy_slice], + ) + + # --- Epilogue --- + @pl.when(is_start) + def _(): + initial_right_copy.wait() + + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + right_copy.wait() + signal(LEFT, right_capacity_sem) + + @pl.when(phase == RIGHT) + def _(): + left_copy.wait() + signal(RIGHT, left_capacity_sem) + + # Store result on last iteration. + @pl.when(last_iteration) + def _(): + output_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot, current_phase_slice], + dst_ref=o_ref.at[current_phase_slice], + sem=copy_sem, + ) + output_copy.start() + output_copy.wait() + + # Clean up semaphores so that they exit with a value of 0. + @pl.when(phase == LEFT) + def _(): + pltpu.semaphore_wait(right_capacity_sem, 1) + + @pl.when(phase == RIGHT) + def _(): + pltpu.semaphore_wait(left_capacity_sem, 1) + + +out_shape = ( + jax.ShapeDtypeStruct( + (outer_block_size[0], outer_block_size[1]), jnp.float32 + ), + # Shape: [working/recv, block[0], block[1]] + jax.ShapeDtypeStruct( + (2, outer_block_size[0], outer_block_size[1]), jnp.float32 + ), # hbm_scratch +) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + grid=(num_devices, 2), + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 5 + + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores + ), +) + + +def pallas_reduce_scatter(input_arr): + input_arr = input_arr.reshape( + num_devices, outer_block_size[0], outer_block_size[1] + ) + return pl.pallas_call( + reduce_scatter_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + compiler_params=dict(mosaic=dict(collective_id=0)), + )(input_arr)[0] + + +pallas_result = jax.jit( + shard_map.shard_map( + pallas_reduce_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + check_rep=False, + ) +)(input_arr) + +pallas_result = jax.block_until_ready(pallas_result) +``` + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 768 + status: ok + timestamp: 1722904808851 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: cTEyiMDyx9Y0 +outputId: 1de26695-3713-430e-9ab4-4ea646691680 +--- +# Now we compare our result to XLA. +def lax_reduce_sum_scatter(x): + x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1]) + return lax.psum_scatter(x, 'x') + + +xla_result = jax.jit( + shard_map.shard_map( + lax_reduce_sum_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + ) +)(input_arr) + +print('Input:', input_arr.shape, input_arr[::4, 0]) +print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0]) +print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0]) +print( + 'Difference |Pallas - lax.psum_scatter|:', + jnp.max(jnp.abs(pallas_result - xla_result)), +) +``` + ++++ {"id": "zz5AFbriliyv"} + +## Final Notes + +### Interaction with XLA + +In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations. + +### Next Steps + +Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead. diff --git a/docs/pallas/tpu/index.rst b/docs/pallas/tpu/index.rst index 83898b331517..d5efe3a73a40 100644 --- a/docs/pallas/tpu/index.rst +++ b/docs/pallas/tpu/index.rst @@ -8,3 +8,4 @@ TPU specific documentation. details pipelining + distributed From e57a7e3f05447fb3c0a4eb2c58438db75222866b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Thu, 8 Aug 2024 20:13:32 -0700 Subject: [PATCH 040/702] [Mosaic] Column shift relayouts for non-native tilings and packed types, except for (1, n) and packed PiperOrigin-RevId: 661091012 --- .../tpu/transforms/apply_vector_layout.cc | 353 ++++++++++++++---- 1 file changed, 283 insertions(+), 70 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 0670ef1e2f09..f85692a61624 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5008,13 +5008,290 @@ FailureOr> tpu_rotate_with_overflow( return out_tiles; } +void rotateVregs(OpBuilder &builder, xla::Array &vregs, + const int64_t amount, const int dimension) { + if (amount != 0) { + vregs.Each([&](absl::Span idx, Value *vreg) { + CHECK(vreg); + *vreg = builder + .create(vreg->getLoc(), *vreg, + /*amount=*/amount, + /*dimension=*/dimension, + /*stride=*/nullptr, + /*stride_dimension=*/nullptr) + .getResult(); + }); + } +}; + +void rotateSublanes(OpBuilder &builder, xla::Array &vregs, + const int64_t amount) { + rotateVregs(builder, vregs, amount, 0); +} + +void rotateLanes(OpBuilder &builder, xla::Array &vregs, + const int64_t amount) { + rotateVregs(builder, vregs, amount, 1); +} + +// Relayout src_vregs from layout src to layout dst, where dst is the same as +// src except that the column offset is dst_col_offset. +FailureOr> doColumnShiftRelayout( + OpBuilder &builder, const ArrayRef shape, + xla::Array src_vregs, const VectorLayout &src, + const int64_t dst_col_offset, const std::array target_shape) { + CHECK(src.offsets()[1]); + const std::array tiled_ishape = + src.getImplicitTiledDims(shape, 1); + const Location loc = src_vregs.begin()->getLoc(); + const std::array tiling = src.tiling(); + const std::array vreg_slice = src.vregSlice(target_shape); + const int bitwidth = src.bitwidth(); + const int packing = src.packing(); + const VectorLayout dst(bitwidth, {src.offsets()[0], dst_col_offset}, tiling, + src.implicit_dim()); + const int64_t col_diff = dst_col_offset - *src.offsets()[1]; + if (tiling[0] % packing != 0 || tiling[1] != target_shape[1]) { + return emitError(loc, + "Not implemented: Unsupported tiling for column shift"); + } + // When shifting columns with multiple tiles per vreg, the overflowing + // columns of a tile move to the next tile, and they have to be shifted + // down. For example, for a 32-bit layout with (2, 128 tiling), when shifting + // a vreg right by 138 (128 + 10): + // + // +---------------+---------+ +---------+---------------+ + // | 0:118 | 118:128 | |-138:-128| -128:-10 | + // +---------------+---------+ +---------+---------------+ + // | 128:246 | 246:256 | | -10:0 | 0:118 | + // +---------------+---------+ -> +---------+---------------+ + // | 256:382 | 382:392 | | 118:128 | 128:246 | + // +---------------+---------+ +---------+---------------+ + // | 392:502 | 502:512 | | 246:256 | 256:382 | + // +---------------+---------+ +---------+---------------+ + // + // The negative numbers above are used for column intervals coming from the + // previous vreg (if there is one). + // + // We can break the result vreg down into four parts: + // + // +---------+---------------+ + // | UL | UR | + // + +---------------+ + // | | LR | + // +---------+ + + // | LL | | + // + + + + // | | | + // +---------+---------------+ + // + // Our example shifts right, which causes the upper parts to come from the + // previous (along the minor dim) vreg of the array (if it exists) and the + // lower parts to come from the original "current" vreg. + // + // - LR (Lower Right) comes from the current vreg lane-rotated by 10, and + // sublane-rotated down by 2 (1 tile). + // - LL (Lower Left) comes from the current vreg lane-rotated by 10, and + // sublane-rotated down by 4 (2 tiles). + // - UR (Upper Right) comes from the previous vreg lane-shifted by 10, and + // sublane-rotated down by 2 (1 tile). + // - UL (Upper Left) comes from the previous vreg lane-shifted by 10, and + // sublane-rotated down by 4 (2 tiles). + // + // This partitioning also works similarly for left shifts, except that the + // upper parts come from the current vreg, and the lower parts come from the + // next vreg. + // + // In general, for any tiling and shift amount, we will partition the result + // vreg into four like we did here. However, for some tilings and shift + // amounts, some of the partitions may be empty. There are some notable cases: + // + // - Tile-aligned shifts result in empty left parts. + // - Native tiling (a single tile per vreg) results in empty upper right and + // lower left parts. + // - Shifts right by less than 1 tile result in empty upper right parts, and + // shifts left by less than 1 tile result in empty lower left parts. + + const int64_t sublanes_per_tile = src.sublanesPerTile(target_shape); + const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape); + + int64_t split_offset = col_diff; + int64_t upper_idx_delta = -1; + int64_t lower_idx_delta = 0; + if (col_diff < 0) { + split_offset += vreg_slice[1]; + ++upper_idx_delta; + ++lower_idx_delta; + } + const int64_t left_tile_split = llvm::divideCeil(split_offset, tiling[1]); + const int64_t right_tile_split = split_offset / tiling[1]; + const int64_t left_right_split = split_offset % tiling[1]; + + rotateLanes(builder, src_vregs, left_right_split); + // TODO(tlongeri): Clean up. Some of these rotations may end up unused: + // - The left part of the first vreg and the right part of the last vreg + // may be entirely padding. + // - The entire left part may be unused if the shift is tile-aligned. + // They will be removed as dead code anyway, but it would be nicer to not + // generate them in the first place. + // Also, sometimes the rotation amount is 0, so we don't need to allocate + // another array (and we should steal the allocation for src_tiles, too). + xla::Array left_part = src_vregs; + xla::Array right_part = src_vregs; + rotateSublanes(builder, left_part, + left_tile_split * sublanes_per_tile % target_shape[0]); + rotateSublanes(builder, right_part, + right_tile_split * sublanes_per_tile % target_shape[0]); + // We assemble left and right, and then put them together. + // TODO(tlongeri): Lower and upper first is probably better, it can be + // reused for consecutive vregs. We can assemble lower_left+lower_right + // for one vreg and upper_left+upper_right for the next one in the same + // vselect. But the mask for assembling upper+lower is not as simple, so + // it might be a bit more expensive to generate. Worth it for large vreg + // arrays, I'm not sure about small ones (especially in older TPU gens). + const auto mask_vreg_ty = VectorType::get( + packing == 1 + ? target_shape + : ArrayRef{target_shape[0], target_shape[1], packing}, + builder.getI1Type()); + Value left_mask = nullptr; + Value right_mask = nullptr; + Value left_right_mask = nullptr; + auto get_left_mask = [&]() { + if (left_mask == nullptr) { + left_mask = builder.create( + loc, mask_vreg_ty, + ArrayRef{IdxConst(0, builder, loc), IdxConst(0, builder, loc)}, + ArrayRef{ + IdxConst(left_tile_split * sublanes_per_tile, builder, loc), + IdxConst(target_shape[1], builder, loc)}); + } + return left_mask; + }; + auto get_right_mask = [&]() { + if (right_mask == nullptr) { + right_mask = builder.create( + loc, mask_vreg_ty, + ArrayRef{IdxConst(0, builder, loc), IdxConst(0, builder, loc)}, + ArrayRef{ + IdxConst(right_tile_split * sublanes_per_tile, builder, loc), + IdxConst(target_shape[1], builder, loc)}); + } + return right_mask; + }; + auto get_left_right_mask = [&]() { + if (left_right_mask == nullptr) { + left_right_mask = builder.create( + loc, mask_vreg_ty, + ArrayRef{IdxConst(0, builder, loc), IdxConst(0, builder, loc)}, + ArrayRef{IdxConst(target_shape[0], builder, loc), + IdxConst(left_right_split, builder, loc)}); + } + return left_right_mask; + }; + xla::Array dst_vregs(VectorLayout(bitwidth, + {src.offsets()[0], dst_col_offset}, + tiling, src.implicit_dim()) + .tileArrayImplicitShape(shape, target_shape)); + dst_vregs.Each([&](absl::Span dst_idx, Value *dst_vreg) { + SmallVector dst_idx_local(toArrayRef(dst_idx)); + Value lower_left = nullptr; + Value lower_right = nullptr; + Value upper_left = nullptr; + Value upper_right = nullptr; + // Set parts if their size is non-empty and the source vreg exists. + *(dst_idx_local.end() - 1) += lower_idx_delta; + if (*(dst_idx_local.end() - 1) < *(src_vregs.dimensions().end() - 1)) { + if (left_tile_split < tiles_per_vreg && 0 < left_right_split) { + lower_left = left_part(dst_idx_local); + } + if (right_tile_split < tiles_per_vreg) { + lower_right = right_part(dst_idx_local); + } + } + *(dst_idx_local.end() - 1) -= lower_idx_delta; + *(dst_idx_local.end() - 1) += upper_idx_delta; + if (*(dst_idx_local.end() - 1) >= 0) { + if (0 < left_tile_split && 0 < left_right_split) { + upper_left = left_part(dst_idx_local); + } + if (0 < right_tile_split) { + upper_right = right_part(dst_idx_local); + } + } + *(dst_idx_local.end() - 1) -= upper_idx_delta; + + // For the first and last vregs, some parts may be all padding, so + // unset them if this is the case. Note that the first and last vreg + // are the same when there is only one. + if (*(dst_idx_local.end() - 1) == 0) { + // We check the final offset (note that this is different from the rotate + // amount) against the thresholds of the last columns of vreg parts. + if (right_tile_split * tiling[1] <= dst_col_offset) { + // Note: When shifting right, UR is always all-padding. + upper_right = nullptr; + } + if (split_offset <= dst_col_offset) { + // Note: When shifting right, UL is always all-padding. When shifting + // left, UL is never all-padding (unless this is also the last vreg, + // possibly). + upper_left = nullptr; + } + if (vreg_slice[1] - tiling[1] + left_right_split <= dst_col_offset) { + // Note: When shifting right, LL is only all-padding if the source + // offset is in the last tile. When shifting left, LL is never + // all-padding (unless this is also the last vreg, possibly). + lower_left = nullptr; + } + } + if (*(dst_idx_local.end() - 1) == *(dst_vregs.dimensions().end() - 1) - 1) { + // We check the final end offset against the thresholds of the first + // columns of vreg parts. + const uint64_t end_offset = + (dst_col_offset + tiled_ishape[1] - 1) % vreg_slice[1] + 1; + if (end_offset <= left_tile_split * tiling[1]) { + // Note: When shifting left, LL is always all-padding. + lower_left = nullptr; + } + if (end_offset <= split_offset) { + // Note: When shifting left, LR is always all-padding. When shifting + // right, LR is never all-padding (unless this is also the first vreg, + // possibly). + lower_right = nullptr; + } + if (end_offset <= left_right_split) { + // Note: When shifting left, UR is only all-padding if the original + // end offset is in the first tile. When shifting right, UR is never + // all-padding (unless this is also the last vreg, possibly). + upper_right = nullptr; + } + } + // Combine parts into the final vreg (see comment in mask definitions). + auto combine_parts = [&builder](Value part1, Value part2, + auto get_mask_fn) -> Value { + if (part1 && part2) { + return builder.create(part1.getLoc(), get_mask_fn(), + part1, part2); + } else if (part1) { + return part1; + } else { + return part2; + } + }; + Value left = combine_parts(upper_left, lower_left, get_left_mask); + Value right = combine_parts(upper_right, lower_right, get_right_mask); + *dst_vreg = combine_parts(left, right, get_left_right_mask); + CHECK(*dst_vreg); + }); + return dst_vregs; +} + FailureOr>> changeOffsets( OpBuilder &builder, const std::array target_shape, const Location loc, const VectorType vty, const VectorLayout src, xla::Array vregs, const LayoutOffsets dst_offsets) { const VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(), src.implicit_dim()); - const auto &tiling = src.tiling(); const int packing = src.packing(); const int8_t bitwidth = src.bitwidth(); @@ -5061,15 +5338,7 @@ FailureOr>> changeOffsets( if (sublane_diff < 0) { sublane_diff += target_shape[0]; } - vregs.Each([&](absl::Span idx, Value *tile) { - *tile = - builder - .create(loc, *tile, - /*amount=*/sublane_diff, - /*dimension=*/0, /*stride=*/nullptr, - /*stride_dimension=*/nullptr) - .getResult(); - }); + rotateSublanes(builder, vregs, sublane_diff); } const int src_subelem = *src.offsets()[0] % packing; const int dst_subelem = *dst.offsets()[0] % packing; @@ -5108,68 +5377,12 @@ FailureOr>> changeOffsets( SmallVector dst_tiles_shape = dst.tileArrayImplicitShape(vty.getShape(), target_shape); CHECK_EQ(*(dst_tiles_shape.end() - 2), *(vregs.dimensions().end() - 2)); - if (dst_tiles_shape.back() != vregs.dimensions().back()) { - return emitError(loc, - "Not implemented: Offsets changing the vreg array shape"); - } + // TODO(tlongeri): Clean up col_diff and pass the dst offset directly. if (col_diff != 0) { - if (bitwidth != 32 || tiling != target_shape) { - return emitError(loc, - "Not implemented: Only 32-bit column shifts for " - "native layouts supported"); - } - TPU_ASSERT_GE_LOC(loc, vregs.num_dimensions(), 1); - std::optional maybe_create_mask; - if (*(vregs.dimensions().end() - 1) > 1) { - int64_t lane_start, lane_end; - if (col_diff > 0) { - lane_start = 0; - lane_end = col_diff; - } else { // col_diff < 0 - lane_start = target_shape[1] + col_diff; - lane_end = target_shape[1]; - } - auto boundIdxConst = - std::bind(IdxConst, std::placeholders::_1, builder, loc); - maybe_create_mask = builder.create( - loc, VectorType::get(target_shape, builder.getI1Type()), - ValueRange{boundIdxConst(0), boundIdxConst(lane_start)}, - ValueRange{boundIdxConst(target_shape[0]), boundIdxConst(lane_end)}); - } - auto rotated_vregs = vregs; - rotated_vregs.Each([&](absl::Span idx, Value *tile) { - *tile = builder - .create(loc, *tile, - /*amount=*/col_diff < 0 - ? target_shape[1] + col_diff - : col_diff, - /*dimension=*/1, /*stride=*/nullptr, - /*stride_dimension=*/nullptr) - .getResult(); - }); - vregs.Each([&](absl::Span idx, Value *result) { - Value rot_tile = rotated_vregs(idx); - Value prev_rot_tile; - if (col_diff > 0) { - if (*(idx.end() - 1) != 0) { - SmallVector prev_idx(idx.begin(), idx.end()); - --*(prev_idx.end() - 1); - prev_rot_tile = rotated_vregs(prev_idx); - } - } else { // col_diff < 0 - if (*(idx.end() - 1) != *(rotated_vregs.dimensions().end() - 1) - 1) { - SmallVector prev_idx(idx.begin(), idx.end()); - ++*(prev_idx.end() - 1); - prev_rot_tile = rotated_vregs(prev_idx); - } - } - if (prev_rot_tile != nullptr) { - rot_tile = builder.create( - loc, maybe_create_mask->getResult(), prev_rot_tile, rot_tile); - } - *result = rot_tile; - }); + FAILUREOR_ASSIGN_OR_RETURN( + vregs, doColumnShiftRelayout(builder, vty.getShape(), std::move(vregs), + src, *dst.offsets()[1], target_shape)); } return std::make_pair(dst, std::move(vregs)); } From 6ee1555d21ddaaed78e40b0f797973d360ea27c4 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 9 Aug 2024 14:23:56 +0530 Subject: [PATCH 041/702] Fix broken links in jnp.fft.fftfreq and jnp.fft.rfftfreq --- jax/_src/numpy/fft.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index e246b1fb6929..be98a153c02a 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -986,7 +986,7 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, """Return sample frequencies for the discrete Fourier transform. JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate - for use with the outputs of :func:`~jax.numpy.fft` and :func:`~jax.numpy.ifft`. + for use with the outputs of :func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`. Args: n: length of the FFT window @@ -1000,8 +1000,8 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, Array of sample frequencies, length ``n``. See also: - - :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.rfft` - and :func:`~jax.numpy.irfft`. + - :func:`jax.numpy.fft.rfftfreq`: frequencies for use with + :func:`~jax.numpy.fft.rfft` and :func:`~jax.numpy.fft.irfft`. """ dtype = dtype or dtypes.canonicalize_dtype(jnp.float_) if isinstance(n, (list, tuple)): @@ -1037,7 +1037,8 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, """Return sample frequencies for the discrete Fourier transform. JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate - for use with the outputs of :func:`~jax.numpy.rfft` and :func:`~jax.numpy.irfft`. + for use with the outputs of :func:`~jax.numpy.fft.rfft` and + :func:`~jax.numpy.fft.irfft`. Args: n: length of the FFT window @@ -1051,8 +1052,8 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, Array of sample frequencies, length ``n // 2 + 1``. See also: - - :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.fft` - and :func:`~jax.numpy.ifft`. + - :func:`jax.numpy.fft.rfftfreq`: frequencies for use with + :func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`. """ dtype = dtype or dtypes.canonicalize_dtype(jnp.float_) if isinstance(n, (list, tuple)): From f9bc4c643b4badae9c551c5c4767c89d53e67079 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Wed, 10 Jul 2024 10:15:35 +0000 Subject: [PATCH 042/702] [jax.distributed] Allow setting local device ids via env var --- jax/_src/distributed.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 5e8e956cf98b..b4a72c7678b0 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -45,10 +45,12 @@ def initialize(self, initialization_timeout: int = 300, coordinator_bind_address: str | None = None): coordinator_address = (coordinator_address or - os.environ.get('JAX_COORDINATOR_ADDRESS', None)) + os.environ.get('JAX_COORDINATOR_ADDRESS')) if isinstance(local_device_ids, int): local_device_ids = [local_device_ids] + if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')): + local_device_ids = list(map(int, env_ids.split(","))) (coordinator_address, num_processes, process_id, local_device_ids) = ( clusters.ClusterEnv.auto_detect_unset_distributed_params( From 77afe251e70500d6e505449242b4343cb86c402f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Fri, 9 Aug 2024 05:25:39 -0700 Subject: [PATCH 043/702] [Mosaic TPU][Python] Check validity of VectorLayout on init PiperOrigin-RevId: 661226283 --- jaxlib/mlir/_mlir_libs/tpu_ext.cc | 11 +++++++++-- .../mosaic/dialect/tpu/integrations/c/tpu_dialect.cc | 5 +++++ .../mosaic/dialect/tpu/integrations/c/tpu_dialect.h | 3 +++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 1d024a8b77a4..b09e5744b619 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -374,14 +374,21 @@ PYBIND11_MODULE(_tpu_ext, m) { .def(py::init([](int bitwidth, py::tuple offsets, py::tuple tiling, MlirTpuImplicitDim implicit_dim) { if (offsets.size() != 2) { - throw py::value_error("offsets should be of length 2"); + throw py::value_error("Offsets should be of length 2"); } - return mlirTpuVectorLayoutCreate( + if (tiling.size() != 2) { + throw py::value_error("Tiling should be of length 2"); + } + MlirTpuVectorLayout layout = mlirTpuVectorLayoutCreate( bitwidth, {offsetFromPyOffset(offsets[0]), offsetFromPyOffset(offsets[1])}, {tiling[0].cast(), tiling[1].cast()}, implicit_dim); + if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) { + throw py::value_error("Layout not valid for target shape"); + } + return layout; }), py::arg("bitwidth"), py::arg("offsets"), py::arg("tiling"), py::arg("implicit_dim")) diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index ef7d3fecfb22..3cc9b36972d6 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -312,6 +312,11 @@ void mlirTpuVectorLayoutPrint( unwrap(layout)->print(stream); } +bool mlirTpuVectorLayoutIsValid(MlirTpuVectorLayout layout, + MlirTpuI64TargetTuple target_shape) { + return unwrap(layout)->isValid(unwrap(target_shape)); +} + void mlirTpuVregDataBoundsDestroy(MlirTpuVregDataBounds data_bounds) { delete unwrap(data_bounds); } diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h index 42c974b3a961..5b2a7009e9e6 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h @@ -191,6 +191,9 @@ MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquivalentTo( MLIR_CAPI_EXPORTED void mlirTpuVectorLayoutPrint( MlirTpuVectorLayout layout, MlirStringCallback callback, void* user_data); +MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutIsValid( + MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape); + MLIR_CAPI_EXPORTED void mlirTpuVregDataBoundsDestroy( MlirTpuVregDataBounds data_bounds); From 5ced6db6927be3a82d739670c4e113db9692515c Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Fri, 9 Aug 2024 09:33:11 -0700 Subject: [PATCH 044/702] Cache `_get_tpu_generation` to avoid repeated calls to jax.devices(). We use `util.cache` such that if the default backend changes this function will (correctly) be re-evaluated. PiperOrigin-RevId: 661293560 --- jax/_src/pallas/mosaic/pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 7fde6665d394..812c496306f0 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -72,6 +72,7 @@ def add_leaves(i, x): return tree_util.tree_unflatten(treedef, broadcast_leaves) +@jax_util.cache(trace_context_in_key=False) def _get_tpu_generation() -> int: kind = jax.devices()[0].device_kind if kind.endswith(' lite'): From ff1f199d09c1c8861a74cea7e0566fc76b112bc0 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 9 Aug 2024 23:07:17 +0530 Subject: [PATCH 045/702] Improved docs for jnp.fft.rfftn and jnp.fft.irfftn --- jax/_src/numpy/fft.py | 160 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 157 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index e246b1fb6929..845ad7ee7ac2 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -252,17 +252,171 @@ def ifftn(a: ArrayLike, s: Shape | None = None, return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm) -@implements(np.fft.rfftn) def rfftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: + """Compute a multidimensional discrete Fourier transform of a real-valued array. + + JAX implementation of :func:`numpy.fft.rfftn`. + + Args: + a: real-valued input array. + s: optional sequence of integers. Controls the effective size of the input + along each specified axis. If not specified, it will default to the + dimension of input along ``axes``. + axes: optional sequence of integers, default=None. Specifies the axes along + which the transform is computed. If not specified, the transform is computed + along the last ``len(s)`` axes. If neither ``axes`` nor ``s`` is specified, + the transform is computed along all the axes. + norm: string, default="backward". The normalization mode. "backward", "ortho" + and "forward" are supported. + + Returns: + An array containing the multidimensional discrete Fourier transform of ``a`` + having size specified in ``s`` along the axes ``axes`` except along the axis + ``axes[-1]``. The size of the output along the axis ``axes[-1]`` is + ``s[-1]//2+1``. + + See also: + - :func:`jax.numpy.fft.rfft`: Computes a one-dimensional discrete Fourier + transform of real-valued array. + - :func:`jax.numpy.fft.rfft2`: Computes a two-dimensional discrete Fourier + transform of real-valued array. + - :func:`jax.numpy.fft.irfftn`: Computes a real-valued multidimensional inverse + discrete Fourier transform. + + Examples: + >>> x = jnp.array([[[1, 3, 5], + ... [2, 4, 6]], + ... [[7, 9, 11], + ... [8, 10, 12]]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfftn(x) + Array([[[ 78.+0.j , -12.+6.93j], + [ -6.+0.j , 0.+0.j ]], + + [[-36.+0.j , 0.+0.j ], + [ 0.+0.j , 0.+0.j ]]], dtype=complex64) + + When ``s=[3, 3, 4]``, size of the transform along ``axes (-3, -2)`` will + be (3, 3), and along ``axis -1`` will be ``4//2+1 = 3`` and size along + other axes will be the same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfftn(x, s=[3, 3, 4]) + Array([[[ 78. +0.j , -16. -26.j , 26. +0.j ], + [ 15. -36.37j, -16.12 +1.93j, 5. -12.12j], + [ 15. +36.37j, 8.12-11.93j, 5. +12.12j]], + + [[ -7.5 -49.36j, -20.45 +9.43j, -2.5 -16.45j], + [-25.5 -7.79j, -0.6 +11.96j, -8.5 -2.6j ], + [ 19.5 -12.99j, -8.33 -6.5j , 6.5 -4.33j]], + + [[ -7.5 +49.36j, 12.45 -4.43j, -2.5 +16.45j], + [ 19.5 +12.99j, 0.33 -6.5j , 6.5 +4.33j], + [-25.5 +7.79j, 4.6 +5.04j, -8.5 +2.6j ]]], dtype=complex64) + + When ``s=[3, 5]`` and ``axes=(0, 1)``, size of the transform along ``axis 0`` + will be ``3``, along ``axis 1`` will be ``5//2+1 = 3`` and dimension along + other axes will be same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfftn(x, s=[3, 5], axes=[0, 1]) + Array([[[ 18. +0.j , 26. +0.j , 34. +0.j ], + [ 11.09 -9.51j, 16.33-13.31j, 21.56-17.12j], + [ -0.09 -5.88j, 0.67 -8.23j, 1.44-10.58j]], + + [[ -4.5 -12.99j, -2.5 -16.45j, -0.5 -19.92j], + [ -9.71 -6.3j , -10.05 -9.52j, -10.38-12.74j], + [ -4.95 +0.72j, -5.78 -0.2j , -6.61 -1.12j]], + + [[ -4.5 +12.99j, -2.5 +16.45j, -0.5 +19.92j], + [ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j], + [ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64) + + For 1-D input: + + >>> x1 = jnp.array([1, 2, 3, 4]) + >>> jnp.fft.rfftn(x1) + Array([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64) + """ return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm) -@implements(np.fft.irfftn) def irfftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: + """Compute a real-valued multidimensional inverse discrete Fourier transform. + + JAX implementation of :func:`numpy.fft.irfftn`. + + Args: + a: input array. + s: optional sequence of integers. Specifies the size of the output in each + specified axis. If not specified, the dimension of output along axis + ``axes[-1]`` is ``2*(m-1)``, ``m`` is the size of input along axis ``axes[-1]`` + and the dimension along other axes will be the same as that of input. + axes: optional sequence of integers, default=None. Specifies the axes along + which the transform is computed. If not specified, the transform is computed + along the last ``len(s)`` axes. If neither ``axes`` nor ``s`` is specified, + the transform is computed along all the axes. + norm: string, default="backward". The normalization mode. "backward", "ortho" + and "forward" are supported. + + Returns: + A real-valued array containing the multidimensional inverse discrete Fourier + transform of ``a`` with size ``s`` along specified ``axes``, and the same as + the input along other axes. + + See also: + - :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier + transform of a real-valued array. + - :func:`jax.numpy.fft.irfft`: Computes a real-valued one-dimensional inverse + discrete Fourier transform. + - :func:`jax.numpy.fft.irfft2`: Computes a real-valued two-dimensional inverse + discrete Fourier transform. + + Examples: + ``jnp.fft.irfftn`` computes the transform along all the axes by default. + + >>> x = jnp.array([[[1, 3, 5], + ... [2, 4, 6]], + ... [[7, 9, 11], + ... [8, 10, 12]]]) + >>> jnp.fft.irfftn(x) + Array([[[ 6.5, -1. , 0. , -1. ], + [-0.5, 0. , 0. , 0. ]], + + [[-3. , 0. , 0. , 0. ], + [ 0. , 0. , 0. , 0. ]]], dtype=float32) + + When ``s=[3, 4]``, size of the transform along ``axes (-2, -1)`` will be + ``(3, 4)`` and size along other axes will be the same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.irfftn(x, s=[3, 4]) + Array([[[ 2.33, -0.67, 0. , -0.67], + [ 0.33, -0.74, 0. , 0.41], + [ 0.33, 0.41, 0. , -0.74]], + + [[ 6.33, -0.67, 0. , -0.67], + [ 1.33, -1.61, 0. , 1.28], + [ 1.33, 1.28, 0. , -1.61]]], dtype=float32) + + When ``s=[3]`` and ``axes=[0]``, size of the transform along ``axes 0`` will + be ``3`` and dimension along other axes will be same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.irfftn(x, s=[3], axes=[0]) + Array([[[ 5., 7., 9.], + [ 6., 8., 10.]], + + [[-2., -2., -2.], + [-2., -2., -2.]], + + [[-2., -2., -2.], + [-2., -2., -2.]]], dtype=float32) + """ return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm) @@ -937,7 +1091,7 @@ def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), discrete Fourier transform. Examples: - ``jnp.fft.ifft2`` computes the transform along the last two axes by default. + ``jnp.fft.irfft2`` computes the transform along the last two axes by default. >>> x = jnp.array([[[1, 3, 5], ... [2, 4, 6]], From 3bd3597703c88b1f68533c111d0f42aeb41431d7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 9 Aug 2024 12:17:28 -0700 Subject: [PATCH 046/702] Improves error message in case of invalid sharding mesh PiperOrigin-RevId: 661358450 --- jax/_src/sharding_impls.py | 23 +++++++++++------------ tests/pjit_test.py | 6 +++--- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 98fd8c7b02c2..1a23f4ba74ad 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -53,18 +53,17 @@ class TransferToMemoryKind: @util.cache(max_size=128, trace_context_in_key=False) def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes): - try: - for p in parsed_pspec: - if p is not None: - for r in p: - mesh.shape[r] - if r in _manual_axes: - raise ValueError( - f"Axis: {r} of {parsed_pspec.get_partition_spec()} " - f"is also found in manual_axes: {_manual_axes}.") from None - except KeyError as e: - raise ValueError(f"Resource axis: {e.args[0]} of {parsed_pspec.user_spec} is " - "undefined.") from None + for p in parsed_pspec: + if p is not None: + for r in p: + if r not in mesh.shape: + raise ValueError( + f"Resource axis: {r} of {parsed_pspec.get_partition_spec()} " + f"is not found in mesh: {tuple(mesh.shape.keys())}.") + if r in _manual_axes: + raise ValueError( + f"Axis: {r} of {parsed_pspec.get_partition_spec()} " + f"is also found in manual_axes: {_manual_axes}.") from None def hashed_index(x) -> int: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 915a6d4b52d2..b933944090a2 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4398,7 +4398,7 @@ def testUndefinedResourcesArgs(self, mesh, resources): spec = P(resources,) with self.assertRaisesRegex( ValueError, - r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"): + r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."): pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x) @check_1d_2d_mesh(set_mesh=False) @@ -4408,7 +4408,7 @@ def testUndefinedResourcesOuts(self, mesh, resources): spec = P(resources,) with self.assertRaisesRegex( ValueError, - r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"): + r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."): pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x) @check_1d_2d_mesh(set_mesh=False) @@ -4418,7 +4418,7 @@ def testUndefinedResourcesConstraint(self, mesh, resources): spec = P(resources,) with self.assertRaisesRegex( ValueError, - r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"): + r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."): pjit( lambda x: with_sharding_constraint(x, spec), in_shardings=None, From a3ae5e18d3a59b5be60bfd98bee7a56ea2f49042 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Fri, 9 Aug 2024 12:53:31 -0700 Subject: [PATCH 047/702] Remove `build_cuda_plugin_from_source` flag which is no longe used. https://github.com/google/jax/commit/751b5742fdb5ff3e36212e8e64ef668de41363ff PiperOrigin-RevId: 661370449 --- jax/BUILD | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 66df0d2f7272..ec350b4b99a7 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -57,20 +57,6 @@ config_setting( }, ) -# When `build_cuda_plugin_from_source` is true, it assumes running `bazel test` without preinstalled -# cuda plugin. -bool_flag( - name = "build_cuda_plugin_from_source", - build_setting_default = False, -) - -config_setting( - name = "enable_build_cuda_plugin_from_source", - flag_values = { - ":build_cuda_plugin_from_source": "True", - }, -) - exports_files([ "LICENSE", "version.py", From 8bd84913a6b00024c238dbc0c80ca6df9a3591fd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 9 Aug 2024 13:34:50 -0700 Subject: [PATCH 048/702] Better docs for array, asarray, linspace This allows removal of extra_params handling from util.implements --- jax/_src/numpy/lax_numpy.py | 198 +++++++++++++++++++++++++++++++++--- jax/_src/numpy/util.py | 12 --- 2 files changed, 186 insertions(+), 24 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 49bd8ca412fe..0f27cb5ff409 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3399,13 +3399,74 @@ def _supports_buffer_protocol(obj): deprecations.register("jax-numpy-array-none") -@util.implements(np.array, lax_description=_ARRAY_DOC, extra_params=""" -device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. -""") + def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0, *, device: xc.Device | Sharding | None = None) -> Array: + """Convert an object to a JAX array. + + JAX implementation of :func:`numpy.array`. + + Args: + object: an object that is convertible to an array. This includes JAX + arrays, NumPy arrays, Python scalars, Python collections like lists + and tuples, objects with an ``__array__`` method, and objects + supporting the Python buffer protocol. + dtype: optionally specify the dtype of the output array. If not + specified it will be inferred from the input. + copy: specify whether to force a copy of the input. Default: True. + order: not implemented in JAX + ndmin: integer specifying the minimum number of dimensions in the + output array. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + A JAX array constructed from the input. + + See also: + - :func:`jax.numpy.asarray`: like `array`, but by default only copies + when necessary. + - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object + that implements the dlpack interface. + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + Constructing JAX arrays from Python scalars: + + >>> jnp.array(True) + Array(True, dtype=bool) + >>> jnp.array(42) + Array(42, dtype=int32, weak_type=True) + >>> jnp.array(3.5) + Array(3.5, dtype=float32, weak_type=True) + >>> jnp.array(1 + 1j) + Array(1.+1.j, dtype=complex64, weak_type=True) + + Constructing JAX arrays from Python collections: + + >>> jnp.array([1, 2, 3]) # list of ints -> 1D array + Array([1, 2, 3], dtype=int32) + >>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.array(range(5)) + Array([0, 1, 2, 3, 4], dtype=int32) + + Constructing JAX arrays from NumPy arrays: + + >>> jnp.array(np.linspace(0, 2, 5)) + Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [2, 3, 5, 7]) + >>> jnp.array(pybuffer) + Array([2, 3, 5, 7], dtype=int32) + """ if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'") @@ -3594,13 +3655,72 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, return _array_copy(result) if copy else result -@util.implements(np.asarray, lax_description=_ARRAY_DOC, extra_params=""" -device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. -""") def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, *, copy: bool | None = None, device: xc.Device | Sharding | None = None) -> Array: + """Convert an object to a JAX array. + + JAX implementation of :func:`numpy.asarray`. + + Args: + a: an object that is convertible to an array. This includes JAX + arrays, NumPy arrays, Python scalars, Python collections like lists + and tuples, objects with an ``__array__`` method, and objects + supporting the Python buffer protocol. + dtype: optionally specify the dtype of the output array. If not + specified it will be inferred from the input. + order: not implemented in JAX + copy: optional boolean specifying the copy mode. If True, then always + return a copy. If False, then error if a copy is necessary. Default is + None, which will only copy when necessary. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + A JAX array constructed from the input. + + See also: + - :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`. + - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object + that implements the dlpack interface. + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + Constructing JAX arrays from Python scalars: + + >>> jnp.asarray(True) + Array(True, dtype=bool) + >>> jnp.asarray(42) + Array(42, dtype=int32, weak_type=True) + >>> jnp.asarray(3.5) + Array(3.5, dtype=float32, weak_type=True) + >>> jnp.asarray(1 + 1j) + Array(1.+1.j, dtype=complex64, weak_type=True) + + Constructing JAX arrays from Python collections: + + >>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array + Array([1, 2, 3], dtype=int32) + >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.asarray(range(5)) + Array([0, 1, 2, 3, 4], dtype=int32) + + Constructing JAX arrays from NumPy arrays: + + >>> jnp.asarray(np.linspace(0, 2, 5)) + Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [2, 3, 5, 7]) + >>> jnp.asarray(pybuffer) + Array([2, 3, 5, 7], dtype=int32) + """ # For copy=False, the array API specifies that we raise a ValueError if the input supports # the buffer protocol but a copy is required. Since array() supports the buffer protocol # via numpy, this is only the case when the default device is not 'cpu' @@ -4450,15 +4570,69 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... -@util.implements(np.linspace, extra_params=""" -device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. -""") def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: + """Return evenly-spaced numbers within an interval. + + JAX implementation of :func:`numpy.linspace`. + + Args: + start: scalar or array of starting values. + stop: scalar or array of stop values. + num: number of values to generate. Default: 50. + endpoint: if True (default) then include the ``stop`` value in the result. + If False, then exclude the ``stop`` value. + retstep: If True, then return a ``(result, step)`` tuple, where ``step`` is the + interval between adjacent values in ``result``. + axis: integer axis along which to generate the linspace. Defaults to zero. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + An array ``values``, or a tuple ``(values, step)`` if ``retstep`` is True, where: + + - ``values`` is an array of evenly-spaced values from ``start`` to ``stop`` + - ``step`` is the interval between adjacent values. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. + + Examples: + List of 5 values between 0 and 10: + + >>> jnp.linspace(0, 10, 5) + Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32) + + List of 8 values between 0 and 10, excluding the endpoint: + + >>> jnp.linspace(0, 10, 8, endpoint=False) + Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32) + + List of values and the step size between them + + >>> vals, step = jnp.linspace(0, 10, 9, retstep=True) + >>> vals + Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) + >>> step + Array(1.25, dtype=float32) + + Multi-dimensional linspace: + + >>> start = jnp.array([0, 5]) + >>> stop = jnp.array([5, 10]) + >>> jnp.linspace(start, stop, 5) + Array([[ 0. , 5. ], + [ 1.25, 6.25], + [ 2.5 , 7.5 ], + [ 3.75, 8.75], + [ 5. , 10. ]], dtype=float32) + """ num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 09ff99cb40a1..e9d1db26731c 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -111,19 +111,12 @@ def _parse_parameters(body: str) -> dict[str, str]: return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters} -def _parse_extra_params(extra_params: str) -> dict[str, str]: - """Parse the extra parameters passed to implements()""" - parameters = _parameter_break.split(extra_params.strip('\n')) - return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters} - - def implements( original_fun: Callable[..., Any] | None, update_doc: bool = True, lax_description: str = "", sections: Sequence[str] = ('Parameters', 'Returns', 'References'), skip_params: Sequence[str] = (), - extra_params: str | None = None, module: str | None = None, ) -> Callable[[_T], _T]: """Decorator for JAX functions which implement a specified NumPy function. @@ -145,9 +138,6 @@ def implements( ["Parameters", "Returns", "References"] skip_params: a list of strings containing names of parameters accepted by the function that should be skipped in the parameter list. - extra_params: an optional string containing additional parameter descriptions. - When ``update_doc=True``, these will be added to the list of parameter - descriptions in the updated doc. module: an optional string specifying the module from which the original function is imported. This is useful for objects such as ufuncs, where the module cannot be determined from the original function itself. @@ -176,8 +166,6 @@ def decorator(wrapped_fun): code = getattr(getattr(wrapped_fun, "__wrapped__", wrapped_fun), "__code__", None) # Remove unrecognized parameter descriptions. parameters = _parse_parameters(parsed.sections['Parameters']) - if extra_params: - parameters.update(_parse_extra_params(extra_params)) parameters = {p: desc for p, desc in parameters.items() if (code is None or p in code.co_varnames) and p not in skip_params} From 4863a568f91f74f3156a9447e9b614996e0e9938 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 9 Aug 2024 14:40:08 -0700 Subject: [PATCH 049/702] Fix array_test.py when jax_pmap_no_rank_reduction is flipped to true. The problem is that squeezing was happening on noncommitted arrays so list(x) was moving all the shards to device 0. This will potentially cause ooms. PiperOrigin-RevId: 661408226 --- jax/_src/array.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 03b0e49d3201..e7bc2e933531 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -339,18 +339,18 @@ def __getitem__(self, idx): except ValueError: arr_idx = None if arr_idx is not None: - a = self._arrays[arr_idx] - out = ArrayImpl( - a.aval, SingleDeviceSharding(_get_device(a)), [a], committed=False, - _skip_checks=True) + out = self._arrays[arr_idx] + sharding = SingleDeviceSharding(_get_device(out)) if config.pmap_no_rank_reduction.value: # If cidx was the index of a single shard, then it corresponds to one # shard of the chunked dimension. dims = tuple(i for i, x in enumerate(cidx) if isinstance(x, int)) - return lax.squeeze(out, dimensions=dims) - else: - return out + # Squeeze on committed arrays to avoid data movement to shard 0. + out = lax.squeeze(out, dimensions=dims) + + return ArrayImpl( + out.aval, sharding, [out], committed=False, _skip_checks=True) return lax_numpy._rewriting_take(self, idx) From 7a75c96aa98d369cb69c88fe9e7daae1f6e14853 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 9 Aug 2024 14:52:55 -0700 Subject: [PATCH 050/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/46e205a0b6b38586a0d9edfb2ffecdcbd0b7590b. PiperOrigin-RevId: 661412627 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 815b007ce5fc..888b96545278 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "76978f280df19d6bfcaa4559ccb5573e13367c7b" -XLA_SHA256 = "124185ce5c8da06f7e3f48eac5e2fdfd4049d17f4ef7b0dc343bb087722037f0" +XLA_COMMIT = "46e205a0b6b38586a0d9edfb2ffecdcbd0b7590b" +XLA_SHA256 = "7a91ff40e4abbef76428a2175e580159d468429992d99c9e11a96e5870dc700e" def repo(): tf_http_archive( From abc9ba00e93fad04e829c366e8f06c0b396269c5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 9 Aug 2024 20:03:06 -0700 Subject: [PATCH 051/702] Rename `count_jit_and_pmap_compiles` to `count_jit_and_pmap_lowerings` PiperOrigin-RevId: 661496993 --- jax/_src/test_util.py | 4 ++-- tests/api_test.py | 16 ++++++++-------- tests/checkify_test.py | 2 +- tests/lax_control_flow_test.py | 2 +- tests/memories_test.py | 2 +- tests/pjit_test.py | 6 +++--- tests/pmap_test.py | 10 +++++----- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index e4de7e7b787b..b19110dfc516 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -365,7 +365,7 @@ def compiled_call_count(*args, **kwargs): @contextmanager -def count_jit_and_pmap_compiles(): +def count_jit_and_pmap_lowerings(): # No need to clear any caches since we generally jit and pmap fresh callables # in tests. @@ -405,7 +405,7 @@ def mlir_lower_and_count(ctx, name, *args, **kwargs): @contextmanager def assert_num_jit_and_pmap_compilations(times): - with count_jit_and_pmap_compiles() as count: + with count_jit_and_pmap_lowerings() as count: yield if count[0] != times: raise AssertionError(f"Expected exactly {times} XLA compilations, " diff --git a/tests/api_test.py b/tests/api_test.py index fa2f389f494b..52670bfd1590 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1442,7 +1442,7 @@ def test_caches_dont_depend_on_unnamed_axis_env(self): # https://github.com/google/jax/issues/9187 f = jax.jit(lambda: jnp.sin(1)) expected = f() - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = jax.vmap(f, axis_size=2, out_axes=None)() self.assertEqual(count[0], 0) # no compiles self.assertArraysAllClose(ans, expected, check_dtypes=True) @@ -3433,11 +3433,11 @@ def test_grad_of_jit_compilation_caching2(self): def f(x): return jnp.sin(x) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 _ = jax.grad(f)(3.) self.assertEqual(count[0], 2) # one for fwd, one for bwd - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 _ = jax.grad(f)(3.) _ = jax.grad(f)(4.) self.assertEqual(count[0], 0) # cache hits on both fwd and bwd @@ -4352,7 +4352,7 @@ def test_vmap_caching(self): jf = jax.jit(f) x = jax.random.uniform(jax.random.key(0), shape=(8, 4)) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(5): jax.hessian(jf)(x).block_until_ready() @@ -5929,7 +5929,7 @@ def test_linearize_caching(self): # https://github.com/google/jax/issues/9661 identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) _, f_lin = jax.linearize(identity, 1.) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): f_lin(1.).block_until_ready() self.assertEqual(count[0], 1) # cached after first execution @@ -5947,7 +5947,7 @@ def test_vjp_caching_static_argnums(self): identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x), static_argnums=(1,)) _, f_vjp = jax.vjp(lambda x: identity(x, True), 1.) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): f_vjp(1.)[0].block_until_ready() self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd @@ -5955,7 +5955,7 @@ def test_vjp_caching_static_argnums(self): def test_fwd_caching(self): # see above test also identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): y, _ = jax.vjp(identity, 1.) y.block_until_ready() @@ -5964,7 +5964,7 @@ def test_fwd_caching(self): def test_fwd_caching_static_argnums(self): # see above test also identity = jax.checkpoint(jax.jit(lambda x: 2 * x), static_argnums=(0,)) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): y = identity(1.) y.block_until_ready() diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 730f14ddcdd1..726e89d1b3e9 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -815,7 +815,7 @@ def g(x): def test_retracing(self): f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2)) _ = f(3.) - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: _ = f(3.) self.assertEqual(count[0], 0) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index d52862ec42ac..192603de3655 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2590,7 +2590,7 @@ def f(x): def g(x): return x + 2 - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: for x in range(10): lax.cond(x, f, g, x) # Should observe a maximum of 4 compiles: convert_element_type, f, g, cond diff --git a/tests/memories_test.py b/tests/memories_test.py index 0d3720e807c0..ae07faf0f9de 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1200,7 +1200,7 @@ def mul(x): f = jax.jit(mul, in_shardings=s) g = jax.jit(mul, in_shardings=s2) - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: out = f(np_inp) out2 = g(np_inp2) self.assertEqual(count[0], 1) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b933944090a2..e072d5b22ec5 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3523,7 +3523,7 @@ def test_sharding_on_output_with_vmap(self): arr = jax.device_put( np.arange(16).reshape(8, 2), NamedSharding(mesh, P(None, 'x'))) - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: vf = jax.vmap(pjit(lambda x: x * 2, in_shardings=ns)) out = vf(arr) self.assertIsInstance(out.sharding, NamedSharding) @@ -3867,7 +3867,7 @@ def g(a): b = jax.device_put(out_a, NamedSharding(mesh2, P('x'))) f(b) # lowering cache *hit* - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: g(np.arange(8)) self.assertEqual(count[0], 1) @@ -3890,7 +3890,7 @@ def g(a): b = jax.device_put(out_a, NamedSharding(mesh2, P())) f(b) # lowering cache *miss* - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: g(np.arange(8)) self.assertEqual(count[0], 2) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index e576fae91a83..c0a3d27dadef 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1285,7 +1285,7 @@ def testPmapConstant(self): device_count = jax.device_count() f = self.pmap(lambda x: 3) x = jnp.arange(device_count) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) # self.assertEqual(count[0], 0) # TODO(mattjj): fix this expected = np.repeat(3, device_count) @@ -1306,7 +1306,7 @@ def testPmapConstantDevices(self): shuffle(devices) f = self.pmap(lambda x: 3, devices=devices) x = jnp.arange(len(devices)) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants expected = np.repeat(3, len(devices)) @@ -1342,7 +1342,7 @@ def testNestedPmapConstant(self): f = self.pmap(self.pmap(lambda x: 3)) shape = (2, jax.device_count() // 2, 3) x = jnp.arange(math.prod(shape)).reshape(shape) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants expected = 3 * np.ones(shape[:2]) @@ -1368,7 +1368,7 @@ def testNestedPmapConstantDevices(self): f = self.pmap(self.pmap(lambda x: 3), devices=devices) shape = (2, len(devices) // 2, 3) x = jnp.arange(math.prod(shape)).reshape(shape) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants expected = 3 * np.ones(shape[:2]) @@ -2039,7 +2039,7 @@ def f(x): _, f_bwd = jax.vjp(f, x) _ = f_bwd(x) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 _, f_bwd2 = jax.vjp(f, x) _ = f_bwd(x) _ = f_bwd2(x) From c08656c61d0e2460f5d902b0af808b74c76a48ca Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 9 Aug 2024 23:16:54 -0700 Subject: [PATCH 052/702] [Rollback] We still want to allow multiple meshes in the user program Reverts dd958adc39550d2758ecdb13809c6d85df7658a2 PiperOrigin-RevId: 661537233 --- jax/_src/custom_partitioning.py | 6 +--- jax/_src/interpreters/mlir.py | 6 ++-- jax/_src/interpreters/pxla.py | 23 +++++++-------- jax/_src/sharding_impls.py | 1 - tests/memories_test.py | 17 +++++++++++ tests/pjit_test.py | 51 ++++----------------------------- 6 files changed, 38 insertions(+), 66 deletions(-) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 8f48746dda37..c038ef0641a8 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -24,7 +24,6 @@ from typing import Any import weakref -import numpy as np import jax from jax import tree_util from jax._src import api_util @@ -482,20 +481,17 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, infer_sharding_from_operands, decode_shardings, static_args): + mesh = mesh_lib.thread_resources.env.physical_mesh axis_context = ctx.module_context.axis_context if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)): return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values) - mesh = mesh_lib.thread_resources.env.physical_mesh if isinstance(axis_context, sharding_impls.ShardingContext): devices = axis_context.device_assignment if devices is None: raise AssertionError( 'Please file a bug at https://github.com/google/jax/issues') - if axis_context.mesh_shape is not None: - ma, ms = list(zip(*axis_context.mesh_shape)) - mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma) elif isinstance(axis_context, sharding_impls.SPMDAxisContext): devices = axis_context.mesh._flat_devices_tuple else: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 3a666e357df1..84b1824bd2fe 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -957,6 +957,7 @@ def lower_jaxpr_to_module( input_output_aliases: None | tuple[int | None, ...] = None, propagated_out_mem_kinds: tuple[None | str, ...] | None = None, lowering_parameters: LoweringParameters, + mesh_shape_tuple: tuple[tuple[str, int], ...] | None = None, ) -> LoweringResult: """Lowers a top-level jaxpr to an MLIR module. @@ -1044,14 +1045,13 @@ def lower_jaxpr_to_module( # XLA computation preserves the module name. attrs = ctx.module.operation.attributes if config.use_shardy_partitioner.value: - assert (isinstance(axis_context, sharding_impls.ShardingContext) and - axis_context.mesh_shape is not None) + assert mesh_shape_tuple is not None ctx.module.body.append( dialects.sdy.MeshOp( "mesh", dialects.sdy.MeshAttr.get( [dialects.sdy.MeshAxisAttr.get(name, size) - for name, size in axis_context.mesh_shape]))) + for name, size in mesh_shape_tuple]))) module_name = _module_name_regex.sub("_", module_name) attrs["sym_name"] = ir.StringAttr.get(module_name) attrs["mhlo.num_replicas"] = i32_attr(num_replicas) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 2cd09299e60a..4dcdfbcbf495 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1881,7 +1881,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, propagated_out_mem_kinds: tuple[None | str, ...], platforms: tuple[str, ...], lowering_parameters: mlir.LoweringParameters, - mesh_shape_tuple: tuple[tuple[str, int], ...] | None): + mesh_shape_tuple: tuple[tuple[str, int], ...]): jaxpr = closed_jaxpr.jaxpr in_shardings = semantic_in_shardings.shardings out_shardings = semantic_out_shardings.shardings @@ -1911,8 +1911,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings) out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings) replicated_args = [False] * len(global_in_avals) - axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment, - mesh_shape=mesh_shape_tuple) + axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment) num_partitions = num_devices else: # This path is triggered for `jit(pmap)` cases. @@ -1958,7 +1957,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, all_default_mem_kind=all_default_mem_kind, input_output_aliases=inout_aliases, propagated_out_mem_kinds=propagated_out_mem_kinds, - lowering_parameters=lowering_parameters) + lowering_parameters=lowering_parameters, + mesh_shape_tuple=mesh_shape_tuple) tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) @@ -2203,15 +2203,14 @@ def lower_sharding_computation( semantic_out_shardings = SemanticallyEqualShardings( out_shardings, global_out_avals) # type: ignore prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) - - # TODO(yashkatariya): Initialize with context_mesh here? mesh_shape_tuple = None - for sharding in it.chain( - in_shardings, out_shardings, - [js for js, _ in unique_intermediate_shardings]): - if isinstance(sharding, sharding_impls.NamedSharding): - mesh_shape_tuple = sharding.mesh.shape_tuple - break + if config.use_shardy_partitioner.value: + for sharding in it.chain( + in_shardings, out_shardings, + [js for js, _ in unique_intermediate_shardings]): + if isinstance(sharding, sharding_impls.NamedSharding): + mesh_shape_tuple = sharding.mesh.shape_tuple + break (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 1a23f4ba74ad..d41ef7410d1a 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1162,7 +1162,6 @@ class ShardingContext: """ num_devices: int device_assignment: tuple[xc.Device, ...] | None = None - mesh_shape: tuple[tuple[str, int], ...] | None = None def __post_init__(self): if self.device_assignment is not None: diff --git a/tests/memories_test.py b/tests/memories_test.py index ae07faf0f9de..87c85ffc47d8 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1178,6 +1178,23 @@ def test_jit_cpp_cache_hit(self): self.assertArraysEqual(out, np_inp @ np_inp.T) self.assertArraysEqual(out2, np_inp @ np_inp.T) + def test_jit_compilation_cache_hit(self): + mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y")) + inp2 = jax.device_put( + np_inp, GSPMDSharding(tuple(mesh.devices.flat), + s._to_xla_hlo_sharding(inp.ndim), + memory_kind="device") + ) + + f = jax.jit(lambda x: x @ x.T) + + with (jtu.count_pjit_cpp_cache_miss() as cpp_count, + jtu.count_jit_and_pmap_lowerings() as compile_count): + f(inp) + f(inp2) + self.assertEqual(cpp_count[0], 2) + self.assertEqual(compile_count[0], 1) + def test_jit_cpp_cache_output_hit(self): _, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e072d5b22ec5..44782ec15bc0 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1483,45 +1483,6 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape): pjit_f = pjit(jit_f, in_shardings=(P('x')), out_shardings=P('x')) self.assertArraysEqual(x, pjit_f(x)) - def test_custom_partitioning_no_mesh_context(self): - self.skip_if_custom_partitioning_not_supported() - - @custom_partitioning - def f(x): - return x - - def partition(mesh, arg_shapes, result_shape): - def lower_fn(x): - @jax.jit - def g(y): - return y - - return g(x) - - x_shard = arg_shapes[0].sharding - return ( - mesh, - lower_fn, - NamedSharding(x_shard.mesh, P('x')), - (NamedSharding(x_shard.mesh, P('x')),), - ) - - def infer_sharding_from_operands(mesh, arg_shapes, result_shape): - x_shard = arg_shapes[0].sharding - return NamedSharding(x_shard.mesh, P('x')) - - f.def_partition( - infer_sharding_from_operands=infer_sharding_from_operands, - partition=partition, - ) - - mesh = jtu.create_global_mesh((4,), ('x',)) - x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32) - s = NamedSharding(mesh, P('x')) - - pjit_f = jax.jit(f, in_shardings=s, out_shardings=s) - self.assertArraysEqual(x, pjit_f(x)) - @jtu.with_mesh([('x', 4)]) def test_custom_partitioner_with_scan(self): self.skip_if_custom_partitioning_not_supported() @@ -3448,8 +3409,8 @@ def mul(x): cache_info4 = pxla._cached_compilation.cache_info() self.assertIsInstance(out4.sharding, PositionalSharding) - self.assertEqual(cache_info4.hits, cache_info3.hits) - self.assertEqual(cache_info4.misses, cache_info3.misses + 1) + self.assertEqual(cache_info4.hits, cache_info3.hits + 1) + self.assertEqual(cache_info4.misses, cache_info3.misses) def test_cache_hit_pjit_lower_with_cpp_cache_miss(self): mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) @@ -3560,8 +3521,8 @@ def test_jit_mul_sum_sharding_preserved(self): self.assertIsInstance(out3.sharding, PositionalSharding) self.assertEqual(count[0], 1) - self.assertEqual(cache_info2.hits, cache_info1.hits) - self.assertEqual(cache_info2.misses, cache_info1.misses + 1) + self.assertEqual(cache_info2.hits, cache_info1.hits + 1) + self.assertEqual(cache_info2.misses, cache_info1.misses) self.assertEqual(pl_cache_info2.hits, pl_cache_info1.hits) self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1) @@ -3853,7 +3814,7 @@ def test_lowering_cache_hit_different_devices(self): self.skipTest('Requires >=4 devices') mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x') - mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'x') + mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y') @jax.jit def f(x): @@ -3864,7 +3825,7 @@ def g(a): out_a = f(a) # lowering cached # same num_devices but different devices. - b = jax.device_put(out_a, NamedSharding(mesh2, P('x'))) + b = jax.device_put(out_a, NamedSharding(mesh2, P('y'))) f(b) # lowering cache *hit* with jtu.count_jit_and_pmap_lowerings() as count: From c2c116dc5cdc3dba0568e66ad652c8f2acbc3c08 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Sat, 10 Aug 2024 05:22:05 -0700 Subject: [PATCH 053/702] jnp.intersect1d: add support for static size argument. --- jax/_src/numpy/setops.py | 181 +++++++++++++++++++++++++++------------ jax/numpy/__init__.pyi | 3 +- tests/lax_numpy_test.py | 49 +++++++++-- 3 files changed, 171 insertions(+), 62 deletions(-) diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 2e953c67abd6..db4237dbd069 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -33,7 +33,7 @@ sort, where, zeros) from jax._src.numpy.reductions import any, cumsum from jax._src.numpy.ufuncs import isnan -from jax._src.numpy.util import check_arraylike +from jax._src.numpy.util import check_arraylike, promote_dtypes from jax._src.util import canonicalize_axis from jax._src.typing import Array, ArrayLike @@ -68,10 +68,9 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, JAX implementation of :func:`numpy.setdiff1d`. Because the size of the output of ``setdiff1d`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.setdiff1d`` to be used in such contexts. - transformations. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified statically + for ``jnp.setdiff1d`` to be used in such contexts. Args: ar1: first array of elements to be differenced. @@ -156,10 +155,9 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, JAX implementation of :func:`numpy.union1d`. Because the size of the output of ``union1d`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.union1d`` to be used in such contexts. - transformations. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.union1d`` to be used in such contexts. Args: ar1: first array of elements to be unioned. @@ -272,30 +270,97 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr @partial(jit, static_argnames=['return_indices']) -def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> tuple[Array, ...]: - # JIT-compatible helper function for intersect1d - ar = concatenate((ar1, ar2)) +def _intersect1d_sorted_mask(arr1: Array, arr2: Array, + return_indices: bool) -> tuple[Array, Array, Array | None]: + """JIT-compatible helper function for intersect1d""" + assert arr1.ndim == arr2.ndim == 1 + arr = concatenate((arr1, arr2)) if return_indices: - iota = lax.broadcasted_iota(np.int64, np.shape(ar), dimension=0) - aux, indices = lax.sort_key_val(ar, iota) + iota = lax.broadcasted_iota(np.int64, np.shape(arr), dimension=0) + aux, indices = lax.sort_key_val(arr, iota) else: - aux = sort(ar) - + aux = sort(arr) + indices = None mask = aux[1:] == aux[:-1] + return aux, mask, indices + + +@partial(jit, static_argnames=['fill_value', 'assume_unique', 'size', 'return_indices']) +def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, assume_unique: bool, + size: int, return_indices: bool) -> Array | tuple[Array, Array, Array]: + """Jit-compatible helper function for intersect1d with size specified.""" + # Ensured by caller + assert arr1.ndim == arr2.ndim == 1 + assert arr1.dtype == arr2.dtype + + # First step: we concatenate the unique values of arr1 and arr2. + # The resulting values are: + # num_unique1/num_unique2: number of unique values in arr1/arr2 + # aux[:num_unique1 + num_unique2] contains the sorted concatenated + # unique values drawn from arr1 and arr2. + # aux_sorted_indices: indices mapping aux to concatenation of arr1 and arr2 + # ind1[:num_unique1], ind2[:num_unique2]: indices of sorted unique + # values in arr1/arr2 + # mask: boolean mask of relevant values in aux & aux_sorted_indices + if assume_unique: + ind1, num_unique1 = arange(arr1.size), asarray(arr1.size) + ind2, num_unique2 = arange(arr2.size), asarray(arr2.size) + arr = concatenate([arr1, arr2]) + aux, aux_sort_indices = lax.sort([arr, arange(arr.size)], is_stable=True, num_keys=1) + mask = ones(arr.size, dtype=bool) + else: + arr1, ind1, num_unique1 = _unique(arr1, 0, size=arr1.size, return_index=True, return_true_size=True, fill_value=0) + arr2, ind2, num_unique2 = _unique(arr2, 0, size=arr2.size, return_index=True, return_true_size=True, fill_value=0) + arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2)) + arr = arr.at[:arr1.size].set(arr1) + arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,)) + mask = arange(arr.size) < num_unique1 + num_unique2 + _, aux, aux_sort_indices = lax.sort([~mask, arr, arange(arr.size)], is_stable=True, num_keys=2) + + # Second step: extract the intersection values from aux + # Since we've sorted the unique entries in arr1 and arr2, any place where + # adjacent entries are equal is a value of the intersection. + # relevant results here: + # num_results: number of values in the intersection of arr1 and arr2 + # vals: array where vals[:num_results] contains the intersection of arr1 and arr2, + # and vals[num_results:] contains the appropriate fill_value. + aux_mask = (aux[1:] == aux[:-1]) & mask[1:] + num_results = aux_mask.sum() + val_indices = nonzero(aux_mask, size=size, fill_value=aux.size)[0] + vals = aux.at[val_indices].get(mode='fill', fill_value=0) + if fill_value is None: + vals = where(arange(len(vals)) < num_results, vals, vals.max()) + vals = where(arange(len(vals)) < num_results, vals, vals.min()) + else: + vals = where(arange(len(vals)) < num_results, vals, fill_value) + + # Third step: extract the indices of the intersection values. + # This requires essentially unwinding aux_sort_indices and ind1/ind2 to find + # the appropriate list of indices from the original arrays. if return_indices: - return aux, mask, indices + arr1_indices = aux_sort_indices.at[val_indices].get(mode='fill', fill_value=arr1.size) + arr1_indices = where(arange(len(arr1_indices)) < num_results, arr1_indices, arr1.size) + arr2_indices = aux_sort_indices.at[val_indices + 1].get(mode='fill', fill_value=arr2.size) - num_unique1 + arr2_indices = where(arange(len(arr2_indices)) < num_results, arr2_indices, arr2.size) + if not assume_unique: + arr1_indices = ind1.at[arr1_indices].get(mode='fill', fill_value=ind1.size) + arr2_indices = ind2.at[arr2_indices].get(mode='fill', fill_value=ind2.size) + return vals, arr1_indices, arr2_indices else: - return aux, mask + return vals def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, - return_indices: bool = False) -> Array | tuple[Array, Array, Array]: + return_indices: bool = False, *, size: int | None = None, + fill_value: ArrayLike | None = None) -> Array | tuple[Array, Array, Array]: """Compute the set intersection of two 1D arrays. JAX implementation of :func:`numpy.intersect1d`. - Because the size of the output of ``intersect1d`` is data-dependent, the function is not - compatible with JIT or other JAX transformations. + Because the size of the output of ``intersect1d`` is data-dependent, the function + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.intersect1d`` to be used in such contexts. Args: ar1: first array of values to intersect. @@ -305,6 +370,12 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, arrays contain duplicates, the behavior is undefined. default: False. return_indices: If True, return arrays of indices specifying where the intersected values first appear in the input arrays. + size: if specified, return only the first ``size`` sorted elements. If there are fewer + elements than ``size`` indicates, the return value will be padded with ``fill_value``, + and returned indices will be padded with an out-of-bound index. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the smallest value + in the intersection. Returns: An array ``intersection``, or if ``return_indices=True``, a tuple of arrays @@ -353,35 +424,35 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, Array(True, dtype=bool) """ check_arraylike("intersect1d", ar1, ar2) - ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()") - ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()") + arr1, arr2 = promote_dtypes(ar1, ar2) + del ar1, ar2 + arr1 = ravel(arr1) + arr2 = ravel(arr2) + + if size is not None: + return _intersect1d_size(arr1, arr2, return_indices=return_indices, + size=size, fill_value=fill_value, assume_unique=assume_unique) if not assume_unique: if return_indices: - ar1, ind1 = unique(ar1, return_index=True) - ar2, ind2 = unique(ar2, return_index=True) + arr1, ind1 = unique(arr1, return_index=True) + arr2, ind2 = unique(arr2, return_index=True) else: - ar1 = unique(ar1) - ar2 = unique(ar2) - else: - ar1 = ravel(ar1) - ar2 = ravel(ar2) + arr1 = unique(arr1) + arr2 = unique(arr2) - if return_indices: - aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices) - else: - aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices) + aux, mask, aux_sort_indices = _intersect1d_sorted_mask(arr1, arr2, return_indices) int1d = aux[:-1][mask] if return_indices: - ar1_indices = aux_sort_indices[:-1][mask] - ar2_indices = aux_sort_indices[1:][mask] - np.size(ar1) + assert aux_sort_indices is not None + arr1_indices = aux_sort_indices[:-1][mask] + arr2_indices = aux_sort_indices[1:][mask] - np.size(arr1) if not assume_unique: - ar1_indices = ind1[ar1_indices] - ar2_indices = ind2[ar2_indices] - - return int1d, ar1_indices, ar2_indices + arr1_indices = ind1[arr1_indices] + arr2_indices = ind2[arr2_indices] + return int1d, arr1_indices, arr2_indices else: return int1d @@ -517,9 +588,9 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal JAX implementation of :func:`numpy.unique`. Because the size of the output of ``unique`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used in such contexts. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.unique`` to be used in such contexts. Args: ar: N-dimensional array from which unique values will be extracted. @@ -729,9 +800,9 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None, and `equal_nan` set to True. Because the size of the output of ``unique_all`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used in such contexts. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.unique`` to be used in such contexts. Args: x: N-dimensional array from which unique values will be extracted. @@ -810,9 +881,9 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None, :func:`jax.numpy.unique` with `return_counts` and `equal_nan` set to True. Because the size of the output of ``unique_counts`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used in such contexts. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.unique`` to be used in such contexts. Args: x: N-dimensional array from which unique values will be extracted. @@ -870,9 +941,9 @@ def unique_inverse(x: ArrayLike, /, *, size: int | None = None, :func:`jax.numpy.unique` with `return_inverse` and `equal_nan` set to True. Because the size of the output of ``unique_inverse`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used in such contexts. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.unique`` to be used in such contexts. Args: x: N-dimensional array from which unique values will be extracted. @@ -935,9 +1006,9 @@ def unique_values(x: ArrayLike, /, *, size: int | None = None, :func:`jax.numpy.unique` with `equal_nan` set to True. Because the size of the output of ``unique_values`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used in such contexts. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified statically + for ``jnp.unique`` to be used in such contexts. Args: x: N-dimensional array from which unique values will be extracted. diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index dfea8a8ddd74..583f6886e915 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -495,7 +495,8 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, right: ArrayLike | str | None = ..., period: ArrayLike | None = ...) -> Array: ... def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ..., - return_indices: builtins.bool = ...) -> Array | tuple[Array, Array, Array]: ... + return_indices: builtins.bool = ..., *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> Array | tuple[Array, Array, Array]: ... def invert(x: ArrayLike, /) -> Array: ... def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a4a7fa896ae4..c750c3021004 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -779,17 +779,54 @@ def np_fun(ar1, ar2): @jtu.sample_product( dtype1=[s for s in default_dtypes if s != jnp.bfloat16], dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, + shape1=[(), (5,), (2, 5)], + shape2=[(), (5,), (2, 5)], assume_unique=[False, True], return_indices=[False, True], + size=[None, 3, 5], + fill_value=[None, -1] ) def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, - return_indices): + return_indices, size, fill_value): rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + def args_maker(): + # Generate two arrays with overlapping values. + size1, size2 = math.prod(shape1), math.prod(shape2) + num_vals = max(size1, size2) + min(size1, size2) // 2 + vals = rng((num_vals,), 'int32') + arr1 = vals[:size1].astype(dtype1).reshape(shape1) + arr2 = vals[-size2:].astype(dtype2).reshape(shape2) + # if assume_unique is True, we need the results to contain unique values. + # This may lead to different shapes than requested, but ¯\_(ツ)_/¯ + if assume_unique: + arr1 = np.unique(arr1) + self.rng().shuffle(arr1) # inplace + arr1 = arr1.reshape(shape1) if arr1.shape == size1 else arr1 + arr2 = np.unique(arr2) + self.rng().shuffle(arr2) # inplace + arr2 = arr1.reshape(shape2) if arr2.shape == size2 else arr2 + return arr1, arr2 + + def jnp_fun(ar1, ar2): + return jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices, + size=size, fill_value=fill_value) + + def np_fun(ar1, ar2): + result = np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + def correct_size(x, fill_value): + if size is None or size == len(x): + return x + elif size < len(x): + return x[:size] + else: + if fill_value is None: + fill_value = x.min() + return np.pad(x, (0, size - len(x)), constant_values=fill_value) + if return_indices: + return tuple(correct_size(r, f) for r, f in zip(result, [fill_value, ar1.size, ar2.size])) + else: + return correct_size(result, fill_value) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) From 9e86416a325220062faa8030eee9783ec8b4c2e9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 10 Aug 2024 14:19:31 -0700 Subject: [PATCH 054/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/dfb2a8b4980f1b9ee01e4f6bcc23ad25403a2124. PiperOrigin-RevId: 661668331 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 888b96545278..ab4affe07816 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "46e205a0b6b38586a0d9edfb2ffecdcbd0b7590b" -XLA_SHA256 = "7a91ff40e4abbef76428a2175e580159d468429992d99c9e11a96e5870dc700e" +XLA_COMMIT = "dfb2a8b4980f1b9ee01e4f6bcc23ad25403a2124" +XLA_SHA256 = "255dc4ed80eb9f15a9723c65ea9f27faecbdd81ba012c767b69c2b1e6e8859ca" def repo(): tf_http_archive( From 96045043a47908f3fc4e2eb633b0e07762012981 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Sun, 11 Aug 2024 01:46:52 -0700 Subject: [PATCH 055/702] Move ir_attribute builder from extend.ffi to interpreters.mlir. While this function is currently only used for lowering FFI calls, it could be used most places where `ir.*Attr` objects are directly constructed. PiperOrigin-RevId: 661761712 --- jax/_src/extend/ffi.py | 26 +---------- jax/_src/interpreters/mlir.py | 88 ++++++++++++++++++++++++++++++++--- jax/interpreters/mlir.py | 1 + tests/extend_test.py | 11 ----- 4 files changed, 84 insertions(+), 42 deletions(-) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index df1c09efffc5..66af4f331d78 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -22,7 +22,6 @@ from jax._src import core from jax._src import dispatch -from jax._src import dtypes from jax._src import util from jax._src.callback import _check_shape_dtype, callback_batching_rule from jax._src.interpreters import ad @@ -32,7 +31,6 @@ from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray -import numpy as np map, unsafe_map = util.safe_map, map @@ -137,7 +135,7 @@ def _lowering( kwargs = dict(lowering_args) kwargs.setdefault("api_version", 4) kwargs["backend_config"] = dict( - backend_config or {}, **{k: _ir_attribute(v) for k, v in params.items()}) + backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()}) if "result_types" not in kwargs: kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] if operand_layouts is None: @@ -154,28 +152,6 @@ def _default_layouts(shapes: Iterable[Sequence[DimSize]]) -> list[list[DimSize]] return [list(reversed(range(len(shape)))) for shape in shapes] -def _ir_attribute(obj: Any) -> ir.Attribute: - # TODO(dfm): Similar functions exist in Pallas and Mosaic GPU. Perhaps these - # could be consolidated into mlir or similar. - if isinstance(obj, str): - return ir.StringAttr.get(obj) - elif isinstance(obj, bool): - return ir.BoolAttr.get(obj) - elif isinstance(obj, int): - return mlir.i64_attr(obj) - elif isinstance(obj, float): - return ir.FloatAttr.get_f64(obj) - elif hasattr(obj, "dtype"): - if not (dtypes.is_python_scalar(obj) or np.isscalar(obj)): - raise TypeError("Only scalar attributes are supported") - mlir_type = mlir.dtype_to_ir_type(obj.dtype) - if isinstance(mlir_type, ir.IntegerType): - return ir.IntegerAttr.get(mlir_type, obj) - elif isinstance(mlir_type, ir.FloatType): - return ir.FloatAttr.get(mlir_type, obj) - raise TypeError(f"Unsupported attribute type: {type(obj)}") - - def ffi_call( target_name: str, result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray], diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 84b1824bd2fe..4d15e803adfd 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -277,12 +277,7 @@ def ir_constant(val: Any) -> IrValues: raise TypeError(f"No constant handler for type: {type(val)}") def _numpy_array_constant(x: np.ndarray | np.generic) -> IrValues: - element_type = dtype_to_ir_type(x.dtype) - shape = x.shape - if x.dtype == np.bool_: - x = np.packbits(x, bitorder='little') # type: ignore - x = np.ascontiguousarray(x) - attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore + attr = _numpy_array_attribute(x) return hlo.constant(attr) @@ -344,6 +339,87 @@ def _token_constant_handler(val): return hlo.create_token() register_constant_handler(core.Token, _token_constant_handler) +# Attributes + +AttributeHandler = Callable[[Any], ir.Attribute] +_attribute_handlers: dict[type[Any], AttributeHandler] = {} + +def register_attribute_handler(type_: type[Any], handler_fun: AttributeHandler): + _attribute_handlers[type_] = handler_fun + +def get_attribute_handler(type_: type[Any]) -> AttributeHandler: + return _attribute_handlers[type_] + +def _numpy_scalar_attribute(val: Any) -> ir.Attribute: + mlir_type = dtype_to_ir_type(val.dtype) + if isinstance(mlir_type, ir.IntegerType): + return ir.IntegerAttr.get(mlir_type, val) + elif isinstance(mlir_type, ir.FloatType): + return ir.FloatAttr.get(mlir_type, val) + else: + raise TypeError(f"Unsupported scalar attribute type: {type(val)}") + +def _numpy_array_attribute(x: np.ndarray | np.generic) -> ir.Attribute: + element_type = dtype_to_ir_type(x.dtype) + shape = x.shape + if x.dtype == np.bool_: + x = np.packbits(x, bitorder='little') # type: ignore + x = np.ascontiguousarray(x) + return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore + +def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute: + if 0 in val.strides and val.size > 0: + raise ValueError( + "NumPy arrays with zero strides are not supported as MLIR attributes") + if val.dtype == dtypes.float0: + val = np.zeros(val.shape, dtype=np.bool_) + if dtypes.is_python_scalar(val) or np.isscalar(val): + return _numpy_scalar_attribute(val) + else: + return _numpy_array_attribute(val) + +register_attribute_handler(np.ndarray, _numpy_array_attribute_handler) + +for _scalar_type in [np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + np.float16, np.float32, np.float64, + np.complex64, np.complex128, + np.bool_, np.longlong, dtypes.bfloat16]: + register_attribute_handler(_scalar_type, _numpy_array_attribute_handler) # type: ignore + +def _python_scalar_attribute_handler(dtype, val): + return _numpy_scalar_attribute(np.array(val, dtype)) + +for ptype, dtype in dtypes.python_scalar_dtypes.items(): + register_attribute_handler( + ptype, partial(_python_scalar_attribute_handler, dtype)) + +register_attribute_handler(str, ir.StringAttr.get) +register_attribute_handler(bytes, ir.StringAttr.get) + +def _dict_attribute_handler(val: dict[str, Any]) -> ir.Attribute: + return ir.DictAttr.get({k: ir_attribute(v) for k, v in val.items()}) + +register_attribute_handler(dict, _dict_attribute_handler) + +def _sequence_attribute_handler(val: Sequence[Any]) -> ir.Attribute: + return ir.ArrayAttr.get([ir_attribute(v) for v in val]) + +register_attribute_handler(list, _sequence_attribute_handler) +register_attribute_handler(tuple, _sequence_attribute_handler) + +def ir_attribute(val: Any) -> ir.Attribute: + """Convert a Python value to an MLIR attribute.""" + for t in type(val).__mro__: + handler = _attribute_handlers.get(t) + if handler: + out = handler(val) + assert isinstance(out, ir.Attribute), (type(val), out) + return out + if hasattr(val, '__jax_array__'): + return ir_attribute(val.__jax_array__()) + raise TypeError(f"No attribute handler defined for type: {type(val)}") + # Source locations def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str: diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index edfc56ddd4fd..78b070614621 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -49,6 +49,7 @@ i32_attr as i32_attr, i64_attr as i64_attr, ir as ir, + ir_attribute as ir_attribute, ir_constant as ir_constant, ir_type_handlers as ir_type_handlers, jaxpr_subcomp as jaxpr_subcomp, diff --git a/tests/extend_test.py b/tests/extend_test.py index 45c689239a1d..098d52b10eeb 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -31,8 +31,6 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.interpreters import mlir -from jax._src.lib.mlir import ir -from jax._src.extend import ffi jax.config.parse_flags_with_absl() @@ -103,15 +101,6 @@ def testHeadersExist(self): for header in ["c_api.h", "api.h", "ffi.h"]: self.assertTrue(os.path.exists(os.path.join(base_dir, header))) - @parameterized.parameters( - [True, int(1), float(5.0), - np.int32(-5), np.float32(0.5)]) - def testIrAttribute(self, value): - with mlir.make_ir_context(), ir.Location.unknown(): - const = mlir.ir_constant(value) - attr = ffi._ir_attribute(value) - assert const.type.element_type == attr.type - @parameterized.parameters([True, 1, 5.0, "param", np.float32(0.5)]) def testParams(self, param): prim = core.Primitive("test_ffi") From ded5b5366bd9060b4160037f93e22c8a5a38b2ce Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sat, 10 Aug 2024 18:01:23 -0700 Subject: [PATCH 056/702] indent consistently in auto-parallelization and shard_map tutorials --- ...arrays_and_automatic_parallelization.ipynb | 20 +++++++++---------- ...ed_arrays_and_automatic_parallelization.md | 20 +++++++++---------- docs/notebooks/shard_map.ipynb | 20 +++++++++---------- docs/notebooks/shard_map.md | 20 +++++++++---------- 4 files changed, 40 insertions(+), 40 deletions(-) diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 09fbbd3a74c3..3d8c5b0203d5 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -1749,20 +1749,20 @@ "outputs": [], "source": [ "def init_layer(key, n_in, n_out):\n", - " k1, k2 = jax.random.split(key)\n", - " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", - " b = jax.random.normal(k2, (n_out,))\n", - " return W, b\n", + " k1, k2 = jax.random.split(key)\n", + " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", + " b = jax.random.normal(k2, (n_out,))\n", + " return W, b\n", "\n", "def init_model(key, layer_sizes, batch_size):\n", - " key, *keys = jax.random.split(key, len(layer_sizes))\n", - " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", + " key, *keys = jax.random.split(key, len(layer_sizes))\n", + " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", "\n", - " key, *keys = jax.random.split(key, 3)\n", - " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n", - " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n", + " key, *keys = jax.random.split(key, 3)\n", + " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n", + " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n", "\n", - " return params, (inputs, targets)\n", + " return params, (inputs, targets)\n", "\n", "layer_sizes = [784, 8192, 8192, 8192, 10]\n", "batch_size = 8192\n", diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index 43c14bc41da4..cb5d4602c055 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -631,20 +631,20 @@ gradfun = jax.jit(jax.grad(loss)) :id: R0x62AIa3vGU def init_layer(key, n_in, n_out): - k1, k2 = jax.random.split(key) - W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) - b = jax.random.normal(k2, (n_out,)) - return W, b + k1, k2 = jax.random.split(key) + W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) + b = jax.random.normal(k2, (n_out,)) + return W, b def init_model(key, layer_sizes, batch_size): - key, *keys = jax.random.split(key, len(layer_sizes)) - params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) + key, *keys = jax.random.split(key, len(layer_sizes)) + params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) - key, *keys = jax.random.split(key, 3) - inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) - targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) + key, *keys = jax.random.split(key, 3) + inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) + targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) - return params, (inputs, targets) + return params, (inputs, targets) layer_sizes = [784, 8192, 8192, 8192, 10] batch_size = 8192 diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 919690d230ab..b1792a6c039e 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -1483,20 +1483,20 @@ "outputs": [], "source": [ "def init_layer(key, n_in, n_out):\n", - " k1, k2 = jax.random.split(key)\n", - " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", - " b = jax.random.normal(k2, (n_out,))\n", - " return W, b\n", + " k1, k2 = jax.random.split(key)\n", + " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", + " b = jax.random.normal(k2, (n_out,))\n", + " return W, b\n", "\n", "def init(key, layer_sizes, batch_size):\n", - " key, *keys = jax.random.split(key, len(layer_sizes))\n", - " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", + " key, *keys = jax.random.split(key, len(layer_sizes))\n", + " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", "\n", - " key, *keys = jax.random.split(key, 3)\n", - " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n", - " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n", + " key, *keys = jax.random.split(key, 3)\n", + " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n", + " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n", "\n", - " return params, (inputs, targets)" + " return params, (inputs, targets)" ] }, { diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 6f9dfbb659e1..7469bf32b516 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -1035,20 +1035,20 @@ def loss(params, batch): ```{code-cell} def init_layer(key, n_in, n_out): - k1, k2 = jax.random.split(key) - W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) - b = jax.random.normal(k2, (n_out,)) - return W, b + k1, k2 = jax.random.split(key) + W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) + b = jax.random.normal(k2, (n_out,)) + return W, b def init(key, layer_sizes, batch_size): - key, *keys = jax.random.split(key, len(layer_sizes)) - params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) + key, *keys = jax.random.split(key, len(layer_sizes)) + params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) - key, *keys = jax.random.split(key, 3) - inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) - targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) + key, *keys = jax.random.split(key, 3) + inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) + targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) - return params, (inputs, targets) + return params, (inputs, targets) ``` ```{code-cell} From 371935cc10ea50e4545f43112591404c13a242e9 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sun, 11 Aug 2024 08:09:47 -0700 Subject: [PATCH 057/702] update README and several docs to typed RNG keys --- README.md | 2 +- benchmarks/api_benchmark.py | 4 ++-- benchmarks/sparse_benchmark.py | 2 +- cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb | 6 +++--- cloud_tpu_colabs/JAX_demo.ipynb | 10 +++++----- cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb | 2 +- cloud_tpu_colabs/Pmap_Cookbook.ipynb | 2 +- docs/Custom_Operation_for_GPUs.md | 2 +- docs/Custom_Operation_for_GPUs.py | 2 +- docs/_tutorials/advanced-autodiff.md | 6 +++--- docs/notebooks/shard_map.ipynb | 2 +- docs/notebooks/shard_map.md | 2 +- jax/_src/nn/initializers.py | 2 +- 13 files changed, 22 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index b19d7b9ff128..8f50aa42a125 100644 --- a/README.md +++ b/README.md @@ -273,7 +273,7 @@ from jax import random, pmap import jax.numpy as jnp # Create 8 random 5000 x 6000 matrices, one per GPU -keys = random.split(random.PRNGKey(0), 8) +keys = random.split(random.key(0), 8) mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) # Run a local matmul on each device in parallel (no data transfer) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index c68dab85dc8e..710ffb6d7cad 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -839,7 +839,7 @@ def f(x): out = out + y * x[0] return out - x = jax.random.normal(jax.random.PRNGKey(0), (2, 2)) + x = jax.random.normal(jax.random.key(0), (2, 2)) f(x).block_until_ready() # compile while state: f(x).block_until_ready() @@ -929,7 +929,7 @@ def jit_add_chain(state): def g(x, y): return lax.add(x, y) - x = jax.random.normal(jax.random.PRNGKey(0), (2, 2)) + x = jax.random.normal(jax.random.key(0), (2, 2)) while state: @jax.jit def f(x): diff --git a/benchmarks/sparse_benchmark.py b/benchmarks/sparse_benchmark.py index 65550b9cfee0..d6328881d5c6 100644 --- a/benchmarks/sparse_benchmark.py +++ b/benchmarks/sparse_benchmark.py @@ -109,7 +109,7 @@ def sparse_bcoo_todense_compile(state): def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): shape = (2000, 2000) nse = 10000 - key = jax.random.PRNGKey(1701) + key = jax.random.key(1701) mat = sparse.random_bcoo( key, nse=nse, diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index 1278bd01c91f..279aef3e9c65 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -38,7 +38,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "key, subkey = random.split(key)\n", "x = random.normal(key, (5000, 5000))\n", "\n", @@ -189,7 +189,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, ())\n", "\n", "print(grad(f)(x))\n", @@ -261,7 +261,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, (5000, 5000))" ] }, diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index 6a6993f44ed2..9acb1971c3b6 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -27,7 +27,7 @@ "import jax.numpy as jnp\n", "from jax import random\n", "\n", - "key = random.PRNGKey(0)" + "key = random.key(0)" ] }, { @@ -194,7 +194,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, ())\n", "\n", "print(grad(f)(x))\n", @@ -246,7 +246,7 @@ "\n", "layer_sizes = [5, 2, 3]\n", "\n", - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "key, *keys = random.split(key, len(layer_sizes))\n", "params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", "\n", @@ -351,7 +351,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, (5000, 5000))" ] }, @@ -754,7 +754,7 @@ }, "outputs": [], "source": [ - "keys = random.split(random.PRNGKey(0), 8)\n", + "keys = random.split(random.key(0), 8)\n", "mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)\n", "result = pmap(jnp.dot)(mats, mats)\n", "print(pmap(jnp.mean)(result))" diff --git a/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb b/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb index 1777d3d1ef79..84abf865851a 100644 --- a/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb +++ b/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb @@ -366,7 +366,7 @@ "\n", "# set some initial conditions for each replicate\n", "ys = jnp.zeros((N_dev, N, 3))\n", - "state0 = jr.uniform(jr.PRNGKey(1), \n", + "state0 = jr.uniform(jr.key(1), \n", " minval=-1., maxval=1.,\n", " shape=(N_dev, 3))\n", "state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))\n", diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index 4f4ba8c165a3..981f0a9e80a7 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -263,7 +263,7 @@ "from jax import random\n", "\n", "# create 8 random keys\n", - "keys = random.split(random.PRNGKey(0), 8)\n", + "keys = random.split(random.key(0), 8)\n", "# create a 5000 x 6000 matrix on each device by mapping over keys\n", "mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)\n", "# the stack of matrices is represented logically as a single array\n", diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index 28a81715428e..8490bd489608 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -306,7 +306,7 @@ per_core_batch_size=4 seq_len=512 emb_dim=512 x = jax.random.normal( - jax.random.PRNGKey(0), + jax.random.key(0), shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), dtype=jnp.bfloat16, ) diff --git a/docs/Custom_Operation_for_GPUs.py b/docs/Custom_Operation_for_GPUs.py index 4c0b4b6f7b38..31a00c49071e 100644 --- a/docs/Custom_Operation_for_GPUs.py +++ b/docs/Custom_Operation_for_GPUs.py @@ -479,7 +479,7 @@ def custom_p_rms_norm_bwd(eps, res, g): emb_dim = 512 assert jax.local_device_count() > 1, "Only 1 GPU, the example work, but it is this really what you want?" x = jax.random.normal( - jax.random.PRNGKey(0), + jax.random.key(0), shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), dtype=jnp.float16, ) diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index da95f96d8b25..d58a45d1ddf3 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -590,7 +590,7 @@ def vmap_mjp(f, x, M): outs, = vmap(vjp_fun)(M) return outs -key = random.PRNGKey(0) +key = random.key(0) num_covecs = 128 U = random.normal(key, (num_covecs,) + y.shape) @@ -714,7 +714,7 @@ Here's a check: ```{code-cell} def check(seed): - key = random.PRNGKey(seed) + key = random.key(seed) # random coeffs for u and v key, subkey = random.split(key) @@ -768,7 +768,7 @@ Here's a check of the VJP rules: ```{code-cell} def check(seed): - key = random.PRNGKey(seed) + key = random.key(seed) # random coeffs for u and v key, subkey = random.split(key) diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 919690d230ab..6fbe67f4d05f 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -1509,7 +1509,7 @@ "layer_sizes = [784, 128, 128, 128, 128, 128, 8]\n", "batch_size = 32\n", "\n", - "params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)" + "params, batch = init(jax.random.key(0), layer_sizes, batch_size)" ] }, { diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 6f9dfbb659e1..389b62c8a0e6 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -1055,7 +1055,7 @@ def init(key, layer_sizes, batch_size): layer_sizes = [784, 128, 128, 128, 128, 128, 8] batch_size = 32 -params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size) +params, batch = init(jax.random.key(0), layer_sizes, batch_size) ``` Compare these examples with the purely [automatic partitioning examples in the diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 7d228e4beef4..eb1bb1609bbf 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -184,7 +184,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.truncated_normal(5.0) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 2.9047365, 5.2338114, 5.29852 ], [-3.836303 , -4.192359 , 0.6022964]], dtype=float32) """ From 4f8f66f10bbe6b938c3fc4c1262cd6de5ddcb57e Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Sun, 11 Aug 2024 12:33:14 -0700 Subject: [PATCH 058/702] Add more complete tests for attribute serialization when lowering an FFI call. PiperOrigin-RevId: 661849681 --- tests/extend_test.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/extend_test.py b/tests/extend_test.py index 098d52b10eeb..d43063c7a0a1 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -25,7 +25,6 @@ from jax._src import abstract_arrays from jax._src import api -from jax._src import core from jax._src import linear_util from jax._src import prng from jax._src import test_util as jtu @@ -101,16 +100,34 @@ def testHeadersExist(self): for header in ["c_api.h", "api.h", "ffi.h"]: self.assertTrue(os.path.exists(os.path.join(base_dir, header))) - @parameterized.parameters([True, 1, 5.0, "param", np.float32(0.5)]) - def testParams(self, param): - prim = core.Primitive("test_ffi") - prim.def_abstract_eval(lambda *args, **kwargs: args[0]) - mlir.register_lowering(prim, jex.ffi.ffi_lowering("test_ffi")) - - # TODO(dfm): Currently testing that lowering works with different types of - # parameters, but we should probably actually check the emitted HLO. - func = jax.jit(lambda *args: prim.bind(*args, param=param)) - func.lower(jnp.linspace(0, 5, 10)) + @parameterized.parameters([ + (True, mlir.ir.BoolAttr.get), + (1, mlir.i64_attr), + (5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)), + ("param", mlir.ir.StringAttr.get), + (np.float32(0.5), + lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)), + ]) + def testParams(self, param, expected_builder): + def fun(x): + return jex.ffi.ffi_call("test_ffi", x, x, param=param) + + # Here we inspect the lowered IR to test that the parameter has been + # serialized with the appropriate type. + module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo") + for func in module.body.operations: + for block in func.body.blocks: + for op in block.operations: + if op.OPERATION_NAME == "stablehlo.custom_call": + config = op.attributes["mhlo.backend_config"] + self.assertIsInstance(config, mlir.ir.DictAttr) + self.assertIn("param", config) + with mlir.make_ir_context(), mlir.ir.Location.unknown(): + expected = expected_builder(param) + self.assertEqual(type(config["param"]), type(expected)) + self.assertTrue(expected.type.isinstance(config["param"].type)) + return + self.fail("No custom_call found in the lowered IR") @jtu.sample_product( shape=[(1,), (4,), (5,)], From dd535d88a7858df6000b71804a0e72cb553a268a Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sun, 11 Aug 2024 12:41:50 -0700 Subject: [PATCH 059/702] emphasize typed over legacy RNG keys in `random` module docs Update both docstrings and move the `PRNGKey` function listing lower in the API reference. --- docs/jax.random.rst | 2 +- jax/_src/random.py | 26 ++++++++++++++++++-------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/docs/jax.random.rst b/docs/jax.random.rst index 9d6369d2d2b1..6c5427c05e66 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -12,13 +12,13 @@ Key Creation & Manipulation .. autosummary:: :toctree: _autosummary - PRNGKey key key_data wrap_key_data fold_in split clone + PRNGKey Random Samplers ~~~~~~~~~~~~~~~ diff --git a/jax/_src/random.py b/jax/_src/random.py index 113bcc450100..6105d56f9148 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -197,9 +197,10 @@ def key(seed: int | ArrayLike, *, impl: PRNGSpecDesc | None = None) -> KeyArray: """Create a pseudo-random number generator (PRNG) key given an integer seed. - The result is a scalar array with a key that indicates the default PRNG - implementation, as determined by the optional ``impl`` argument or, - otherwise, by the ``jax_default_prng_impl`` config flag. + The result is a scalar array containing a key, whose dtype indicates + the default PRNG implementation, as determined by the optional + ``impl`` argument or, otherwise, by the ``jax_default_prng_impl`` + config flag at the time when this function is called. Args: seed: a 64- or 32-bit integer used as the value of the key. @@ -214,11 +215,20 @@ def key(seed: int | ArrayLike, *, def PRNGKey(seed: int | ArrayLike, *, impl: PRNGSpecDesc | None = None) -> KeyArray: - """Create a pseudo-random number generator (PRNG) key given an integer seed. - - The resulting key carries the default PRNG implementation, as - determined by the optional ``impl`` argument or, otherwise, by the - ``jax_default_prng_impl`` config flag. + """Create a legacy PRNG key given an integer seed. + + This function produces old-style legacy PRNG keys, which are arrays + of dtype ``uint32``. For more, see the note in the `PRNG keys + `_ + section. When possible, :func:`jax.random.key` is recommended for + use instead. + + The resulting key does not carry a PRNG implementation. The returned + key matches the implementation given by the optional ``impl`` + argument or, otherwise, determined by the ``jax_default_prng_impl`` + config flag. Callers must ensure that same implementation is set as + the default when passing this key as an argument to other functions + (such as ``jax.random.split`` and ``jax.random.normal``). Args: seed: a 64- or 32-bit integer used as the value of the key. From c54ffd41bca1c3c6635c2a89391a43e33c0b1e15 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sun, 11 Aug 2024 12:44:50 -0700 Subject: [PATCH 060/702] in `dot` docstring, format and link to `dot_general` --- jax/_src/lax/lax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c0ed971770b4..3924d355e591 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -721,7 +721,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, `_ operator. - For more general contraction, see the `dot_general` operator. + For more general contraction, see the :func:`jax.lax.dot_general` operator. Args: lhs: an array of dimension 1 or 2. From 3c7bd54c54daaaf53ff5afa6db7d14c825ca5713 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 11 Aug 2024 15:16:10 -0700 Subject: [PATCH 061/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a6fc99fadc72358a5d79dd3ece66340ac5e45ad7. PiperOrigin-RevId: 661883548 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ab4affe07816..db188b0a9b68 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "dfb2a8b4980f1b9ee01e4f6bcc23ad25403a2124" -XLA_SHA256 = "255dc4ed80eb9f15a9723c65ea9f27faecbdd81ba012c767b69c2b1e6e8859ca" +XLA_COMMIT = "a6fc99fadc72358a5d79dd3ece66340ac5e45ad7" +XLA_SHA256 = "da0b6beeb418933b380c439f34416b6635931809a4c2dc9a99eceb6ff35363fe" def repo(): tf_http_archive( From 4b7c198a1c99049e2b41c18a9e06a5173e4e54f2 Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Thu, 25 Jul 2024 19:02:33 +0000 Subject: [PATCH 062/702] [ROCm]: Add get_arch_details for triton kernel call --- jaxlib/gpu/triton.cc | 12 ++++++++++++ jaxlib/gpu_triton.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 1274eeba466b..500034af3ebb 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -132,6 +132,18 @@ NB_MODULE(_triton, m) { return major * 10 + minor; })); + m.def( + "get_arch_details", + ValueOrThrowWrapper([](int device) -> absl::StatusOr { +#ifdef JAX_GPU_HIP + hipDeviceProp_t prop; + hipGetDeviceProperties(&prop, 0); + return prop.gcnArchName; +#else + return absl::UnimplementedError("Not a HIP GPU"); +#endif + })); + m.def("get_serialized_metadata", ValueOrThrowWrapper( [](nb::bytes opaque) -> absl::StatusOr { diff --git a/jaxlib/gpu_triton.py b/jaxlib/gpu_triton.py index f2d37bfec03d..77f315e5b4b1 100644 --- a/jaxlib/gpu_triton.py +++ b/jaxlib/gpu_triton.py @@ -35,6 +35,7 @@ create_array_parameter = _cuda_triton.create_array_parameter create_scalar_parameter = _cuda_triton.create_scalar_parameter get_compute_capability = _cuda_triton.get_compute_capability + get_arch_details = _cuda_triton.get_arch_details get_custom_call = _cuda_triton.get_custom_call get_serialized_metadata = _cuda_triton.get_serialized_metadata @@ -58,5 +59,6 @@ create_array_parameter = _hip_triton.create_array_parameter create_scalar_parameter = _hip_triton.create_scalar_parameter get_compute_capability = _hip_triton.get_compute_capability + get_arch_details = _hip_triton.get_arch_details get_custom_call = _hip_triton.get_custom_call get_serialized_metadata = _hip_triton.get_serialized_metadata From ae5b4284d5b5b817fe532a38f08c62ba50409247 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 12 Aug 2024 03:08:13 -0700 Subject: [PATCH 063/702] Make `ffi_call` tests backwards compatible with the released jaxlib. PiperOrigin-RevId: 662017095 --- tests/extend_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/extend_test.py b/tests/extend_test.py index d43063c7a0a1..3194a3ef9073 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -173,6 +173,8 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True) dtype=pivots.dtype, ), pivots, + # TODO(b/358275922): Remove this after jaxlib v0.4.32 is released. + permutation_size=np.int32(permutation_size), vectorized=vectorized, ) From 3c014a4c27fd341a2fc1c62f372b7020fece05fb Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 12 Aug 2024 03:38:55 -0700 Subject: [PATCH 064/702] Add support for shape polymorphism with lu_pivots_to_permutation. This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used. PiperOrigin-RevId: 662024940 --- jax/_src/export/_export.py | 3 +- .../cuda_lu_pivots_to_permutation.py | 55 +++++++++++++++++++ jax/_src/lax/linalg.py | 25 +++++---- jaxlib/gpu_linalg.py | 9 +-- tests/export_back_compat_test.py | 13 +++++ tests/shape_poly_test.py | 35 ++++++++++++ 6 files changed, 124 insertions(+), 16 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 4ee5dca86455..10b6d09a2d91 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -960,7 +960,8 @@ def _check_lowering(lowering) -> None: "lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf", # schur on CPU "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", - # # lu on GPU + # lu on GPU + "cu_lu_pivots_to_permutation", # "cublas_getrf_batched", "cusolver_getrf", # "hipblas_getrf_batched", "hipsolver_getrf", # TODO(b/357034884): This can be added once the mimimum version of jaxlib diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py new file mode 100644 index 000000000000..12285a45b77a --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from numpy import array, int32 + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_08 = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cu_lu_pivots_to_permutation'], + serialized_date=datetime.date(2024, 8, 8), + inputs=(), + expected_outputs=(array([[[0, 1, 2, 3, 4, 5, 6, 7], + [4, 5, 6, 7, 0, 1, 2, 3], + [0, 1, 2, 3, 4, 5, 6, 7]], + + [[0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7]]], dtype=int32),), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<24xi32> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<24xi32>) -> tensor<2x3x4xi32> loc(#loc5) + %c = stablehlo.constant dense<2> : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc6) + %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {permutation_size = 8 : i32}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) + return %2 : tensor<2x3x8xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":348:11) +#loc4 = loc("jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1d\x05\x01\x03\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\xa7}\x17\x01Q\x0f\x07\x0b\x0b\x0f\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03-\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f///\x0b\x0b\x0b\x13\x0b\x0fo\x01\x05\x0b\x0f\x03\x13\x0f\x07\x1b\x07\x13\x13\x1b\x13\x07\x02Z\x04\x1d57\x1f\x05\x13\x05\x15\x11\x03\x05\x05\x17\x03\t\x0f\x11\x13\t\x15\t\x0b\x17\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1bQ\x1dW\x1fY\x0bc!e\x05!\x05#\x05%\x05'\x03\x03%g\x05)\x1d)+\x05+\x17\x05n\x055\x1d/1\x05-\x17\x05n\x05\x1d\x03\x03\x07i\x05/\x17\x05r\x05\x17\x03\x03\x07k\x03\x03\x07m\x03\x13?oASCqEQGsIuKUMQOU\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x01\x1dC\x03\x03{#\r\x03\x03[\r\x05]S_a\x1dE\x1dG\x1dI\x1dK\x1dM\x13\x0b\x01\x1f\x05\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dO\x05\x01\r\x03wy\x1dQ\x13\x07!\x1f\x131\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x0b\x1b)\x07\t\r!\x07\x1d\x11\x01\x03\t)\x03a\x07)\x07\t\r\x11\x07)\x03\r\x15\x13\x04{\x05\x01\x11\x03\r\x07\x03\x01\x05\x05\x11\x03\x19\x07\x03\r\x1d\x07\x03'#\x03\x0f\t\x06-\x03\x11\x03\x01\x03\x03\x013\x03\x05\x03\x03\x019\x03\x05\x03\x03\x01;\x03\x05\x0b\x07\x01=\x03\t\x03\x03\r\x04\x03\x03\x0b\x06\x03\x01\x05\x01\x00f\x0cS#9\x0f\x0b\x11#!\x03\x1f/!)!)#\x1f\x19\x8b\x8b\x85\x1f\x1f\x15\x1d\x15\x1b%)9\x13\ri\x15\x1f\x17\x11\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00\x00jax.result_info\x00mhlo.layout_mode\x00default\x00main\x00public\x00cu_lu_pivots_to_permutation\x00permutation_size\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index cecfa253ec11..c7ef1462361f 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1159,8 +1159,9 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): len(batch_dims)) if m == 0: return permutation - result, _ = lax.fori_loop(np.array(0, np.int32), np.array(k, np.int32), - _lu_pivots_body_fn, (permutation, swaps)) + upper = np.array(k, np.int32) if is_constant_dim(k) else k + result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn, + (permutation, swaps)) return result @@ -1171,19 +1172,14 @@ def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size): raise ValueError( 'Argument to lu_pivots_to_permutation must have rank >= 1 and dtype ' 'int32. Got shape={} and dtype={}'.format(pivots.shape, pivots.dtype)) - pivots_size = pivots.shape[-1] - if permutation_size < pivots_size: + if not permutation_size >= pivots_size: raise ValueError( 'Output permutation size {} has to exceed the trailing dimension of ' 'the pivots. Got pivots size {}'.format(permutation_size, pivots_size)) - - batch_dims = pivots.shape[:-1] - permutations = pivots.update(shape=batch_dims + (permutation_size,)) + return pivots.update(shape=(*pivots.shape[:-1], permutation_size)) else: - permutations = pivots - - return permutations + return pivots def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *, @@ -1196,7 +1192,14 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *, def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *, permutation_size): - return lowering(pivots, permutation_size=permutation_size) + # TODO(danfm): Remove once jaxlib 0.4.32 is the minimum version. + if jaxlib_version >= (0, 4, 32): + pivots_aval, = ctx.avals_in + pivots_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, pivots_aval.shape) + kwargs = dict(pivots_shape_vals=pivots_shape_vals) + else: + kwargs = {} + return lowering(pivots, permutation_size=permutation_size, **kwargs) lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation') diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 32d31e7206a4..f392cc690046 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -61,16 +61,17 @@ _prod = lambda xs: functools.reduce(operator.mul, xs, 1) -def _lu_pivots_to_permutation_hlo(platform, pivots, *, permutation_size): +def _lu_pivots_to_permutation_hlo(platform, pivots, *, permutation_size, + pivots_shape_vals): """Kernel for the transformation of pivots to permutations on GPU.""" typ = ir.RankedTensorType(pivots.type) - dims = typ.shape i32_type = ir.IntegerType.get_signless(32) assert typ.element_type == i32_type, typ + assert len(pivots_shape_vals) >= 1 - pivots_layout = tuple(range(len(dims) - 1, -1, -1)) + pivots_layout = tuple(range(len(pivots_shape_vals) - 1, -1, -1)) permutations_layout = pivots_layout - permutations_dims = (*dims[:-1], permutation_size) + permutations_dims = (*pivots_shape_vals[:-1], permutation_size) result_types, result_shapes = mk_result_types_and_shapes( [(permutations_dims, i32_type)]) return custom_call( diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 045f2e233465..7118495efa9b 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -44,6 +44,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32 +from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK @@ -124,6 +125,7 @@ def test_custom_call_coverage(self): cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15, cuda_threefry2x32.data_2024_07_30, cpu_lu_lapack_getrf.data_2023_06_14, + cuda_lu_pivots_to_permutation.data_2024_08_08, cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17, rocm_qr_hipsolver_geqrf.data_2024_08_05, rocm_eigh_hipsolver_syev.data_2024_08_05, @@ -342,6 +344,17 @@ def test_tpu_Eigh(self): self.run_one_test(func, data, rtol=1e-3, check_results=partial(self.check_eigh_results, operand)) + @staticmethod + def lu_pivots_to_permutation_harness(shape): + operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=np.int32), shape) + return lax.linalg.lu_pivots_to_permutation(operand, permutation_size=8) + + def test_cuda_lu_pivots_to_permutation(self): + shape = (2, 3, 4) + func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape) + data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08) + self.run_one_test(func, data) + @staticmethod def qr_harness(shape, dtype): # In order to keep inputs small, we construct the input programmatically diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 404c9ce0c4fd..cff9eb4d6e8e 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2731,6 +2731,41 @@ def test_vmap_error(self): ((2, 3, 8, 4), "b1, b2, ...", True), ] ], + [ + PolyHarness( + "lu_pivots_to_permutation", + f"shape={jtu.format_shape_dtype_string(shape, np.int32)}_poly={poly}_{permutation_size=}", + lax.linalg.lu_pivots_to_permutation, + arg_descriptors=[RandArg(shape, np.int32), StaticArg(permutation_size)], + polymorphic_shapes=[poly], + symbolic_constraints=constraints, + ) + for shape, poly, permutation_size, constraints in [ + ((4,), None, 8, ()), + ((2, 3, 4), "b1, b2, ...", 8, ()), + ((4,), "b", 8, ["b <= 8"]), + ((2, 3, 4), "b1, b2, b3", 8, ["b3 <= 8"]), + ] + ], + [ + # Tracing errors are only thrown when the trailing dimension of pivots + # is static. Otherwise, the error is thrown at runtime. + PolyHarness( + "lu_pivots_to_permutation_error", + f"shape={jtu.format_shape_dtype_string(shape, np.int32)}_poly={poly}_{permutation_size=}", + lax.linalg.lu_pivots_to_permutation, + arg_descriptors=[RandArg(shape, np.int32), StaticArg(permutation_size)], + polymorphic_shapes=[poly], + symbolic_constraints=constraints, + expect_error=(ValueError, "Output permutation size"), + ) + for shape, poly, permutation_size, constraints in [ + ((4,), None, 3, ()), + ((2, 3, 4), "b1, b2, ...", 3, ()), + ((4,), "b", 8, ["b >= 9"]), + ((2, 3, 4), "b1, b2, b3", 8, ["b3 >= 9"]), + ] + ], [ # The random primitive tests, with threefry (both partitionable and # non-partitionable), and unsafe_rbg. From c9142cbe753c876503098ba9a7da8a166630971a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 12 Aug 2024 12:49:18 +0100 Subject: [PATCH 065/702] Collapsed a few unnecessary ``if TYPE_CHECKING`` blocks --- jax/_src/lax/lax.py | 96 +++++++++++++++++---------------------------- jax/_src/util.py | 69 +++++++++++++------------------- 2 files changed, 64 insertions(+), 101 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 3924d355e591..1e1f2d48c538 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -22,7 +22,7 @@ import itertools import math import operator -from typing import Any, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING +from typing import Any, TypeVar, Union, cast as type_cast, overload import warnings import numpy as np @@ -634,64 +634,42 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: _precision_strings: dict[Any, Precision] = {} -# TODO(b/333851820): pytype does not properly handle _missing_ in enums. -# We work around that by defining `Precision` as a normal class. -if TYPE_CHECKING: - - class Precision: - DEFAULT: ClassVar[Precision] - HIGH: ClassVar[Precision] - HIGHEST: ClassVar[Precision] - - def __new__(cls, value: Precision | int | str | None) -> Precision: - raise NotImplementedError - - @property - def name(self) -> str: - raise NotImplementedError - - @property - def value(self) -> int: - raise NotImplementedError - -else: - - class Precision(enum.Enum): - """Precision enum for lax matrix multiply related functions. - - The device-dependent `precision` argument to JAX functions generally - controls the tradeoff between speed and accuracy for array computations on - accelerator backends, (i.e. TPU and GPU). Has no impact on CPU backends. - This only has an effect on float32 computations, and does not affect the - input/output datatypes. Members are: - - DEFAULT: - Fastest mode, but least accurate. On TPU: performs float32 computations in - bfloat16. On GPU: uses tensorfloat32 if available (e.g. on A100 and H100 - GPUs), otherwise standard float32 (e.g. on V100 GPUs). Aliases: - ``'default'``, ``'fastest'``. - HIGH: - Slower but more accurate. On TPU: performs float32 computations in 3 - bfloat16 passes. On GPU: uses tensorfloat32 where available, otherwise - float32. Aliases: ``'high'``.. - HIGHEST: - Slowest but most accurate. On TPU: performs float32 computations in 6 - bfloat16. Aliases: ``'highest'``. On GPU: uses float32. - """ - - DEFAULT = 0 - HIGH = 1 - HIGHEST = 2 - - @classmethod - def _missing_(cls, value: object) -> Precision | None: - return _precision_strings.get(value) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}.{self.name}' - - def __str__(self) -> str: - return self.name +class Precision(enum.Enum): + """Precision enum for lax matrix multiply related functions. + + The device-dependent `precision` argument to JAX functions generally + controls the tradeoff between speed and accuracy for array computations on + accelerator backends, (i.e. TPU and GPU). Has no impact on CPU backends. + This only has an effect on float32 computations, and does not affect the + input/output datatypes. Members are: + + DEFAULT: + Fastest mode, but least accurate. On TPU: performs float32 computations in + bfloat16. On GPU: uses tensorfloat32 if available (e.g. on A100 and H100 + GPUs), otherwise standard float32 (e.g. on V100 GPUs). Aliases: + ``'default'``, ``'fastest'``. + HIGH: + Slower but more accurate. On TPU: performs float32 computations in 3 + bfloat16 passes. On GPU: uses tensorfloat32 where available, otherwise + float32. Aliases: ``'high'``.. + HIGHEST: + Slowest but most accurate. On TPU: performs float32 computations in 6 + bfloat16. Aliases: ``'highest'``. On GPU: uses float32. + """ + + DEFAULT = 0 + HIGH = 1 + HIGHEST = 2 + + @classmethod + def _missing_(cls, value: object) -> Precision | None: + return _precision_strings.get(value) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}.{self.name}' + + def __str__(self) -> str: + return self.name _precision_strings['highest'] = Precision.HIGHEST diff --git a/jax/_src/util.py b/jax/_src/util.py index 5174b21c2323..fce342c493ed 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -611,55 +611,40 @@ def wrapper(func: T) -> T: return wrapper -if TYPE_CHECKING: - def use_cpp_class(cpp_cls: Any) -> Callable[[T], T]: - def wrapper(cls: T) -> T: - return cls - return wrapper +def use_cpp_class(cpp_cls: type[Any]) -> Callable[[type[T]], type[T]]: + """A decorator replacing a Python class with its C++ version at runtime.""" - def use_cpp_method(is_enabled: bool = True) -> Callable[[T], T]: - def wrapper(cls: T) -> T: + def wrapper(cls): + if cpp_cls is None: return cls - return wrapper -else: - def use_cpp_class(cpp_cls): - """A helper decorator to replace a python class with its C++ version""" + exclude_methods = {'__module__', '__dict__', '__doc__'} - def wrapper(cls): - if cpp_cls is None: - return cls + originals = {} + for attr_name, attr in cls.__dict__.items(): + if attr_name not in exclude_methods: + if hasattr(_original_func(attr), "_use_cpp"): + originals[attr_name] = attr + else: + setattr(cpp_cls, attr_name, attr) - exclude_methods = {'__module__', '__dict__', '__doc__'} + cpp_cls.__doc__ = cls.__doc__ + # TODO(pschuh): Remove once fastpath is gone. + cpp_cls._original_py_fns = originals + return cpp_cls - originals = {} - for attr_name, attr in cls.__dict__.items(): - if attr_name not in exclude_methods: - if hasattr(_original_func(attr), "_use_cpp"): - originals[attr_name] = attr - else: - setattr(cpp_cls, attr_name, attr) - - cpp_cls.__doc__ = cls.__doc__ - # TODO(pschuh): Remove once fastpath is gone. - cpp_cls._original_py_fns = originals - return cpp_cls - - return wrapper + return wrapper - def use_cpp_method(is_enabled=True): - """A helper decorator to exclude methods from the set that are forwarded to C++ class""" - def decorator(f): - if is_enabled: - original_func = _original_func(f) - original_func._use_cpp = True - return f - - if not isinstance(is_enabled, bool): - raise TypeError( - "Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)" - ) - return decorator +def use_cpp_method(is_enabled: bool = True) -> Callable[[T], T]: + """A decorator excluding methods from the set that are forwarded to C++ class.""" + if not isinstance(is_enabled, bool): + raise TypeError("``is_enabled`` must be a bool") + def decorator(f): + if is_enabled: + original_func = _original_func(f) + original_func._use_cpp = True + return f + return decorator try: From 7f680aaab8beab48f21056a172d1572a7067a63b Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 12 Aug 2024 05:08:56 -0700 Subject: [PATCH 066/702] [pallas] Move ops_test.py from jax_triton to jax/pallas The `jax_triton/ops_test.py` has over time accumulated many tests that are in fact platform-independent tests. Furthermore, those tests were only Google-internal, and they can be external as well. This moves test coverage for Pallas from the jax_triton package to the Pallas core package. A small number of the tests were deleted, because they were already present in Pallas, e.g., tests in `jax_triton/ops_test.py:ControlFlowTest`, and tests for unary and binary ops in `jax_triton/ops_test.py:OpsTest`. The other tests were distributed to different files in the Pallas repo, according to their purpose: * tests in `jax_triton/ops_test.py:PrettyPrintingTest` are moved to `tpu_pallas_test.py::PrettyPrintingTest` * tests in `jax_triton/ops_test.py::IndexingTest` are appended to `indexing_test.py::IndexingTest`; some other indexing tests from `jax_triton/ops_test.py::LoadStoreTest` are also moved there. * some tests in `jax_triton/ops_test.py:OpsTest` are moved to `ops_test.py::OpsTest`. * some tests for TPU specific ops in `jax_triton/ops_test.py:OpsTest` are moved to a new test file `tpu_ops_tests.py` Some of this required adding sharding and hypothesis support to `ops_test.py`, and adding TPU versions of `indexing_test.py`. PiperOrigin-RevId: 662045774 --- tests/pallas/BUILD | 25 ++- tests/pallas/indexing_test.py | 317 ++++++++++++++++++++++++++- tests/pallas/ops_test.py | 373 +++++++++++++++++++++++++++++++- tests/pallas/tpu_ops_test.py | 183 ++++++++++++++++ tests/pallas/tpu_pallas_test.py | 37 ++++ 5 files changed, 929 insertions(+), 6 deletions(-) create mode 100644 tests/pallas/tpu_ops_test.py diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 2076519f1af3..d2a7ec56db07 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -84,12 +84,17 @@ jax_test( "gpu_a100_x32", "gpu_h100_x32", ], + shard_count = { + "cpu": 4, + "gpu": 4, + "tpu": 4, + }, deps = [ "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) jax_test( @@ -99,10 +104,10 @@ jax_test( ], disable_backends = [ "gpu", - "tpu", ], deps = [ "//jax:pallas", + "//jax:pallas_tpu", ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) @@ -317,6 +322,22 @@ jax_test( ], ) +jax_test( + name = "tpu_ops_test", + srcs = [ + "tpu_ops_test.py", + ], + disable_backends = [ + "gpu", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), +) + jax_test( name = "tpu_pallas_distributed_test", srcs = ["tpu_pallas_distributed_test.py"], diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 11402ed99741..696f12b0ed72 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -15,12 +15,13 @@ """Tests for Pallas indexing logic and abstractions.""" from __future__ import annotations - +import sys import unittest from absl.testing import absltest from absl.testing import parameterized import jax +from jax import random from jax._src import test_util as jtu from jax._src import util from jax._src.state import indexing @@ -28,6 +29,11 @@ import jax.numpy as jnp from jax.experimental import pallas as pl +if sys.platform != "win32": + from jax.experimental.pallas import tpu as pltpu +else: + pltpu = None + try: import hypothesis as hp except (ModuleNotFoundError, ImportError): @@ -46,6 +52,26 @@ ds = indexing.ds +_INDEXING_TEST_CASES = [ + ((4, 8, 128), (...,), (4, 8, 128)), + ((4, 8, 128), (0,), (8, 128)), + ((4, 8, 128), (pl.ds(1, 2),), (2, 8, 128)), + ((4, 8, 128), (pl.ds(2, 2),), (2, 8, 128)), + ((4, 8, 128), (pl.ds(2, 2), 0), (8, 128)), + ((4, 8, 128), (pl.ds(2, 2), 1), (8, 128)), + ((4, 8, 128), (slice(2, 4), 1), (8, 128)), + ((4, 8, 128), (slice(2, 4), slice(0, 1), 0), (8, 128)), + ((4, 8, 128), ((0, pl.ds(0, 8), pl.ds(0, 128)), ...), (8, 128)), + ((4, 8, 128), (..., (0, pl.ds(0, 8), pl.ds(0, 128)), ...), (8, 128)), +] + + +def _maybe_ds_to_slice(x: int | slice | indexing.Slice) -> int | slice: + if isinstance(x, indexing.Slice): + return slice(x.start, x.start + x.size) + return x + + def int_indexer_strategy(dim) -> hps.SearchStrategy[int]: return hps.integers(min_value=np.iinfo(np.int32).min, max_value=dim - 1) @@ -88,7 +114,23 @@ def nd_indexer_strategy(draw, shape) -> NDIndexer: return NDIndexer.from_indices_shape(indices, shape) +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if not self.INTERPRET: + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only interpret mode supported on non-TPU") + + super().setUp() + + @classmethod + def pallas_call(cls, *args, **kwargs): + return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) + + class IndexerTest(jtu.JaxTestCase): + """These are unit tests for the indexer logic, not using pallas_call.""" def test_simple_ndindexer(self): indices = (0, 0) @@ -206,7 +248,11 @@ def test_ndindexer(self, data): indexer.get_indexer_shape()) +class IndexerOpsTest(PallasBaseTest): + def test_multi_indexing_interpreter_only(self): + if not self.INTERPRET: + self.skipTest("Only supported in interpret mode") # Interpreter only test! YMMV actually compiling this. def permute(left, right, left_out_ref, right_out_ref): left_out = jnp.zeros_like(left) @@ -254,6 +300,8 @@ def invoke_permutes(x_ref, y_ref, x_out_ref, y_out_ref): )(x, y) def test_ellipsis_indexing_iterpret_only(self): + if not self.INTERPRET: + self.skipTest("Only supported in interpret mode") # Interpreter only test! YMMV actually compiling this. def permute_columns_in_row_kernel(left, right, new_left, new_right): shape = left.shape @@ -296,18 +344,281 @@ def permute_columns_in_row_kernel(left, right, new_left, new_right): interpret=True, )(left, right) - import numpy as np # noqa: F811 left_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) right_np = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float32) left_out_np = left_np.copy() right_out_np = right_np.copy() - permute_columns_in_row_kernel(left_np, right_np, left_out_np, right_out_np) np.testing.assert_array_equal(left_out_np, left_out) np.testing.assert_array_equal(right_out_np, right_out) + @hp.given(hps.data()) + def test_vmap_nd_indexing(self, data): + self.skipTest("TODO(necula): enable this test; was in jax_triton.") + vmap_shape = data.draw(hnp.array_shapes(min_dims=1, max_dims=3, min_side=2), + label="vmap_shape") + el_shape = data.draw(hnp.array_shapes(min_dims=2), label="el_shape") + # TODO(sharadmv,apaszke): enable rank 0 and rank 1 Refs + # hp.assume(len(el_shape) >= 2) + nd_indexer = data.draw(nd_indexer_strategy(el_shape), label="nd_indexer") + expected_shape = jax.eval_shape(lambda x: x[nd_indexer], + jax.ShapeDtypeStruct(el_shape, jnp.float32)) + + ref = lambda x: x[nd_indexer] + def kernel(x_ref, y_ref): + x = pl.load(x_ref, nd_indexer) + pl.store(y_ref, (slice(None),) * len(y_ref.shape), x) + func = pl.pallas_call(kernel, out_shape=expected_shape) + + shape = el_shape + for vmap_dim in vmap_shape[::-1]: + index = data.draw(hps.integers(min_value=0, + max_value=max(0, len(shape) - 2)), + label="index") + # hp.assume(index <= max(0, len(shape) - 2)) + # TODO(sharadmv,apaszke): enable vmapping over batch axes in 2 minormost + # dimensions + shape = (*shape[:index], vmap_dim, *shape[index:]) + ref = jax.vmap(ref, in_axes=index, out_axes=0) + func = jax.vmap(func, in_axes=index, out_axes=0) + key = random.PRNGKey(0) + x = random.normal(key, shape, dtype=jnp.float32) + expected = ref(x) + y = func(x) + np.testing.assert_array_equal(y, expected) + + @parameterized.product( + indexer_type=["state", "pallas"], + case=_INDEXING_TEST_CASES, + ) + def test_can_load_with_ref_at(self, indexer_type, case): + if self.INTERPRET: + self.skipTest("TODO: fails in interpret mode.") + in_shape, indexers, out_shape = case + dtype = jnp.float32 + def body(x_ref, y_ref): + for indexer in indexers[:-1]: + x_ref = x_ref.at[indexer] + if indexer_type == "state": + x = x_ref[indexers[-1]] + y_ref[...] = x + elif indexer_type == "pallas": + x = pl.load(x_ref, indexers[-1]) + pl.store(y_ref, ..., x) + + x = random.normal(random.key(0), in_shape, dtype=dtype) + y = x + for indexer in indexers: + if not isinstance(indexer, tuple): + indexer = (indexer,) + indexer = tuple(map(_maybe_ds_to_slice, indexer)) + y = y[indexer] + assert y.shape == out_shape + out = self.pallas_call(body, out_shape=y)(x) + self.assertAllClose(out, y) + + @parameterized.product( + indexer_type=["state", "pallas"], + case=_INDEXING_TEST_CASES, + ) + def test_can_store_with_ref_at(self, indexer_type, case): + if self.INTERPRET: + self.skipTest("TODO: fails in interpret mode.") + in_shape, indexers, val_shape = case + dtype = jnp.float32 + def body(x_ref, y_ref): + y_ref[...] = jnp.zeros_like(y_ref) + for indexer in indexers[:-1]: + y_ref = y_ref.at[indexer] + if indexer_type == "state": + x = x_ref[...] + y_ref[indexers[-1]] = x + elif indexer_type == "pallas": + x = pl.load(x_ref, ...) + pl.store(y_ref, indexers[-1], x) + + val = random.normal(random.key(0), val_shape, dtype=dtype) + # Use NumPy arrays to do nested indexing and mutation. This is really + # annoying to do in vanilla JAX. + x = np.zeros(in_shape, dtype=dtype) + y = x + for indexer in indexers: + if not isinstance(indexer, tuple): + indexer = (indexer,) + indexer = tuple(map(_maybe_ds_to_slice, indexer)) + y = y[indexer] + assert y.shape == val_shape + y[...] = val + out = self.pallas_call(body, out_shape=x)(val) + self.assertAllClose(out, x) + + @parameterized.product( + indexer_type=["state", "pallas"], + slice_type=["slice", "ds"], + ) + @hp.given( + ref_shape=hps.sampled_from(((8, 8, 32), (7, 7, 33))), + indices=hps.tuples( + hps.integers(0, 6), hps.integers(0, 6), hps.integers(0, 31) + ), + strides=hps.tuples( + hps.integers(1, 10), hps.integers(1, 10), hps.integers(1, 10) + ), + ) + def test_strided_load_and_store( + self, indexer_type, slice_type, ref_shape, indices, strides + ): + if self.INTERPRET: + self.skipTest("TODO: fails in interpret mode.") + ref_shape = (*ref_shape, 128) + indices = (*indices, 0) + strides = (*strides, 1) + vec_shape = [ + (l - i + s - 1) // s for l, i, s in zip(ref_shape, indices, strides) + ] + dtype = jnp.float32 + + def body(x_ref, y_ref1, y_ref2): + if slice_type == "slice": + slices = tuple( + [slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides)] + ) + else: + slices = tuple( + [pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides)] + ) + if indexer_type == "state": + y_ref1[...] = x_ref[slices] + y_ref2[slices] = y_ref1[...] + elif indexer_type == "pallas": + pl.store(y_ref1, ..., pl.load(x_ref, slices)) + pl.store(y_ref2, slices, pl.load(y_ref1, ...)) + + x = random.normal(random.key(0), ref_shape, dtype=dtype) + y1, y2 = self.pallas_call( + body, + out_shape=[ + jax.ShapeDtypeStruct(vec_shape, dtype), + jax.ShapeDtypeStruct(ref_shape, dtype), + ], + )(x) + slices = tuple( + slice(i, l, s) for l, i, s in zip(ref_shape, indices, strides) + ) + expected = x[slices] + self.assertAllClose(y1, expected, err_msg="Strided Load Error") + self.assertAllClose( + y2[slices], expected, err_msg="Strided Store Error" + ) + + def test_load_with_dynamic_2nd_minor_index(self): + # We can take any dynamic index on the 2nd minor dimension as long as + # the minormost dimsize is vreg lane count. + m, n = 32, 128 + k = 10 + start = 2 + + def kernel(x_ref, indices, y_ref): + y_ref[...] = pl.load(x_ref, pl.ds(indices[0], k)) + + x = jnp.arange(m * n, dtype=jnp.int32).reshape((m, n)) + indices = jnp.array([start]) + + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((k, n), jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + ), + )(x, indices) + self.assertAllClose(res, x[start : start + k, :], atol=0., rtol=0.) + + def test_store_with_dynamic_2nd_minor_index(self): + # We can take any dynamic index on the 2nd minor dimension as long as + # the minormost dimsize is vreg lane count. + m, n = 10, 128 + k = 32 + start = 2 + + def kernel(x_ref, indices, y_ref): + pl.store(y_ref, pl.ds(indices[0], m), x_ref[...]) + + x = jnp.arange(m * n, dtype=jnp.int32).reshape((m, n)) + indices = jnp.array([start]) + + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((k, n), jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + ), + )(x, indices) + self.assertAllClose(res[start : start + m, :], x, atol=0., rtol=0.) + + def test_load_one_row_with_dynamic_2nd_minor_index(self): + # This test triggers strided load. We can take any dynamic index on the + # 2nd minor dimension as long as we load one row on the 2nd minor dim. + b, m, n = 4, 16, 256 + start = 3 + + def kernel(x_ref, indices, y_ref): + y_ref[...] = x_ref[:, pl.ds(indices[0], 1), :] + + x = jnp.arange(b * m * n, dtype=jnp.int32).reshape((b, m, n)) + indices = jnp.array([start]) + + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((b, 1, n), jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + ), + )(x, indices) + self.assertAllClose(res, x[:, start : start + 1, :], atol=0., rtol=0.) + + def test_store_one_row_with_dynamic_2nd_minor_index(self): + # This test triggers strided store. We can take any dynamic index on the + # 2nd minor dimension as long as we store one row on the 2nd minor dim. + b, m, n = 4, 16, 256 + start = 3 + + def kernel(x_ref, indices, y_ref): + y_ref[:, pl.ds(indices[0], 1), :] = x_ref[...] + + x = jnp.arange(b * 1 * n, dtype=jnp.int32).reshape((b, 1, n)) + indices = jnp.array([start]) + + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((b, m, n), jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + ), + )(x, indices) + self.assertAllClose(res[:, start : start + 1, :], x, atol=0., rtol=0.) + + +class IndexerOpsInterpreterTest(IndexerOpsTest): + INTERPRET = True + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index ee0cd3531e90..6fbca406931a 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -14,10 +14,13 @@ """Tests for common JAX operations within pallas_call.""" +from collections.abc import Sequence import contextlib import functools import itertools import sys +from typing import Any +import unittest import numpy as np from absl.testing import absltest @@ -41,10 +44,19 @@ plgpu = None pltpu = None +try: + import hypothesis as hp +except (ModuleNotFoundError, ImportError): + raise unittest.SkipTest("tests depend on hypothesis library") + +import hypothesis.extra.numpy as hnp +import hypothesis.strategies as hps + # There are many inherited redefinitions of _ # ruff: noqa: F811 jax.config.parse_flags_with_absl() +jtu.setup_hypothesis(max_examples=100) def smem_on_tpu(): @@ -54,6 +66,181 @@ def smem_on_tpu(): return None +def _random_value(key: jax.Array, shape_dtype: jax.ShapeDtypeStruct + ) -> jax.Array: + if jnp.issubdtype(shape_dtype.dtype, jnp.floating): + return random.normal(key, shape_dtype.shape, dtype=shape_dtype.dtype) + elif jnp.issubdtype(shape_dtype.dtype, jnp.integer): + return random.randint( + key, shape_dtype.shape, minval=-4, maxval=4, dtype=shape_dtype.dtype + ) + raise NotImplementedError(shape_dtype) + + +_DTYPES = ( + "float32", + "bfloat16", + "int32", + "int16", + "int8", + "bool", +) + + +@hps.composite +def make_shape_dtype_strategy( + draw, *, + min_rank: int, + max_rank: int, + min_size_exp: int, + max_size_exp: int, + valid_dtypes: Sequence[jnp.dtype], + max_bytes: int = 2**16, +) -> jax.ShapeDtypeStruct: + dtype = draw(hps.sampled_from(valid_dtypes)) + # To generate shapes with power-of-two sizes, we draw the exponents of the + # sizes, and then generate the sizes from the exponents. + shape_exponents = tuple( + draw(hps.lists( + hps.integers(min_value=min_size_exp, max_value=max_size_exp), + min_size=min_rank, max_size=max_rank)) + ) + shape = tuple(2**exp for exp in shape_exponents) + size = np.prod(shape) * dtype.itemsize + hp.assume(size <= max_bytes) # Make sure we don't take more than 4K VMEM + return jax.ShapeDtypeStruct(shape, dtype) + + +@hps.composite +def arrays( + draw, shape: tuple[int, ...], dtype: np.dtype, + *, elements: hps.SearchStrategy[Any] | None = None, +) -> np.ndarray: + cast_to_bf16 = False + if dtype == np.dtype(jnp.bfloat16): + dtype = np.dtype('float32') + cast_to_bf16 = True + arr = draw(hnp.arrays(shape=shape, dtype=dtype, elements=elements)) + if cast_to_bf16: + arr = arr.astype(np.dtype(jnp.bfloat16)) + return arr + + +@hps.composite +def select_n_strategy( + draw, *, max_cases: int = 4, + min_rank: int = 0, max_rank: int = 2, + min_size_exp: int = 0, max_size_exp: int = 8, +) -> tuple[np.ndarray, ...]: + n_cases = draw(hps.integers(min_value=1, max_value=max_cases)) + case_shape_dtype = draw( + make_shape_dtype_strategy( + min_rank=min_rank, max_rank=max_rank, + min_size_exp=min_size_exp, max_size_exp=max_size_exp, + valid_dtypes=[ + np.dtype("int32"), + np.dtype("float32"), + # TODO(sharadmv,apaszke): enable bf16 + # np.dtype(jnp.bfloat16), + ], + ) + ) + allowed_elements = hps.integers(min_value=0, max_value=n_cases - 1) + pred_shape = draw(hps.sampled_from([(), case_shape_dtype.shape])) + # TODO(sharadmv,apaszke): enable passing bool arrays into Pallas kernels + if n_cases == 2 and not pred_shape: + pred_dtype = draw(hps.sampled_from([np.dtype(np.bool_), + np.dtype(np.int32)])) + allowed_elements = hps.booleans() + else: + pred_dtype = np.int32 + pred = draw(arrays(shape=pred_shape, dtype=pred_dtype, + elements=allowed_elements)) + cases = ( + draw( + arrays(shape=case_shape_dtype.shape, dtype=case_shape_dtype.dtype) + ) + for _ in range(n_cases) + ) + return pred, *cases + + +UNARY_PRIMITIVES = [ + # TODO(sharadmv,apaszke): enable zero rank + # TODO(sharadmv,apaszke): enable one rank + # TODO(sharadmv,apaszke): enable zero dim sizes + # TODO(sharadmv,apaszke): enable one dim sizes + ( + lax.neg_p, + make_shape_dtype_strategy( + min_rank=2, + max_rank=3, + min_size_exp=1, + max_size_exp=6, + valid_dtypes=[jnp.dtype("float32"), jnp.dtype("int32")], + ), + ), + ( + lax.not_p, + make_shape_dtype_strategy( + min_rank=2, + max_rank=3, + min_size_exp=1, + max_size_exp=6, + valid_dtypes=[jnp.dtype("int32")], + ), + ), + *[ + ( + prim, + make_shape_dtype_strategy( + min_rank=2, + max_rank=3, + min_size_exp=1, + max_size_exp=6, + valid_dtypes=[jnp.dtype("float32")], + ), + ) + for prim in [ + lax.exp_p, + lax.tanh_p, + lax.logistic_p, + lax.rsqrt_p, + lax.log_p, + lax.exp2_p, + lax.abs_p, + lax.log1p_p, + lax.sin_p, + lax.sqrt_p, + ] + ], +] + +UNARY_FUNCTIONS = [ + (prim.name, prim.bind, strategy) for prim, strategy in UNARY_PRIMITIVES +] + [ + ( + name, + func, + make_shape_dtype_strategy( + min_rank=2, + max_rank=3, + min_size_exp=1, + max_size_exp=6, + valid_dtypes=[jnp.dtype("float32")], + ), + ) + for name, func in [ + ("relu", jax.nn.relu), + ("pow2", lambda x: jnp.power(2, x)), + ("square", jnp.square), + ("reciprocal", jnp.reciprocal), + ("round", jnp.round), + ("rint", jnp.rint), + ] +] + + class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -292,6 +479,190 @@ def kernel(x_ref, o_ref): np.testing.assert_allclose(result[0, 0], reduction_op(x), atol=1e-5) + # TODO(sharadmv): test rank < 2, size < 2 + @hp.given(select_n_strategy(max_cases=2, min_rank=2, max_rank=4, + min_size_exp=1)) + def test_select_n(self, args): + if jtu.test_device_matches(["gpu"]): + self.skipTest("TODO: error on GPU, lowering bug for select_n") + pred, *cases = args + scalar_pred = not pred.shape + + def kernel(*refs): + if scalar_pred: + *case_refs, o_ref = refs + pred_ = pred + else: + pred_ref, *case_refs, o_ref = refs + pred_ = pred_ref[...] + vals = [case_ref[...] for case_ref in case_refs] + o_ref[...] = lax.select_n(pred_, *vals) + out_ref = lax.select_n(pred, *cases) + if scalar_pred: + args = cases + else: + args = [pred, *cases] + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct(out_ref.shape, out_ref.dtype), + )(*args) + if out.dtype == jnp.bfloat16: + out, out_ref = out.astype(jnp.float32), out_ref.astype(jnp.float32) + np.testing.assert_allclose(out, out_ref) + + @parameterized.named_parameters( + (name, name, func, strategy) + for name, func, strategy in UNARY_FUNCTIONS + ) + @hp.given(hps.data()) + def test_unary_primitives(self, name, func, shape_dtype_strategy, data): + if self.INTERPRET: + self.skipTest("This hypothesis test is slow, even more so in interpret mode.") + # We want exact equality here to match how JAX lowers to XLA + tol = 0. + if jtu.test_device_matches(["gpu"]): + if func == jnp.round or func == jnp.rint: + self.skipTest("TODO: not implemented on GPU") + if name == "tanh": + tol = 1e-6 + elif name == "exp2": + tol = 1e-6 + elif jtu.test_device_matches(["tpu"]): + if not jtu.is_device_tpu_at_least(version=5) and False: + self.skipTest("TODO: not implemented on TPU v{3,4}") + + def kernel(x_ref, y_ref): + y_ref[...] = func(x_ref[...]) + x_shape_dtype = data.draw(shape_dtype_strategy) + key = random.key(0) + x = _random_value(key, x_shape_dtype) + out = self.pallas_call(kernel, out_shape=x_shape_dtype)(x) + self.assertAllClose(out, func(x), atol=tol, rtol=tol) + + @parameterized.product(from_dtype=_DTYPES, to_dtype=_DTYPES) + @hp.given(hps.data()) + def test_cast(self, from_dtype, to_dtype, data): + if from_dtype == to_dtype: + self.skipTest("Unnecessary test") + if jtu.is_device_tpu(version=4): + if from_dtype in {"int16", "int8"} or to_dtype in {"int16", "int8"}: + self.skipTest( + "Not supported: TPU generation doesn't support this cast." + ) + if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4: + if from_dtype in {"int32", "float32", "bfloat16"} and to_dtype in {"int16", "int8"}: + self.skipTest( + "Not supported: TPU generation doesn't support this cast." + ) + + # TODO(sharadmv,apaszke): add support for the following casts + if from_dtype == "int16" and to_dtype == "int8": + self.skipTest("Not supported: bad canonicalization") + if from_dtype == "int8" and to_dtype == "int16": + self.skipTest("Not supported: bad canonicalization") + if from_dtype == "bool" and to_dtype in {"int16", "int8"}: + self.skipTest("Not supported: cannot extend to sub-32 bit types") + if from_dtype in {"int32", "bfloat16", "float32"} and to_dtype == "bool": + self.skipTest("Not supported: unsupported relayout") + if from_dtype == "bool" and to_dtype in {"int32", "bfloat16", "float32"}: + self.skipTest("Not supported: unsupported relayout") + if from_dtype in {"int16", "int8"} and to_dtype == "bool": + self.skipTest("Not supported: cannot truncate from sub-32 bit types") + if from_dtype in {"int16", "int8"} and to_dtype == "bool": + self.skipTest("Not supported: cannot truncate from sub-32 bit types") + if jtu.test_device_matches(["gpu"]): + if (from_dtype in {"bfloat16", "float32"} and + to_dtype in {"int8", "int16", "int32"}): + self.skipTest("TODO: wrong result on GPU") + + if from_dtype == "bfloat16": + from_dtype = jnp.bfloat16 + if to_dtype == "bfloat16": + to_dtype = jnp.bfloat16 + + if from_dtype == jnp.bfloat16: + x = jnp.asarray(data.draw(hnp.arrays(jnp.float32, (8, 128)))) + x = x.astype(jnp.bfloat16) + else: + x = data.draw(hnp.arrays(from_dtype, (8, 128))) + x = jnp.asarray(x) + if from_dtype == jnp.dtype("bool"): + x = x.astype(jnp.int32) + def kernel(x_ref, y_ref): + x = x_ref[...] + if from_dtype == jnp.dtype("bool"): + x = x.astype(jnp.dtype("bool")) + y = x.astype(to_dtype) + if to_dtype == jnp.dtype("bool"): + y = y.astype(jnp.int32) + y_ref[...] = y + if (y_dtype := to_dtype) == jnp.dtype("bool"): + y_dtype = jnp.int32 + y = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct(x.shape, y_dtype))(x) + if to_dtype == jnp.dtype("bool"): + y = y.astype(jnp.dtype("bool")) + y_ref = x.astype(to_dtype) + if to_dtype == jnp.bfloat16: + y, y_ref = y.astype(np.float32), y_ref.astype(np.float32) + np.testing.assert_allclose(y, y_ref, atol=0., rtol=0.) + + @parameterized.product( + shape=((64,), (8, 8)), + dtype=(jnp.int32, jnp.int16, jnp.int8), + ) + def test_scalar_map(self, shape, dtype): + if dtype != jnp.int32 and len(shape) < 2: + # TODO(b/299280718): Implement this. + self.skipTest( + "Loads and stores not implemented for 1D arrays of non-32bit types" + ) + def kernel(x_ref, y_ref): + for idx in np.ndindex(shape): + x = x_ref[idx].astype(jnp.int32) + y_ref[idx] = (x * x).astype(y_ref.dtype) + f = self.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + ) + x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + self.assertAllClose(f(x), x * x) + + @jtu.skip_on_devices("gpu") # TODO: not implemented + def test_extract_scalar(self): + def kernel(x_ref, y_ref): + y_ref[0, 0] = x_ref[:][0, 0] + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((1, 1), jnp.float32), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + ) + x = np.arange(1024, dtype=jnp.float32).reshape(8, 128) + 10 + self.assertAllClose(f(x).item(), 10.0) + + @jtu.skip_on_devices("gpu") # TODO: not implemented + def test_concat_constant(self): + def kernel(out): + result = [] + for i in range(16): + result.append(jnp.full((1, 128), i, jnp.float32)) + out[:] = jnp.stack(result).reshape(16, 128) + + def run(interpret=False): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + interpret=interpret, + )() + expected = run(True) + if not self.INTERPRET: + actual = run(False) + self.assertAllClose(actual, expected) + class OpsInterpreterTest(OpsTest): INTERPRET = True @@ -1130,7 +1501,7 @@ def reduce(x_ref, y_ref): np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) -class OpsExtraInterpreterTest(OpsTest): +class OpsExtraInterpreterTest(OpsExtraTest): INTERPRET = True diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py new file mode 100644 index 000000000000..75aab92af909 --- /dev/null +++ b/tests/pallas/tpu_ops_test.py @@ -0,0 +1,183 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for TPU specific operations within pallas_call.""" + +import sys +import unittest + +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu +from jax.experimental import pallas as pl + +if sys.platform != "win32": + from jax.experimental.pallas import tpu as pltpu +else: + pltpu = None + +try: + import hypothesis as hp +except (ModuleNotFoundError, ImportError): + raise unittest.SkipTest("tests depend on hypothesis library") + +import hypothesis.strategies as hps + +jax.config.parse_flags_with_absl() +jtu.setup_hypothesis(max_examples=100) + +_JAX_DTYPES = ( + jnp.float32, + jnp.bfloat16, + jnp.int32, + jnp.int16, + jnp.int8, + jnp.bool_, +) + + +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if not self.INTERPRET: + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only interpret mode supported on non-TPU") + + super().setUp() + + @classmethod + def pallas_call(cls, *args, **kwargs): + return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) + + +class OpsTest(PallasBaseTest): + + @parameterized.product(from_dtype=_JAX_DTYPES, to_dtype=_JAX_DTYPES) + def test_bitcast(self, from_dtype, to_dtype): + # TODO(jevinjiang): remove this after 2nd minor large tiling is enabled. + if (not jtu.is_device_tpu_at_least(version=5)) and ( + from_dtype in (jnp.int8, jnp.int16) or to_dtype in (jnp.int8, jnp.int16) + ): + self.skipTest( + "Not implemented: packing and unpacking int8, int16 are not supported" + " on < TPUv5" + ) + if from_dtype == to_dtype: + self.skipTest("No bitcast needed") + if from_dtype == jnp.bool_ or to_dtype == jnp.bool_: + self.skipTest("Bitcasting with bool is not supported") + + def kernel(x_ref, y_ref): + y_ref[...] = pltpu.bitcast(x_ref[...], to_dtype) + + m, n = 32, 256 + shape = (m, n) + out_shape = (m * from_dtype.dtype.itemsize // to_dtype.dtype.itemsize, n) + inp = np.arange(np.prod(shape), dtype=from_dtype).reshape(shape) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(out_shape, to_dtype), + )(inp) + if not self.INTERPRET: + out_interpret = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(out_shape, to_dtype), + interpret=True, + )(inp) + self.assertAllClose(out, out_interpret) + + @parameterized.product(is_dynamic=(False, True)) + @hp.given( + axis=hps.integers(0, 3), + shift=hps.integers(0, 3), + stride=hps.one_of(hps.just(None), hps.integers(0, 2)), + # Stride dimension on the minor most is not supported. + stride_axis=hps.one_of(hps.just(None), hps.integers(0, 2)), + ) + @hp.example(3, 9, 1, 2) + @hp.example(3, 9, 2, 2) + @hp.example(0, 9, 0, 1) + @hp.example(0, 9, 1, 1) + def test_roll(self, is_dynamic, axis, shift, stride, stride_axis): + if (stride is None) != (stride_axis is None): + self.skipTest( + "Roll op requires both stride and stride_axis to be either specified" + " or not specified." + ) + if (not jtu.is_device_tpu(version=5)) and stride_axis == 2: + self.skipTest( + "Roll op with stride axis on 2nd minor requires at least TPU v5" + ) + shape = (4, 4, 32, 512) + + def kernel(s_ref, x_ref, y_ref): + amt = s_ref[0] if is_dynamic else shift + y_ref[...] = pltpu.roll( + x_ref[...], amt, axis, stride=stride, stride_axis=stride_axis + ) + + def roll(x, shift, axis, stride=None, stride_axis=None): + assert (stride is None) == (stride_axis is None) + if stride is None: + return np.roll(x, shift, axis) + outputs = [ + np.roll(xs, shift + i * stride, axis) + for i, xs in enumerate(np.split(x, x.shape[stride_axis], stride_axis)) + ] + return np.concatenate(outputs, stride_axis) + + inp = np.arange(np.prod(shape), dtype=jnp.int32).reshape(shape) + ref = roll(inp, shift, axis, stride, stride_axis) + dynamic_shift = jnp.array([abs(shift)], jnp.int32) + for interpret in [False, True]: + out = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(shape, jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec(num_scalar_prefetch=1), + interpret=interpret, + )(dynamic_shift, inp) + np.testing.assert_array_equal(out, ref, err_msg=f"{interpret=}") + + def test_interleave_vectors(self): + if not jtu.is_device_tpu_at_least(version=4): + self.skipTest("Expect TPUv4+") + + def kernel(x_ref, y_ref, out_ref): + x = pltpu.bitcast(x_ref[...].astype(jnp.float32), jnp.int32) + y = pltpu.bitcast(y_ref[...].astype(jnp.float32), jnp.int32) + shift = jax.lax.broadcast(16, x.shape) + out_ref[...] = pltpu.bitcast( + y | jax.lax.shift_right_logical(x, shift), jnp.bfloat16 + ) + + m, n = 16, 128 + inp = np.arange(m * n * 2, dtype=jnp.bfloat16).reshape(m, n * 2) + x, y = np.split(inp, 2, axis=1) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((m * 2, n), jnp.bfloat16), + )(x, y) + np.testing.assert_array_equal(out, inp.reshape(m * 2, n)) + + +class OpsInterpreterTest(OpsTest): + INTERPRET = True + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 2f813afeab14..bf00acfb670f 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2170,6 +2170,43 @@ class PallasCallTPUCheckifyInterpretTest(PallasCallTPUCheckifyTest): INTERPRET: bool = True +class PrettyPrintingTest(PallasBaseTest): + + @parameterized.parameters( + ( + lambda i: (i, pl.ds(0, 8), pl.ds(0, 128)), + 'dma_start c[d,:,:] -> e[...] f', + ), + ( + lambda i: (0, pl.ds(i, 8), pl.ds(0, 128)), + 'dma_start c[0,d:d+8,:] -> e[...] f', + ), + ( + lambda i: (i, pl.ds(2, 4), pl.ds(0, 100)), + 'dma_start c[d,2:6,:100] -> e[...] f', + ), + ( + lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)), + 'dma_start c[d,2:,4:104] -> e[...] f', + ), + ) + def test_dma_custom_pretty_print(self, indexer, expected): + def body(x_hbm_ref, i): + def inner(x_ref, sem): + pltpu.async_copy(x_hbm_ref.at[indexer(i)], x_ref, sem).wait() + + pl.run_scoped( + inner, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA + ) + return [] + + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(body), [state.shaped_array_ref((2, 8, 128), jnp.int32), + jax.core.ShapedArray((), jnp.int32)] + ) + self.assertIn(expected, jaxpr.pretty_print(use_color=False)) + + def only_passes_in_interpret(unless_generation: int | None = None): def decorator(f): def wrapper(self): From cd4e91b2b022e02211096cc1b35265110d367a15 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 12 Aug 2024 07:33:32 -0700 Subject: [PATCH 067/702] [mosaic_gpu] Store untiled splat layout PiperOrigin-RevId: 662077826 --- jax/experimental/mosaic/gpu/fragmented_array.py | 17 +++++++++++++++++ tests/mosaic/gpu_test.py | 16 ++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 44f6904e335e..259cfe4ae430 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -577,11 +577,28 @@ def store_untiled(self, ref: ir.Value): match self.layout: case WGMMAFragLayout(): self._store_untiled_wgmma(ref) + case WGSplatFragLayout(): + self._store_untiled_splat(ref) case WGStridedFragLayout(): self._store_untiled_wg_strided(ref) case _: raise NotImplementedError(self.layout) + def _store_untiled_splat(self, ref: ir.Value): + vec_size = 8 // mgpu.bytewidth(self.mlir_dtype) + if np.prod(self.shape) < vec_size * WARPGROUP_SIZE: + vec_size = 1 + + if np.prod(self.shape) % WARPGROUP_SIZE * vec_size: + raise ValueError(self.shape, WARPGROUP_SIZE, vec_size) + + fa = FragmentedArray.splat( + self.registers.flat[0], + self.shape, + layout=WGStridedFragLayout(shape=self.shape, vec_size=vec_size), + ) + fa.store_untiled(ref) + def _store_untiled_wg_strided(self, ref: ir.Value): ref_ty = ir.MemRefType(ref.type) ref_shape = tuple(ref_ty.shape) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index dec9452fd9d7..ca42a383adef 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -374,6 +374,22 @@ def kernel(ctx, out, _): )() np.testing.assert_array_equal(iota, expected) + @parameterized.named_parameters( + ("f32", ir.F32Type, jnp.float32, 256), + ("f16", ir.F16Type, jnp.float16, 256), + ("f16_small", ir.F16Type, jnp.float16, 128), + ) + def test_store_untiled_splat(self, mlir_dtype_cls, jax_dtype, size): + mlir_dtype = mlir_dtype_cls.get() + def kernel(ctx, out, _): + del ctx + mgpu.FragmentedArray.splat(c(1., mlir_dtype), (size,)).store_untiled(out) + expected = np.ones((size,), jax_dtype) + mosaic_ones = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), expected, () + )() + np.testing.assert_array_equal(mosaic_ones, expected) + @parameterized.product( dtypes=( (ir.F32Type.get, jnp.float32), From 4eb5ef28efc6cb9cf1d07fc269a1b896fd4b50ed Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 12 Aug 2024 08:12:36 -0700 Subject: [PATCH 068/702] Update shape polymorphism tests to skip lu_pivots_to_permutations tests when jaxlib version is too old. PiperOrigin-RevId: 662088901 --- tests/shape_poly_test.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index cff9eb4d6e8e..d5b32cdbd7fc 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -49,6 +49,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow from jax._src.lib import xla_client +from jax._src.lib import version as jaxlib_version import numpy as np config.parse_flags_with_absl() @@ -3395,9 +3396,22 @@ def test_harness(self, harness: PolyHarness): "vmap_qr:gpu", "qr:gpu", "vmap_svd:gpu", } - if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses: + name_device_key = f"{harness.group_name}:{jtu.device_under_test()}" + if name_device_key in custom_call_harnesses: raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778") + # This list keeps track of the minimum jaxlib version that supports shape + # polymorphism for some new primitives as we add them. This check is + # required so that we can still run the test suite with older versions of + # jaxlib. + version_gated = { + # TODO(danfm): remove these checks when jaxlib 0.4.32 is released. + "lu_pivots_to_permutation:gpu": (0, 4, 32), + "lu_pivots_to_permutation_error:gpu": (0, 4, 32), + } + if version_gated.get(name_device_key, jaxlib_version) > jaxlib_version: + raise unittest.SkipTest(f"shape polymorphism not supported by jaxlib version {jaxlib_version}") + if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]): raise unittest.SkipTest("schur decomposition is only implemented on CPU.") From ad74e55dbce46eca114d809796f0dde323da064b Mon Sep 17 00:00:00 2001 From: Zhuo Peng Date: Mon, 12 Aug 2024 09:23:40 -0700 Subject: [PATCH 069/702] Support `None` leaves in arguments to gradient of a call_tf wrapped function. PiperOrigin-RevId: 662115139 --- jax/experimental/jax2tf/call_tf.py | 17 +++++++++++------ jax/experimental/jax2tf/tests/call_tf_test.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 6cb1ec7e4cb2..9018f781198c 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -224,9 +224,11 @@ def make_call_vjp_bwd(residual_jax, ct_res_jax): def tf_vjp_fun(args_tf, ct_res_tf): """Invoke TF gradient.""" - # TF does not like us to watch non-float vars - def replace_non_float(arg_tf): - if arg_tf.dtype.is_floating or arg_tf.dtype.is_complex: + # TF does not like us to watch non-float vars or Nones. + def replace_non_float_or_none(arg_tf): + if arg_tf is not None and ( + arg_tf.dtype.is_floating or arg_tf.dtype.is_complex + ): return arg_tf else: # When watched, this will be ignored. When used in results it will @@ -234,17 +236,20 @@ def replace_non_float(arg_tf): # replace it with a float0) return tf.zeros((), dtype=tf.float32) - watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf) + watched_args_tf = tf.nest.map_structure( + replace_non_float_or_none, args_tf + ) with tf.GradientTape(persistent=True) as tape: tape.watch(watched_args_tf) res = callable_tf(*args_tf) tf.nest.assert_same_structure(res, ct_res_tf) dres_darg = tape.gradient( - tf.nest.map_structure(replace_non_float, res), + tf.nest.map_structure(replace_non_float_or_none, res), sources=watched_args_tf, output_gradients=ct_res_tf, - unconnected_gradients=tf.UnconnectedGradients.ZERO) + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) dres_darg = tree_util.tree_map( lambda x: x if x is None else tf.convert_to_tensor(x), diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 5740b76038d8..2760efea8061 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -1136,6 +1136,16 @@ def tf_f(x): # Jit mode self.assertAllClose(jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x)) + def test_grad_pytree_arg_with_none_leaf(self): + def tf_f(x, params): + return x * params["y"] + + x = jnp.array(1.0) + y = jnp.array(2.0) + actual = jax.grad( + jax2tf.call_tf(tf_f), argnums=(1,))(x, {"y": y, "other": None}) + self.assertDictEqual(actual[0], {"y": x, "other": None}) + class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): "Reloading output of call_tf into TF with jax2tf." From 60bf5b7727c9cdcc5928ca6b8b9ae4f7695892cd Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 12 Aug 2024 10:29:15 -0700 Subject: [PATCH 070/702] Add a `jax.process_indices` function. The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface. PiperOrigin-RevId: 662142636 --- CHANGELOG.md | 2 ++ docs/jax.rst | 1 + jax/__init__.py | 1 + jax/_src/xla_bridge.py | 23 +++++++++++++++++++---- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 038c0131ad12..2b30b08abb01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. more cases. Previously non-parallel computations were always dispatched synchronously. You can recover the old behavior by setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`. + * Added new {func}`jax.process_indices` function to replace the + `jax.host_ids()` function that was deprecated in JAX v0.2.13. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/docs/jax.rst b/docs/jax.rst index b112490a0912..7be3e63015d9 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -138,6 +138,7 @@ Parallelization (:code:`pmap`) device_count local_device_count process_count + process_indices Callbacks --------- diff --git a/jax/__init__.py b/jax/__init__.py index d9c4de6bb617..dc3d9af3a0c4 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -119,6 +119,7 @@ from jax._src.api import pmap as pmap from jax._src.xla_bridge import process_count as process_count from jax._src.xla_bridge import process_index as process_index +from jax._src.xla_bridge import process_indices as process_indices from jax._src.callback import pure_callback as pure_callback from jax._src.ad_checkpoint import checkpoint_wrapper as remat from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 91d761fec5d4..1d3c50403b47 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -1187,15 +1187,30 @@ def host_count(backend: str | xla_client.Client | None = None) -> int: return process_count(backend) +def process_indices( + backend: str | xla_client.Client | None = None +) -> list[int]: + """Returns the list of all JAX process indices associated with the backend. + + Args: + backend: This is an experimental feature and the API is likely to change. + Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or + ``'tpu'``. + + Returns: + List of integer process indices. + """ + return list(range(process_count(backend))) + + # TODO: remove this sometime after jax 0.2.13 is released def host_ids( backend: str | xla_client.Client | None = None ) -> list[int]: warnings.warn( - "jax.host_ids has been deprecated; please use range(jax.process_count()) " - "instead. jax.host_ids will eventually be removed; please update your " - "code.") - return list(range(process_count(backend))) + "jax.host_ids has been renamed to jax.process_indices. This alias " + "will eventually be removed; please update your code.") + return process_indices(backend) def using_pjrt_c_api(backend=None): From 53045380b11a0ce5a8ca680a5fdd429c7221ccef Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 12 Aug 2024 10:39:58 -0700 Subject: [PATCH 071/702] Make custom partitioning work without a mesh context manager. If the arguments have NamedSharding on them, then inside `partition` function, we should get NamedSharding without the existence of the mesh context manager PiperOrigin-RevId: 662146686 --- jax/_src/custom_partitioning.py | 6 ++++- jax/_src/dispatch.py | 2 +- jax/_src/interpreters/mlir.py | 6 ++--- jax/_src/interpreters/pxla.py | 24 ++++++++++++-------- jax/_src/sharding_impls.py | 1 + tests/pjit_test.py | 39 +++++++++++++++++++++++++++++++++ 6 files changed, 64 insertions(+), 14 deletions(-) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index c038ef0641a8..8f48746dda37 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -24,6 +24,7 @@ from typing import Any import weakref +import numpy as np import jax from jax import tree_util from jax._src import api_util @@ -481,17 +482,20 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, infer_sharding_from_operands, decode_shardings, static_args): - mesh = mesh_lib.thread_resources.env.physical_mesh axis_context = ctx.module_context.axis_context if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)): return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values) + mesh = mesh_lib.thread_resources.env.physical_mesh if isinstance(axis_context, sharding_impls.ShardingContext): devices = axis_context.device_assignment if devices is None: raise AssertionError( 'Please file a bug at https://github.com/google/jax/issues') + if axis_context.mesh_shape is not None: + ma, ms = list(zip(*axis_context.mesh_shape)) + mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma) elif isinstance(axis_context, sharding_impls.SPMDAxisContext): devices = axis_context.mesh._flat_devices_tuple else: diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 8605c58a81cd..068b5e3b7e25 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -204,7 +204,7 @@ def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool: # stablehlo is oblivious of physical devices. prim_requires_devices_during_lowering: set[core.Primitive] = set() -def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr): +def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool: for eqn in jaxpr.eqns: if eqn.primitive in prim_requires_devices_during_lowering: return True diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 4d15e803adfd..814c6a9886d7 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1033,7 +1033,6 @@ def lower_jaxpr_to_module( input_output_aliases: None | tuple[int | None, ...] = None, propagated_out_mem_kinds: tuple[None | str, ...] | None = None, lowering_parameters: LoweringParameters, - mesh_shape_tuple: tuple[tuple[str, int], ...] | None = None, ) -> LoweringResult: """Lowers a top-level jaxpr to an MLIR module. @@ -1121,13 +1120,14 @@ def lower_jaxpr_to_module( # XLA computation preserves the module name. attrs = ctx.module.operation.attributes if config.use_shardy_partitioner.value: - assert mesh_shape_tuple is not None + assert (isinstance(axis_context, sharding_impls.ShardingContext) and + axis_context.mesh_shape is not None) ctx.module.body.append( dialects.sdy.MeshOp( "mesh", dialects.sdy.MeshAttr.get( [dialects.sdy.MeshAxisAttr.get(name, size) - for name, size in mesh_shape_tuple]))) + for name, size in axis_context.mesh_shape]))) module_name = _module_name_regex.sub("_", module_name) attrs["sym_name"] = ir.StringAttr.get(module_name) attrs["mhlo.num_replicas"] = i32_attr(num_replicas) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 4dcdfbcbf495..ce96f7e815e1 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1881,7 +1881,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, propagated_out_mem_kinds: tuple[None | str, ...], platforms: tuple[str, ...], lowering_parameters: mlir.LoweringParameters, - mesh_shape_tuple: tuple[tuple[str, int], ...]): + mesh_shape_tuple: tuple[tuple[str, int], ...] | None): jaxpr = closed_jaxpr.jaxpr in_shardings = semantic_in_shardings.shardings out_shardings = semantic_out_shardings.shardings @@ -1911,7 +1911,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings) out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings) replicated_args = [False] * len(global_in_avals) - axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment) + axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment, + mesh_shape_tuple) num_partitions = num_devices else: # This path is triggered for `jit(pmap)` cases. @@ -1957,8 +1958,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, all_default_mem_kind=all_default_mem_kind, input_output_aliases=inout_aliases, propagated_out_mem_kinds=propagated_out_mem_kinds, - lowering_parameters=lowering_parameters, - mesh_shape_tuple=mesh_shape_tuple) + lowering_parameters=lowering_parameters) tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) @@ -2202,15 +2202,21 @@ def lower_sharding_computation( in_shardings, global_in_avals) # type: ignore semantic_out_shardings = SemanticallyEqualShardings( out_shardings, global_out_avals) # type: ignore + prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) + mesh_shape_tuple = None - if config.use_shardy_partitioner.value: - for sharding in it.chain( - in_shardings, out_shardings, - [js for js, _ in unique_intermediate_shardings]): + if config.use_shardy_partitioner.value or prim_requires_devices: + for sharding in it.chain(in_shardings, out_shardings, + [js for js, _ in unique_intermediate_shardings]): if isinstance(sharding, sharding_impls.NamedSharding): + if (mesh_shape_tuple is not None and + mesh_shape_tuple != sharding.mesh.shape_tuple): + raise ValueError( + "mesh should be the same across the entire program. Got mesh" + f" shape for one sharding {mesh_shape_tuple} and" + f" {sharding.mesh.shape_tuple} for another") mesh_shape_tuple = sharding.mesh.shape_tuple - break (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index d41ef7410d1a..1a23f4ba74ad 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1162,6 +1162,7 @@ class ShardingContext: """ num_devices: int device_assignment: tuple[xc.Device, ...] | None = None + mesh_shape: tuple[tuple[str, int], ...] | None = None def __post_init__(self): if self.device_assignment is not None: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 44782ec15bc0..9368f7da9cfe 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1514,6 +1514,45 @@ def f(carry, x): xs = jnp.ones([32, 16]) self.assertEqual(pjit_f(xs), xs.sum()) + def test_custom_partitioning_no_mesh_context(self): + self.skip_if_custom_partitioning_not_supported() + + @custom_partitioning + def f(x): + return x + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + @jax.jit + def g(y): + return y + + return g(x) + + x_shard = arg_shapes[0].sharding + return ( + mesh, + lower_fn, + NamedSharding(x_shard.mesh, P('x')), + (NamedSharding(x_shard.mesh, P('x')),), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + x_shard = arg_shapes[0].sharding + return NamedSharding(x_shard.mesh, P('x')) + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + ) + + mesh = jtu.create_global_mesh((4,), ('x',)) + x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32) + s = NamedSharding(mesh, P('x')) + + jit_f = jax.jit(f, in_shardings=s, out_shardings=s) + self.assertArraysEqual(x, jit_f(x)) + @jtu.pytest_mark_if_available('multiaccelerator') class AutoShardingPjitTest(jtu.JaxTestCase): From 802abfef92159d592716008ef3b3a7a9a2a10e6f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 12 Aug 2024 17:54:09 +0000 Subject: [PATCH 072/702] Bump actions/upload-artifact from 4.3.5 to 4.3.6 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.3.5 to 4.3.6. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/89ef406dd8d7e03cfd12d9e0a4a378f454709029...834a144ee995460fba8ed112a2fc961b36a5ec5a) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/upstream-nightly.yml | 2 +- .github/workflows/wheel_win_x64.yml | 2 +- .github/workflows/windows_ci.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 9e0386f75053..1e345954d0f7 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -85,7 +85,7 @@ jobs: && steps.status.outcome == 'failure' && github.event_name == 'schedule' && github.repository == 'google/jax' - uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # ratchet: actions/upload-artifact@v4 + uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 with: name: output-${{ matrix.python-version }}-log.jsonl path: output-${{ matrix.python-version }}-log.jsonl diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 8dad541b81d0..61912ed8978e 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -46,7 +46,7 @@ jobs: --bazel_options=--config=win_clang ` --verbose - - uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # ratchet: actions/upload-artifact@v4 + - uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 with: name: wheels-${{ matrix.os }}-${{ matrix.pyver }} path: ${{ github.workspace }}\dist\*.whl diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 60bebd32ea76..42083f1d087d 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -53,7 +53,7 @@ jobs: --bazel_options=--color=yes ` --bazel_options=--config=win_clang - - uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # ratchet: actions/upload-artifact@v4 + - uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 with: name: wheels path: ${{ github.workspace }}\jax\dist\*.whl From ee31e95ecd6f7caa4d71aebc6130a372febcbd78 Mon Sep 17 00:00:00 2001 From: Brian Wieder Date: Mon, 12 Aug 2024 11:27:03 -0700 Subject: [PATCH 073/702] Register shutdown code at import to hopefully get registered before any other atexit callbacks. `atexit` callbacks are called in a LIFO order, meaning that since Jax currently registers its callback at runtime rather than import time, it gets called before any `atexit` callbacks registered at import time. PiperOrigin-RevId: 662164776 --- jax/_src/distributed.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index b4a72c7678b0..308387d21b20 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -232,11 +232,12 @@ def initialize(coordinator_address: str | None = None, global_state.initialize(coordinator_address, num_processes, process_id, local_device_ids, cluster_detection_method, initialization_timeout, coordinator_bind_address) - atexit.register(shutdown) +@atexit.register def shutdown(): """Shuts down the distributed system. - Does nothing if the distributed system is not running.""" + Does nothing if the distributed system is not running. + """ global_state.shutdown() From e5eaff84bd1865176ba42f15529cc560741d483c Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 12 Aug 2024 13:00:29 -0700 Subject: [PATCH 074/702] Replace `pjrt_c_api_gpu_plugin.so` symlink with XLA dependency. The runfiles of the original targets were lost when the symlinked files were used. This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When pjrt_c_api_gpu_plugin.so is simlinked, the content of the runfiles is lost. With proper XLA target dependency the runfiles are preserved. PiperOrigin-RevId: 662197057 --- jax_plugins/cuda/BUILD.bazel | 16 ++++------------ jax_plugins/cuda/__init__.py | 7 +++++++ jax_plugins/rocm/BUILD.bazel | 16 ++++------------ jax_plugins/rocm/__init__.py | 8 ++++++++ 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index fea9723e189b..79aebcd86826 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -14,7 +14,6 @@ licenses(["notice"]) -load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "if_windows", @@ -35,22 +34,15 @@ exports_files([ "setup.py", ]) -symlink_files( - name = "pjrt_c_api_gpu_plugin", - srcs = if_windows( - ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], - ), - dst = ".", - flatten = True, -) - py_library_providing_imports_info( name = "cuda_plugin", srcs = [ "__init__.py", ], - data = [":pjrt_c_api_gpu_plugin"], + data = if_windows( + ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], + ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], + ), lib_rule = pytype_library, ) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index ff5a1561dbbc..9867c07b1176 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -48,6 +48,13 @@ def _get_library_path(): local_path = os.path.join( os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so' ) + if not os.path.exists(local_path): + runfiles_dir = os.getenv('RUNFILES_DIR', None) + if runfiles_dir: + local_path = os.path.join( + runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so' + ) + if os.path.exists(local_path): logger.debug( 'Native library %s does not exist. This most likely indicates an issue' diff --git a/jax_plugins/rocm/BUILD.bazel b/jax_plugins/rocm/BUILD.bazel index 08a61c786262..6e265bcd18cf 100644 --- a/jax_plugins/rocm/BUILD.bazel +++ b/jax_plugins/rocm/BUILD.bazel @@ -14,7 +14,6 @@ licenses(["notice"]) -load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "if_windows", @@ -35,21 +34,14 @@ exports_files([ "setup.py", ]) -symlink_files( - name = "pjrt_c_api_gpu_plugin", - srcs = if_windows( - ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], - ), - dst = ".", - flatten = True, -) - py_library_providing_imports_info( name = "rocm_plugin", srcs = [ "__init__.py", ], - data = [":pjrt_c_api_gpu_plugin"], + data = if_windows( + ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], + ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], + ), lib_rule = pytype_library, ) diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index 4535f1b3bbc8..3dbcaf4491e0 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -15,6 +15,7 @@ import functools import importlib import logging +import os import pathlib import platform @@ -47,6 +48,13 @@ def _get_library_path(): local_path = ( base_path / 'pjrt_c_api_gpu_plugin.so' ) + if not local_path.exists(): + runfiles_dir = os.getenv('RUNFILES_DIR', None) + if runfiles_dir: + local_path = pathlib.Path( + os.path.join(runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so') + ) + if local_path.exists(): logger.debug( 'Native library %s does not exist. This most likely indicates an issue' From 1e58d76772afdcf006d03eeaf4743c08af6765cc Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Wed, 17 Jul 2024 19:09:07 -0500 Subject: [PATCH 075/702] [ROCm] Change ROCm builds to manylinux wheels --- build/rocm/Dockerfile.ms | 37 ++- build/rocm/build_rocm.sh | 56 +--- .../Dockerfile.manylinux_2_28_x86_64.rocm | 7 + build/rocm/ci_build | 256 ++++++++++++++++++ build/rocm/ci_build.sh | 156 ++++------- build/rocm/setup.rocm.sh | 2 +- build/rocm/tools/blacken.sh | 3 + build/rocm/tools/build_wheels.py | 222 +++++++++++++++ build/rocm/tools/fixwheel.py | 97 +++++++ build/rocm/tools/get_rocm.py | 100 +++++++ build/rocm/tools/libc.py | 48 ++++ build/rocm/tools/symbols.py | 53 ++++ 12 files changed, 878 insertions(+), 159 deletions(-) create mode 100644 build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm create mode 100755 build/rocm/ci_build create mode 100644 build/rocm/tools/blacken.sh create mode 100644 build/rocm/tools/build_wheels.py create mode 100644 build/rocm/tools/fixwheel.py create mode 100644 build/rocm/tools/get_rocm.py create mode 100644 build/rocm/tools/libc.py create mode 100644 build/rocm/tools/symbols.py diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index 5f831f111b25..899f29a14f58 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -1,6 +1,5 @@ ################################################################################ -ARG BASE_DOCKER=ubuntu:20.04 -FROM $BASE_DOCKER as rt_build +FROM ubuntu:20.04 AS rocm_base ################################################################################ # Add target file to help determine which device(s) to build for @@ -12,9 +11,9 @@ ARG ROCM_VERSION=6.0.0 ARG CUSTOM_INSTALL ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} ENV ROCM_PATH=${ROCM_PATH} -COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL} -COPY setup.rocm.sh /setup.rocm.sh -RUN /setup.rocm.sh $ROCM_VERSION +#COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL} +RUN --mount=type=bind,source=build/rocm/setup.rocm.sh,target=/setup.rocm.sh \ + /setup.rocm.sh $ROCM_VERSION # Set up paths ENV HCC_HOME=$ROCM_PATH/hcc @@ -25,13 +24,35 @@ ENV PATH="$ROCM_PATH/bin:${PATH}" ENV PATH="$OPENCL_ROOT/bin:${PATH}" ENV PATH="/root/bin:/root/.local/bin:$PATH" - # Install pyenv with different python versions -ARG PYTHON_VERSION=3.10.0 +ARG PYTHON_VERSION=3.10.14 RUN git clone https://github.com/pyenv/pyenv.git /pyenv ENV PYENV_ROOT /pyenv ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH RUN pyenv install $PYTHON_VERSION -RUN eval "$(pyenv init -)" && pyenv local ${PYTHON_VERSION} && pip3 install --upgrade --force-reinstall setuptools pip && pip install numpy setuptools build wheel six auditwheel scipy pytest pytest-html pytest_html_merger pytest-reportlog pytest-rerunfailures cloudpickle portpicker matplotlib absl-py flatbuffers hypothesis +RUN eval "$(pyenv init -)" && \ + pyenv local ${PYTHON_VERSION} && \ + pip3 install --upgrade --force-reinstall setuptools pip && \ + pip install \ + numpy setuptools build wheel six auditwheel scipy \ + pytest pytest-html pytest_html_merger pytest-reportlog \ + pytest-rerunfailures cloudpickle portpicker matplotlib absl-py \ + flatbuffers hypothesis + + +################################################################################ +FROM rocm_base AS rt_build +################################################################################ + +ARG JAX_VERSION +ARG JAX_COMMIT +ARG XLA_COMMIT +LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ + com.amdgpu.python_version="$PYTHON_VERSION" \ + com.amdgpu.jax_version="$JAX_VERSION" \ + com.amdgpu.jax_commit="$JAX_COMMIT" \ + com.amdgpu.xla_commit="$XLA_COMMIT" +RUN --mount=type=bind,source=wheelhouse,target=/wheelhouse \ + pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/build_rocm.sh b/build/rocm/build_rocm.sh index 6374a2a18929..111998d35608 100755 --- a/build/rocm/build_rocm.sh +++ b/build/rocm/build_rocm.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash + # Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,57 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Environment Var Notes -# XLA_CLONE_DIR - -# Specifies filepath to where XLA repo is cloned. -# NOTE:, if this is set then XLA repo is not cloned. Must clone repo before running this script. -# Also, if this is set then setting XLA_REPO and XLA_BRANCH have no effect. -# XLA_REPO -# XLA repo to clone from. Default is https://github.com/ROCmSoftwarePlatform/tensorflow-upstream -# XLA_BRANCH -# XLA branch in the XLA repo. Default is develop-upstream-jax -# +# NOTE(mrodden): ROCm JAX build and installs have moved to wheel based builds and installs, +# but some CI scripts still try to run this script. Nothing needs to be done here, +# but we print some debugging information for logs. set -eux python -V -#If XLA_REPO is not set, then use default -if [ ! -v XLA_REPO ]; then - XLA_REPO="https://github.com/openxla/xla.git" - XLA_BRANCH="main" -elif [ -z "$XLA_REPO" ]; then - XLA_REPO="https://github.com/openxla/xla.git" - XLA_BRANCH="main" -fi - -#If XLA_CLONE_PATH is not set, then use default path. -#Note, setting XLA_CLONE_PATH makes setting XLA_REPO and XLA_BRANCH a no-op -#Set this when XLA repository has been already clone. This is useful in CI -#environments and when doing local development -if [ ! -v XLA_CLONE_DIR ]; then - XLA_CLONE_DIR=/tmp/xla - rm -rf /tmp/xla || true - git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/xla -elif [ -z "$XLA_CLONE_DIR" ]; then - XLA_CLONE_DIR=/tmp/xla - rm -rf /tmp/xla || true - git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/xla -fi - - -#Export JAX_ROCM_VERSION so that it is appened in the wheel name -export JAXLIB_RELEASE=1 -rocm_version=$(cat /opt/rocm/.info/version | cut -d "-" -f 1) -export JAX_ROCM_VERSION=${rocm_version//./} - -#Build and install wheel -python3 ./build/build.py --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR} - -JAX_RELEASE=1 python -m build -pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA) - -#This is for CI to read without having to start the container again -if [ -v CI_RUN ]; then - pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1 > jax_version_installed - cat /opt/rocm/.info/version | cut -d "-" -f 1 > jax_rocm_version -fi +printf "Detected jaxlib version: %s\n" $(pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1) +printf "Detected ROCm version: %s\n" $(cat /opt/rocm/.info/version | cut -d "-" -f 1) diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm new file mode 100644 index 000000000000..fd2de6a0c06a --- /dev/null +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -0,0 +1,7 @@ +FROM quay.io/pypa/manylinux_2_28_x86_64 + +ARG ROCM_VERSION=6.1.1 + +RUN --mount=type=cache,target=/var/cache/dnf \ + --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + python3 get_rocm.py --rocm-version $ROCM_VERSION diff --git a/build/rocm/ci_build b/build/rocm/ci_build new file mode 100755 index 000000000000..a43bd26fdc0a --- /dev/null +++ b/build/rocm/ci_build @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +import os +import subprocess +import sys + + +def image_by_name(name): + cmd = ["docker", "images", "-q", "-f", "reference=%s" % name] + out = subprocess.check_output(cmd) + image_id = out.decode("utf8").strip().split("\n")[0] or None + return image_id + + +def dist_wheels(rocm_version, python_versions, xla_path): + xla_path = os.path.abspath(xla_path) + + # create manylinux image with requested ROCm installed + image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "") + + cmd = [ + "docker", + "build", + "-f", + "build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm", + "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--tag=%s" % image, + ".", + ] + + if not image_by_name(image): + _ = subprocess.run(cmd, check=True) + + # use image to build JAX/jaxlib wheels + os.makedirs("wheelhouse", exist_ok=True) + + pyver_string = ",".join(python_versions) + + container_xla_path = "/xla" + + bw_cmd = [ + "python3", + "/jax/build/rocm/tools/build_wheels.py", + "--rocm-version", + rocm_version, + "--python-versions", + pyver_string, + ] + + if xla_path: + bw_cmd.extend(["--xla-path", container_xla_path]) + + bw_cmd.append("/jax") + + cmd = ["docker", "run", "-it"] + + mounts = [ + "-v", + "./:/jax", + "-v", + "./wheelhouse:/wheelhouse", + ] + + if xla_path: + mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)]) + + cmd.extend(mounts) + + # NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID + cmd.extend( + [ + "--init", + "--rm", + image, + "bash", + "-c", + " ".join(bw_cmd), + ] + ) + + _ = subprocess.run(cmd, check=True) + + +def _fetch_jax_metadata(xla_path): + cmd = ["git", "rev-parse", "HEAD"] + jax_commit = subprocess.check_output(cmd) + xla_commit = "" + + if xla_path: + try: + xla_commit = subprocess.check_output(cmd, cwd=xla_path) + except Exception as ex: + LOG.warning("Exception while retrieving xla_commit: %s" % ex) + + cmd = ["python", "setup.py", "-V"] + env = dict(os.environ) + env["JAX_RELEASE"] = "1" + + jax_version = subprocess.check_output(cmd, env=env) + + return { + "jax_version": jax_version.decode("utf8").strip(), + "jax_commit": jax_commit.decode("utf8").strip(), + "xla_commit": xla_commit.decode("utf8").strip(), + } + + +def dist_docker( + rocm_version, + python_versions, + xla_path, + tag="rocm/jax-dev", + dockerfile=None, + keep_image=True, +): + if not dockerfile: + dockerfile = "build/rocm/Dockerfile.ms" + + python_version = python_versions[0] + + md = _fetch_jax_metadata(xla_path) + + cmd = [ + "docker", + "build", + "-f", + dockerfile, + "--target", + "rt_build", + "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--build-arg=PYTHON_VERSION=%s" % python_version, + "--build-arg=JAX_VERSION=%(jax_version)s" % md, + "--build-arg=JAX_COMMIT=%(jax_commit)s" % md, + "--build-arg=XLA_COMMIT=%(xla_commit)s" % md, + "--tag=%s" % tag, + ] + + if not keep_image: + cmd.append("--rm") + + # context dir + cmd.append(".") + + subprocess.check_call(cmd) + + +def test(image_name): + """Run unit tests like CI would inside a JAX image.""" + + gpu_args = [ + "--device=/dev/kfd", + "--device=/dev/dri", + "--group-add", + "video", + "--cap-add=SYS_PTRACE", + "--security-opt", + "seccomp=unconfined", + "--shm-size", + "16G", + ] + + cmd = [ + "docker", + "run", + "-it", + "--rm", + ] + + # NOTE(mrodden): we need jax source dir for the unit test code only, + # JAX and jaxlib are already installed from wheels + mounts = [ + "-v", + "./:/jax", + ] + + cmd.extend(mounts) + cmd.extend(gpu_args) + + container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh" + cmd.append(image_name) + cmd.extend( + [ + "bash", + "-c", + container_cmd, + ] + ) + + subprocess.check_call(cmd) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument( + "--python-versions", + type=lambda x: x.split(","), + default="3.12", + help="Comma separated list of CPython versions to build wheels for", + ) + + p.add_argument( + "--rocm-version", + default="6.1.1", + help="ROCm version used for building wheels, testing, and installing into Docker image", + ) + + p.add_argument( + "--xla-source-dir", + help="Path to XLA source to use during jaxlib build, instead of builtin XLA", + ) + + subp = p.add_subparsers(dest="action", required=True) + + dwp = subp.add_parser("dist_wheels") + + testp = subp.add_parser("test") + testp.add_argument("image_name") + + ddp = subp.add_parser("dist_docker") + ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms") + ddp.add_argument("--keep-image", action="store_true") + ddp.add_argument("--image-tag", default="rocm/jax-dev") + + return p.parse_args() + + +def main(): + args = parse_args() + + if args.action == "dist_wheels": + dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir) + + elif args.action == "test": + test(args.image_name) + + elif args.action == "dist_docker": + dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir) + dist_docker( + args.rocm_version, + args.python_versions, + args.xla_source_dir, + tag=args.image_tag, + dockerfile=args.dockerfile, + keep_image=args.keep_image, + ) + + +if __name__ == "__main__": + main() diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 9084651bed4c..ab599e2661a0 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash + # Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,12 +29,8 @@ # # ROCM_VERSION: ROCm repo version # -# ROCM_PATH: ROCM path in the docker container -# # Environment variables read by this script # WORKSPACE -# XLA_REPO -# XLA_BRANCH # XLA_CLONE_DIR # BUILD_TAG # @@ -44,75 +41,63 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "${SCRIPT_DIR}/build_common.sh" CONTAINER_TYPE="rocm" -DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms" +DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms" DOCKER_CONTEXT_PATH="${SCRIPT_DIR}" KEEP_IMAGE="--rm" -KEEP_CONTAINER="--rm" -PYTHON_VERSION="3.10.0" -ROCM_VERSION="6.0.0" #Point to latest release +PYTHON_VERSION="3.10" +ROCM_VERSION="6.1.3" BASE_DOCKER="ubuntu:20.04" CUSTOM_INSTALL="" -#BASE_DOCKER="compute-artifactory.amd.com:5000/rocm-plus-docker/compute-rocm-rel-6.0:91-ubuntu-20.04-stg2" -#CUSTOM_INSTALL="custom_install_dummy.sh" -#ROCM_PATH="/opt/rocm-5.6.0" POSITIONAL_ARGS=() RUNTIME_FLAG=1 while [[ $# -gt 0 ]]; do - case $1 in - --py_version) - PYTHON_VERSION="$2" - shift 2 - ;; - --dockerfile) - DOCKERFILE_PATH="$2" - DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}") - shift 2 - ;; - --keep_image) - KEEP_IMAGE="" - shift 1 - ;; - --runtime) - RUNTIME_FLAG=1 - shift 1 - ;; - --keep_container) - KEEP_CONTAINER="" - shift 1 - ;; - --rocm_version) - ROCM_VERSION="$2" - shift 2 - ;; - #--rocm_path) - # ROCM_PATH="$2" - # shift 2 - # ;; - - *) - POSITIONAL_ARGS+=("$1") - shift - ;; - esac + case $1 in + --py_version) + PYTHON_VERSION="$2" + shift 2 + ;; + --dockerfile) + DOCKERFILE_PATH="$2" + DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}") + shift 2 + ;; + --keep_image) + KEEP_IMAGE="" + shift 1 + ;; + --runtime) + RUNTIME_FLAG=1 + shift 1 + ;; + --keep_container) + KEEP_CONTAINER="" + shift 1 + ;; + --rocm_version) + ROCM_VERSION="$2" + shift 2 + ;; + *) + POSITIONAL_ARGS+=("$1") + shift + ;; + esac done if [[ ! -f "${DOCKERFILE_PATH}" ]]; then - die "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\"" + die "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\"" fi -ROCM_EXTRA_PARAMS="--device=/dev/kfd --device=/dev/dri --group-add video \ - --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G" - # Helper function to traverse directories up until given file is found. function upsearch (){ - test / == "$PWD" && return || \ - test -e "$1" && echo "$PWD" && return || \ - cd .. && upsearch "$1" + test / == "$PWD" && return || \ + test -e "$1" && echo "$PWD" && return || \ + cd .. && upsearch "$1" } -# Set up WORKSPACE. +# Set up WORKSPACE. WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" BUILD_TAG="${BUILD_TAG:-jax}" @@ -126,6 +111,7 @@ DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | sed -e 's/=/_/g' -e 's/,/-/g') # Convert to all lower-case, as per requirement of Docker image names DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]') + # Print arguments. echo "WORKSPACE: ${WORKSPACE}" echo "COMMAND: ${POSITIONAL_ARGS[*]}" @@ -135,55 +121,25 @@ echo "" echo "Building container (${DOCKER_IMG_NAME})..." echo "Python Version (${PYTHON_VERSION})" -if [[ "${RUNTIME_FLAG}" -eq 1 ]]; then - echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..." - docker build --target rt_build --tag ${DOCKER_IMG_NAME} \ - --build-arg PYTHON_VERSION=$PYTHON_VERSION --build-arg ROCM_VERSION=$ROCM_VERSION \ - --build-arg CUSTOM_INSTALL=$CUSTOM_INSTALL \ - --build-arg BASE_DOCKER=$BASE_DOCKER \ - -f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}" -else - echo "Building (CI) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..." - docker build --target ci_build --tag ${DOCKER_IMG_NAME} \ - --build-arg PYTHON_VERSION=$PYTHON_VERSION \ - --build-arg BASE_DOCKER=$BASE_DOCKER \ - -f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}" -fi - -# Check docker build status -if [[ $? != "0" ]]; then - die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}" -fi - -# Run the command inside the container. -echo "Running '${POSITIONAL_ARGS[*]}' inside ${DOCKER_IMG_NAME}..." +echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..." -export XLA_REPO="${XLA_REPO:-}" -export XLA_BRANCH="${XLA_BRANCH:-}" export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}" -export JAX_RENAME_WHL="${XLA_CLONE_DIR:-}" -if [ ! -z ${XLA_CLONE_DIR} ]; then - ROCM_EXTRA_PARAMS=${ROCM_EXTRA_PARAMS}" -v ${XLA_CLONE_DIR}:${XLA_CLONE_DIR}" -fi +# ci_build.sh is mostly a compatibility wrapper for ci_build + +# 'dist_docker' will run 'dist_wheels' followed by a Docker build to create the "JAX image", +# which is the ROCm image that is shipped for users to use (i.e. distributable). +./build/rocm/ci_build \ + --rocm-version $ROCM_VERSION \ + --python-versions $PYTHON_VERSION \ + --xla-source-dir $XLA_CLONE_DIR \ + dist_docker \ + --dockerfile $DOCKERFILE_PATH \ + --image-tag $DOCKER_IMG_NAME -docker run ${KEEP_IMAGE} --name ${DOCKER_IMG_NAME} --pid=host \ - -v ${WORKSPACE}:/workspace \ - -w /workspace \ - -e XLA_REPO=${XLA_REPO} \ - -e XLA_BRANCH=${XLA_BRANCH} \ - -e XLA_CLONE_DIR=${XLA_CLONE_DIR} \ - -e PYTHON_VERSION=$PYTHON_VERSION \ - -e CI_RUN=1 \ - ${ROCM_EXTRA_PARAMS} \ - "${DOCKER_IMG_NAME}" \ - ${POSITIONAL_ARGS[@]} - -if [[ "${KEEP_IMAGE}" != "--rm" ]] && [[ $? == "0" ]]; then - echo "Committing the docker container as ${DOCKER_IMG_NAME}" - docker stop ${DOCKER_IMG_NAME} - docker commit ${DOCKER_IMG_NAME} ${DOCKER_IMG_NAME} - docker rm ${DOCKER_IMG_NAME} # remove this temp container +# Check build status +if [[ $? != "0" ]]; then + die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}" fi echo "Jax-ROCm build was successful!" diff --git a/build/rocm/setup.rocm.sh b/build/rocm/setup.rocm.sh index 1ade67b17f6e..35c8f4c5166c 100755 --- a/build/rocm/setup.rocm.sh +++ b/build/rocm/setup.rocm.sh @@ -25,7 +25,7 @@ ROCM_DEB_REPO=${ROCM_DEB_REPO_HOME}${ROCM_VERS}/ if [ ! -f "/${CUSTOM_INSTALL}" ]; then # Add rocm repository chmod 1777 /tmp - DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update + DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update DEBIAN_FRONTEND=noninteractive apt install -y wget software-properties-common DEBIAN_FRONTEND=noninteractive apt-get clean all wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -; diff --git a/build/rocm/tools/blacken.sh b/build/rocm/tools/blacken.sh new file mode 100644 index 000000000000..7b61cbdb9e10 --- /dev/null +++ b/build/rocm/tools/blacken.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +black -t py36 build/rocm/ci_build build/rocm/tools/*.py diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py new file mode 100644 index 000000000000..9b9ff778811c --- /dev/null +++ b/build/rocm/tools/build_wheels.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +from collections import deque +import fcntl +import logging +import os +import re +import select +import subprocess +import shutil +import sys + + +LOG = logging.getLogger(__name__) + + +GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" + + +def build_rocm_path(rocm_version_str): + return "/opt/rocm-%s" % rocm_version_str + + +def update_rocm_targets(rocm_path, targets): + target_fp = os.path.join(rocm_path, "bin/target.lst") + version_fp = os.path.join(rocm_path, ".info/version") + with open(target_fp, "w") as fd: + fd.write("%s\n" % targets) + + # mimic touch + open(version_fp, "a").close() + + +def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None): + cmd = [ + "python", + "build/build.py", + "--enable_rocm", + "--build_gpu_plugin", + "--gpu_plugin_rocm_version=60", + "--rocm_path=%s" % rocm_path, + ] + + if xla_path: + cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path) + + cpy = to_cpy_ver(python_version) + py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy) + + env = dict(os.environ) + env["JAX_RELEASE"] = str(1) + env["JAXLIB_RELEASE"] = str(1) + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) + pattern = re.compile("Output wheel: (.+)\n") + + return _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr") + + +def build_jax_wheel(jax_path, python_version): + cmd = [ + "python", + "-m", + "build", + ] + + cpy = to_cpy_ver(python_version) + py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy) + + env = dict(os.environ) + env["JAX_RELEASE"] = str(1) + env["JAXLIB_RELEASE"] = str(1) + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) + pattern = re.compile("Successfully built jax-.+ and (jax-.+\.whl)\n") + + wheels = _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout") + + paths = list(map(lambda x: os.path.join(jax_path, "dist", x), wheels)) + return paths + + +def _run_scan_for_output(cmd, pattern, env=None, cwd=None, capture=None): + + buf = deque(maxlen=20000) + + if capture == "stderr": + p = subprocess.Popen(cmd, env=env, cwd=cwd, stderr=subprocess.PIPE) + redir = sys.stderr + cap_fd = p.stderr + else: + p = subprocess.Popen(cmd, env=env, cwd=cwd, stdout=subprocess.PIPE) + redir = sys.stdout + cap_fd = p.stdout + + flags = fcntl.fcntl(cap_fd, fcntl.F_GETFL) + fcntl.fcntl(cap_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + eof = False + while not eof: + r, _, _ = select.select([cap_fd], [], []) + for fd in r: + dat = fd.read(512) + if dat is None: + continue + elif dat: + t = dat.decode("utf8") + redir.write(t) + buf.extend(t) + else: + eof = True + + # wait and drain pipes + _, _ = p.communicate() + + if p.returncode != 0: + raise Exception( + "Child process exited with nonzero result: rc=%d" % p.returncode + ) + + text = "".join(buf) + + matches = pattern.findall(text) + + if not matches: + LOG.error("No wheel name found in output: %r" % text) + raise Exception("No wheel name found in output") + + wheels = [] + for match in matches: + LOG.info("Found built wheel: %r" % match) + wheels.append(match) + + return wheels + + +def to_cpy_ver(python_version): + tup = python_version.split(".") + return "cp%d%d" % (int(tup[0]), int(tup[1])) + + +def fix_wheel(path, jax_path): + # NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8 + # so use one of the CPythons in /opt to run + env = dict(os.environ) + py_bin = "/opt/python/cp310-cp310/bin" + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + cmd = ["pip", "install", "auditwheel>=6"] + subprocess.run(cmd, check=True, env=env) + + fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") + cmd = ["python", fixwheel_path, path] + subprocess.run(cmd, check=True, env=env) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument( + "--rocm-version", default="6.1.1", help="ROCM Version to build JAX against" + ) + p.add_argument( + "--python-versions", + default=["3.10.19,3.12"], + help="Comma separated CPython versions that wheels will be built and output for", + ) + p.add_argument( + "--xla-path", + type=str, + default=None, + help="Optional directory where XLA source is located to use instead of JAX builtin XLA", + ) + + p.add_argument("jax_path", help="Directory where JAX source directory is located") + + return p.parse_args() + + +def main(): + args = parse_args() + python_versions = args.python_versions.split(",") + + print("ROCM_VERSION=%s" % args.rocm_version) + print("PYTHON_VERSIONS=%r" % python_versions) + print("JAX_PATH=%s" % args.jax_path) + print("XLA_PATH=%s" % args.xla_path) + + rocm_path = build_rocm_path(args.rocm_version) + + update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS) + + for py in python_versions: + wheel_paths = build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path) + for wheel_path in wheel_paths: + fix_wheel(wheel_path, args.jax_path) + + # build JAX wheel for completeness + jax_wheels = build_jax_wheel(args.jax_path, python_versions[-1]) + + # NOTE(mrodden): the jax wheel is a "non-platform wheel", so auditwheel will + # do nothing, and in fact will throw an Exception. we just need to copy it + # along with the jaxlib and plugin ones + + # copy jax wheel(s) to wheelhouse + wheelhouse_dir = "/wheelhouse/" + for whl in jax_wheels: + LOG.info("Copying %s into %s" % (whl, wheelhouse_dir)) + shutil.copy(whl, wheelhouse_dir) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/build/rocm/tools/fixwheel.py b/build/rocm/tools/fixwheel.py new file mode 100644 index 000000000000..d5951cdd4fc1 --- /dev/null +++ b/build/rocm/tools/fixwheel.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +import logging +import os +from pprint import pprint +import subprocess + +from auditwheel.lddtree import lddtree +from auditwheel.wheeltools import InWheelCtx +from auditwheel.elfutils import elf_file_filter +from auditwheel.policy import WheelPolicies +from auditwheel.wheel_abi import analyze_wheel_abi + + +LOG = logging.getLogger(__name__) + + +def tree(path): + + with InWheelCtx(path) as ctx: + for sofile, fd in elf_file_filter(ctx.iter_files()): + + LOG.info("found SO file: %s" % sofile) + elftree = lddtree(sofile) + + print(elftree) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("wheel_path") + return p.parse_args() + + +def parse_wheel_name(path): + wheel_name = os.path.basename(path) + return wheel_name[:-4].split("-") + + +def fix_wheel(path): + tup = parse_wheel_name(path) + plat_tag = tup[4] + if "manylinux2014" in plat_tag: + # strip any manylinux tags from the current wheel first + from wheel.cli import tags + + plat_mod_str = "linux_x86_64" + new_wheel = tags.tags( + path, + python_tags=None, + abi_tags=None, + platform_tags=plat_mod_str, + build_tag=None, + ) + new_path = os.path.join(os.path.dirname(path), new_wheel) + LOG.info("Stripped broken tags and created new wheel at %r" % new_path) + path = new_path + + # build excludes, using auditwheels lddtree to find them + wheel_pol = WheelPolicies() + exclude = frozenset() + abi = analyze_wheel_abi(wheel_pol, path, exclude) + + plat = "manylinux_2_28_x86_64" + ext_libs = abi.external_refs.get(plat, {}).get("libs") + exclude = list(ext_libs.keys()) + + # call auditwheel repair with excludes + cmd = ["auditwheel", "repair", "--plat", plat, "--only-plat"] + + for ex in exclude: + cmd.append("--exclude") + cmd.append(ex) + + cmd.append(path) + + LOG.info("running %r" % cmd) + + rc = subprocess.run(cmd, check=True) + + +def main(): + args = parse_args() + path = args.wheel_path + fix_wheel(path) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/build/rocm/tools/get_rocm.py b/build/rocm/tools/get_rocm.py new file mode 100644 index 000000000000..4cc4a4682acb --- /dev/null +++ b/build/rocm/tools/get_rocm.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +import logging +import subprocess + + +LOG = logging.getLogger(__name__) + + +def which_linux(): + try: + os_rel = open("/etc/os-release").read() + + kvs = {} + for line in os_rel.split("\n"): + if line.strip(): + k, v = line.strip().split("=", 1) + v = v.strip('"') + kvs[k] = v + + print(kvs) + except OSError: + pass + + +rocm_package_names = [ + "libdrm-amdgpu", + "rocm-dev", + "rocm-ml-sdk", + "miopen-hip ", + "miopen-hip-devel", + "rocblas", + "rocblas-devel", + "rocsolver-devel", + "rocrand-devel", + "rocfft-devel", + "hipfft-devel", + "hipblas-devel", + "rocprim-devel", + "hipcub-devel", + "rccl-devel", + "hipsparse-devel", + "hipsolver-devel", +] + + +def install_rocm_el8(rocm_version_str): + + with open("/etc/yum.repos.d/rocm.repo", "w") as rfd: + rfd.write( + """ +[ROCm] +name=ROCm +baseurl=http://repo.radeon.com/rocm/rhel8/%s/main +enabled=1 +gpgcheck=1 +gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key +""" + % rocm_version_str + ) + + with open("/etc/yum.repos.d/amdgpu.repo", "w") as afd: + afd.write( + """ +[amdgpu] +name=amdgpu +baseurl=https://repo.radeon.com/amdgpu/latest/rhel/8.8/main/x86_64/ +enabled=1 +gpgcheck=1 +gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key +""" + ) + + cmd = ["dnf", "install", "-y"] + cmd.extend(rocm_package_names) + LOG.info("Running %r" % cmd) + subprocess.run(cmd, check=True) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--rocm-version", help="ROCm version to install", default="6.1.1") + return p.parse_args() + + +def main(): + args = parse_args() + install_rocm_el8(args.rocm_version) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/build/rocm/tools/libc.py b/build/rocm/tools/libc.py new file mode 100644 index 000000000000..61983d6c258c --- /dev/null +++ b/build/rocm/tools/libc.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import os +import sys + + +def get_libc_version(): + """ + Detect and return glibc version that the current Python is linked against. + + This mimics the detection behavior of the 'wheel' and 'auditwheel' projects, + but without any PyPy or libmusl support. + """ + + try: + version_str = os.confstr("CS_GNU_LIBC_VERSION") + return version_str + except Exception: + print("WARN: lookup by confstr failed", file=sys.stderr) + pass + + try: + import ctypes + except ImportError: + return None + + pn = ctypes.CDLL(None) + print(dir(pn)) + + try: + gnu_get_libc_version = pn.gnu_get_libc_version + except AttributeError: + return None + + gnu_get_libc_version.restype = ctypes.c_char_p + version_str = gnu_get_libc_version() + + return version_str + + +if __name__ == "__main__": + print(get_libc_version()) diff --git a/build/rocm/tools/symbols.py b/build/rocm/tools/symbols.py new file mode 100644 index 000000000000..dc74a0a9bb87 --- /dev/null +++ b/build/rocm/tools/symbols.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import pprint +import re +import sys +import subprocess + +""" +Utility for examining GLIBC versioned symbols +for an object file (shared object or ELF binary) +""" + + +def main(): + sofile = sys.argv[1] + + s = highest_for_file(sofile) + + print("%s: %r" % (sofile, s)) + + +def highest_for_file(sofile): + output = subprocess.check_output(["objdump", "-T", sofile]) + + r = re.compile("\(GLIBC_(.*)\)") + versions = {} + + for line in output.decode("utf-8").split("\n"): + line = line.strip() + match = r.search(line) + if match: + version_str = match.group(1) + count = versions.get(version_str, 0) + versions[version_str] = count + 1 + + vtups = list(map(lambda x: parse(x), versions.keys())) + s = sorted(vtups) + + return s[-1] + + +def parse(version_str): + return tuple(map(int, version_str.split("."))) + + +if __name__ == "__main__": + main() From 3175f13c5957317e15e1654bd48c1fd9957f0797 Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Thu, 25 Jul 2024 15:08:15 -0500 Subject: [PATCH 076/702] Add internal release support to get_rocm.py --- build/rocm/tools/get_rocm.py | 281 +++++++++++++++++++++++++++++++---- 1 file changed, 250 insertions(+), 31 deletions(-) diff --git a/build/rocm/tools/get_rocm.py b/build/rocm/tools/get_rocm.py index 4cc4a4682acb..60ecb7f0102b 100644 --- a/build/rocm/tools/get_rocm.py +++ b/build/rocm/tools/get_rocm.py @@ -7,14 +7,27 @@ import argparse +import json import logging +import os +import sys import subprocess +import urllib.request LOG = logging.getLogger(__name__) -def which_linux(): +def latest_rocm(): + dat = urllib.request.urlopen( + "https://api.github.com/repos/rocm/rocm/releases/latest" + ).read() + rd = json.loads(dat) + _, ver_str = rd["tag_name"].split("-") + return ver_str + + +def os_release_meta(): try: os_rel = open("/etc/os-release").read() @@ -25,33 +38,229 @@ def which_linux(): v = v.strip('"') kvs[k] = v - print(kvs) + return kvs except OSError: pass -rocm_package_names = [ - "libdrm-amdgpu", - "rocm-dev", - "rocm-ml-sdk", - "miopen-hip ", - "miopen-hip-devel", - "rocblas", - "rocblas-devel", - "rocsolver-devel", - "rocrand-devel", - "rocfft-devel", - "hipfft-devel", - "hipblas-devel", - "rocprim-devel", - "hipcub-devel", - "rccl-devel", - "hipsparse-devel", - "hipsolver-devel", -] - - -def install_rocm_el8(rocm_version_str): +class System(object): + + def __init__(self, pkgbin, rocm_package_list): + self.pkgbin = pkgbin + self.rocm_package_list = rocm_package_list + + def install_packages(self, package_specs): + cmd = [ + self.pkgbin, + "install", + "-y", + ] + cmd.extend(package_specs) + + env = dict(os.environ) + if self.pkgbin == "apt": + env["DEBIAN_FRONTEND"] = "noninteractive" + + LOG.info("Running %r" % cmd) + subprocess.check_call(cmd, env=env) + + def install_rocm(self): + self.install_packages(self.rocm_package_list) + + +UBUNTU = System( + pkgbin="apt", + rocm_package_list=[ + "rocm-dev", + "rocm-libs", + ], +) + + +RHEL8 = System( + pkgbin="dnf", + rocm_package_list=[ + "libdrm-amdgpu", + "rocm-dev", + "rocm-ml-sdk", + "miopen-hip ", + "miopen-hip-devel", + "rocblas", + "rocblas-devel", + "rocsolver-devel", + "rocrand-devel", + "rocfft-devel", + "hipfft-devel", + "hipblas-devel", + "rocprim-devel", + "hipcub-devel", + "rccl-devel", + "hipsparse-devel", + "hipsolver-devel", + ], +) + + +def get_system(): + md = os_release_meta() + + if md["ID"] == "ubuntu": + return UBUNTU + + if md["ID"] in ["almalinux", "rhel", "fedora", "centos"]: + if md["PLATFORM_ID"] == "platform:el8": + return RHEL8 + + raise Exception("No system for %r" % md) + + +def _setup_internal_repo(system, rocm_version, job_name, build_num): + # wget is required by amdgpu-repo + system.install_packages(["wget"]) + + install_amdgpu_installer_internal(rocm_version) + + amdgpu_build = ( + urllib.request.urlopen( + "http://rocm-ci.amd.com/job/%s/%s/artifact/amdgpu_kernel_info.txt" + % (job_name, build_num) + ) + .read() + .decode("utf8") + .strip() + ) + + cmd = [ + "amdgpu-repo", + "--amdgpu-build=%s" % amdgpu_build, + "--rocm-build=%s/%s" % (job_name, build_num), + ] + LOG.info("Running %r" % cmd) + subprocess.check_call(cmd) + + cmd = [ + "amdgpu-install", + "--no-dkms", + "--usecase=rocm", + "-y", + ] + + env = dict(os.environ) + if system.pkgbin == "apt": + env["DEBIAN_FRONTEND"] = "noninteractive" + + LOG.info("Running %r" % cmd) + subprocess.check_call(cmd, env=env) + + +def install_rocm(rocm_version, job_name=None, build_num=None): + s = get_system() + + if job_name and build_num: + _setup_internal_repo(s, rocm_version, job_name, build_num) + else: + if s == RHEL8: + setup_repos_el8(rocm_version) + elif s == UBUNTU: + setup_repos_ubuntu(rocm_version) + else: + raise Exception("Platform not supported") + + s.install_rocm() + + +def install_amdgpu_installer_internal(rocm_version): + """ + Download and install the "amdgpu-installer" package from internal builds + on the current system. + """ + md = os_release_meta() + url, fn = _build_installer_url(rocm_version, md) + + try: + # download installer + LOG.info("Downloading from %s", url) + urllib.request.urlretrieve(url, filename=fn) + + system = get_system() + + cmd = [system.pkgbin, "install", "-y", "./%s" % fn] + subprocess.check_call(cmd) + finally: + try: + os.remove(fn) + except FileNotFoundError: + pass + + +def _build_installer_url(rocm_version, metadata): + md = metadata + + if isinstance(rocm_version, str): + parts = rocm_version.split(".") + rv = type("Version", (), {})() + rv.major = parts[0] + rv.minor = parts[1] + + if len(parts) > 2: + rv.rev = parts[2] + else: + rv = rocm_version + + base_url = "http://artifactory-cdn.amd.com/artifactory/list" + + if md["ID"] == "ubuntu": + fmt = "amdgpu-install-internal_%(rocm_major)s.%(rocm_minor)s-%(os_version)s-1_all.deb" + package_name = fmt % { + "rocm_major": rv.major, + "rocm_minor": rv.minor, + "os_version": md["VERSION_ID"], + } + + url = "%s/amdgpu-deb/%s" % (base_url, package_name) + elif md.get("PLATFORM_ID") == "platform:el8": + fmt = "amdgpu-install-internal-%(rocm_major)s.%(rocm_minor)s_%(os_version)s-1.noarch.rpm" + package_name = fmt % { + "rocm_major": rv.major, + "rocm_minor": rv.minor, + "os_version": "8", + } + + url = "%s/amdgpu-rpm/rhel/%s" % (base_url, package_name) + else: + raise Exception("Platform not supported: %r" % md) + + return url, package_name + + +def setup_repos_ubuntu(rocm_version_str): + + s = get_system() + s.install_packages(["wget", "sudo", "gnupg"]) + + md = os_release_meta() + codename = md["VERSION_CODENAME"] + + keyadd = "wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | sudo apt-key add -" + subprocess.check_call(keyadd, shell=True) + + with open("/etc/apt/sources.list.d/amdgpu.list", "w") as fd: + fd.write( + ("deb [arch=amd64] " "https://repo.radeon.com/amdgpu/%s/ubuntu %s main\n") + % (rocm_version_str, codename) + ) + + with open("/etc/apt/sources.list.d/rocm.list", "w") as fd: + fd.write( + ("deb [arch=amd64] " "https://repo.radeon.com/rocm/apt/%s %s main\n") + % (rocm_version_str, codename) + ) + + # update indexes + subprocess.check_call(["apt-get", "update"]) + + +def setup_repos_el8(rocm_version_str): with open("/etc/yum.repos.d/rocm.repo", "w") as rfd: rfd.write( @@ -78,21 +287,31 @@ def install_rocm_el8(rocm_version_str): """ ) - cmd = ["dnf", "install", "-y"] - cmd.extend(rocm_package_names) - LOG.info("Running %r" % cmd) - subprocess.run(cmd, check=True) - def parse_args(): p = argparse.ArgumentParser() - p.add_argument("--rocm-version", help="ROCm version to install", default="6.1.1") + p.add_argument("--rocm-version", help="ROCm version to install", default="latest") + p.add_argument("--job-name", default=None) + p.add_argument("--build-num", default=None) return p.parse_args() def main(): args = parse_args() - install_rocm_el8(args.rocm_version) + if args.rocm_version == "latest": + try: + rocm_version = latest_rocm() + print("Latest ROCm release: %s" % rocm_version) + except Exception: + print( + "Latest ROCm lookup failed. Please use '--rocm-version' to specify a version instead.", + file=sys.stderr, + ) + sys.exit(-1) + else: + rocm_version = args.rocm_version + + install_rocm(rocm_version, job_name=args.job_name, build_num=args.build_num) if __name__ == "__main__": From a1a0a4ecddfd577694ba1fe6d79174694d3ffc7a Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Tue, 30 Jul 2024 14:51:00 -0500 Subject: [PATCH 077/702] Add support for ROCm development builds Use get_rocm.py changes in ci_build to pull in development builds for ROCm. Specify ROCM_BUILD_JOB and ROCM_BUILD_NUM for activating the development build path. --- build/rocm/Dockerfile.ms | 15 +++++-- .../Dockerfile.manylinux_2_28_x86_64.rocm | 4 +- build/rocm/ci_build | 43 +++++++++++++++++-- build/rocm/ci_build.sh | 12 +++++- build/rocm/tools/build_wheels.py | 6 ++- 5 files changed, 69 insertions(+), 11 deletions(-) diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index 899f29a14f58..5e28a2b24c75 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -2,18 +2,21 @@ FROM ubuntu:20.04 AS rocm_base ################################################################################ +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -y python3 + # Add target file to help determine which device(s) to build for ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} # Install ROCM ARG ROCM_VERSION=6.0.0 -ARG CUSTOM_INSTALL ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} ENV ROCM_PATH=${ROCM_PATH} -#COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL} -RUN --mount=type=bind,source=build/rocm/setup.rocm.sh,target=/setup.rocm.sh \ - /setup.rocm.sh $ROCM_VERSION +ARG ROCM_BUILD_JOB +ARG ROCM_BUILD_NUM +RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM # Set up paths ENV HCC_HOME=$ROCM_PATH/hcc @@ -24,6 +27,10 @@ ENV PATH="$ROCM_PATH/bin:${PATH}" ENV PATH="$OPENCL_ROOT/bin:${PATH}" ENV PATH="/root/bin:/root/.local/bin:$PATH" +# install pyenv dependencies +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -y git libssl-dev + # Install pyenv with different python versions ARG PYTHON_VERSION=3.10.14 RUN git clone https://github.com/pyenv/pyenv.git /pyenv diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index fd2de6a0c06a..caf303d45ff3 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -1,7 +1,9 @@ FROM quay.io/pypa/manylinux_2_28_x86_64 ARG ROCM_VERSION=6.1.1 +ARG ROCM_BUILD_JOB +ARG ROCM_BUILD_NUM RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ - python3 get_rocm.py --rocm-version $ROCM_VERSION + python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM diff --git a/build/rocm/ci_build b/build/rocm/ci_build index a43bd26fdc0a..c85ab3cadd30 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -19,8 +19,11 @@ def image_by_name(name): return image_id -def dist_wheels(rocm_version, python_versions, xla_path): - xla_path = os.path.abspath(xla_path) +def dist_wheels( + rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="" +): + if xla_path: + xla_path = os.path.abspath(xla_path) # create manylinux image with requested ROCm installed image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "") @@ -31,6 +34,8 @@ def dist_wheels(rocm_version, python_versions, xla_path): "-f", "build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm", "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, + "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--tag=%s" % image, ".", ] @@ -116,6 +121,8 @@ def dist_docker( rocm_version, python_versions, xla_path, + rocm_build_job="", + rocm_build_num="", tag="rocm/jax-dev", dockerfile=None, keep_image=True, @@ -135,6 +142,8 @@ def dist_docker( "--target", "rt_build", "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, + "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--build-arg=PYTHON_VERSION=%s" % python_version, "--build-arg=JAX_VERSION=%(jax_version)s" % md, "--build-arg=JAX_COMMIT=%(jax_commit)s" % md, @@ -211,6 +220,18 @@ def parse_args(): help="ROCm version used for building wheels, testing, and installing into Docker image", ) + p.add_argument( + "--rocm-build-job", + default="", + help="ROCm build job for development ROCm builds", + ) + + p.add_argument( + "--rocm-build-num", + default="", + help="ROCm build number for development ROCm builds", + ) + p.add_argument( "--xla-source-dir", help="Path to XLA source to use during jaxlib build, instead of builtin XLA", @@ -235,17 +256,31 @@ def main(): args = parse_args() if args.action == "dist_wheels": - dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir) + dist_wheels( + args.rocm_version, + args.python_versions, + args.xla_source_dir, + args.rocm_build_job, + args.rocm_build_num, + ) elif args.action == "test": test(args.image_name) elif args.action == "dist_docker": - dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir) + dist_wheels( + args.rocm_version, + args.python_versions, + args.xla_source_dir, + args.rocm_build_job, + args.rocm_build_num, + ) dist_docker( args.rocm_version, args.python_versions, args.xla_source_dir, + rocm_build_job=args.rocm_build_job, + rocm_build_num=args.rocm_build_num, tag=args.image_tag, dockerfile=args.dockerfile, keep_image=args.keep_image, diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index ab599e2661a0..d552d0f7a7bf 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -79,6 +79,14 @@ while [[ $# -gt 0 ]]; do ROCM_VERSION="$2" shift 2 ;; + --rocm_job) + ROCM_BUILD_JOB="$2" + shift 2 + ;; + --rocm_build) + ROCM_BUILD_NUM="$2" + shift 2 + ;; *) POSITIONAL_ARGS+=("$1") shift @@ -132,7 +140,9 @@ export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}" ./build/rocm/ci_build \ --rocm-version $ROCM_VERSION \ --python-versions $PYTHON_VERSION \ - --xla-source-dir $XLA_CLONE_DIR \ + --xla-source-dir=$XLA_CLONE_DIR \ + --rocm-build-job=$ROCM_BUILD_JOB \ + --rocm-build-num=$ROCM_BUILD_NUM \ dist_docker \ --dockerfile $DOCKERFILE_PATH \ --image-tag $DOCKER_IMG_NAME diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index 9b9ff778811c..b69af98e0519 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -25,7 +25,11 @@ def build_rocm_path(rocm_version_str): - return "/opt/rocm-%s" % rocm_version_str + path = "/opt/rocm-%s" % rocm_version_str + if os.path.exists(path): + return path + else: + return os.path.realpath("/opt/rocm") def update_rocm_targets(rocm_path, targets): From abe44f6d9e3a4619a2548563ddbb96287f3757c1 Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Wed, 31 Jul 2024 11:10:05 -0500 Subject: [PATCH 078/702] Add copyright and license headers to new files --- build/rocm/ci_build | 14 ++++++++++++++ build/rocm/tools/build_wheels.py | 14 ++++++++++++++ build/rocm/tools/fixwheel.py | 14 ++++++++++++++ build/rocm/tools/get_rocm.py | 14 ++++++++++++++ build/rocm/tools/libc.py | 14 ++++++++++++++ build/rocm/tools/symbols.py | 14 ++++++++++++++ 6 files changed, 84 insertions(+) diff --git a/build/rocm/ci_build b/build/rocm/ci_build index c85ab3cadd30..7ecc93eabe9e 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -1,5 +1,19 @@ #!/usr/bin/env python3 +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # NOTE(mrodden): This file is part of the ROCm build scripts, and # needs be compatible with Python 3.6. Please do not include these diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index b69af98e0519..c2da03602e49 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -1,5 +1,19 @@ #!/usr/bin/env python3 +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # NOTE(mrodden): This file is part of the ROCm build scripts, and # needs be compatible with Python 3.6. Please do not include these diff --git a/build/rocm/tools/fixwheel.py b/build/rocm/tools/fixwheel.py index d5951cdd4fc1..ea77162728d5 100644 --- a/build/rocm/tools/fixwheel.py +++ b/build/rocm/tools/fixwheel.py @@ -1,5 +1,19 @@ #!/usr/bin/env python3 +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # NOTE(mrodden): This file is part of the ROCm build scripts, and # needs be compatible with Python 3.6. Please do not include these diff --git a/build/rocm/tools/get_rocm.py b/build/rocm/tools/get_rocm.py index 60ecb7f0102b..d29f67982d4c 100644 --- a/build/rocm/tools/get_rocm.py +++ b/build/rocm/tools/get_rocm.py @@ -1,5 +1,19 @@ #!/usr/bin/env python3 +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # NOTE(mrodden): This file is part of the ROCm build scripts, and # needs be compatible with Python 3.6. Please do not include these diff --git a/build/rocm/tools/libc.py b/build/rocm/tools/libc.py index 61983d6c258c..1cd16b04cd14 100644 --- a/build/rocm/tools/libc.py +++ b/build/rocm/tools/libc.py @@ -1,5 +1,19 @@ #!/usr/bin/env python3 +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # NOTE(mrodden): This file is part of the ROCm build scripts, and # needs be compatible with Python 3.6. Please do not include these diff --git a/build/rocm/tools/symbols.py b/build/rocm/tools/symbols.py index dc74a0a9bb87..f2bf2d561f72 100644 --- a/build/rocm/tools/symbols.py +++ b/build/rocm/tools/symbols.py @@ -1,5 +1,19 @@ #!/usr/bin/env python3 +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # NOTE(mrodden): This file is part of the ROCm build scripts, and # needs be compatible with Python 3.6. Please do not include these From 319ebf81c166d2f6f407d36ec7d0707121e71a2e Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Wed, 31 Jul 2024 13:58:42 -0500 Subject: [PATCH 079/702] Add defaults for ROCm build vars --- build/rocm/ci_build.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index d552d0f7a7bf..7f93af8cae4c 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -46,6 +46,8 @@ DOCKER_CONTEXT_PATH="${SCRIPT_DIR}" KEEP_IMAGE="--rm" PYTHON_VERSION="3.10" ROCM_VERSION="6.1.3" +ROCM_BUILD_JOB="" +ROCM_BUILD_NUM="" BASE_DOCKER="ubuntu:20.04" CUSTOM_INSTALL="" POSITIONAL_ARGS=() From df2d140f518b880be8944ff953299c42b6c1bf94 Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Wed, 31 Jul 2024 14:20:39 -0500 Subject: [PATCH 080/702] Fix jenkins notty issue --- build/rocm/ci_build | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 7ecc93eabe9e..31ee591aedd0 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -78,7 +78,7 @@ def dist_wheels( bw_cmd.append("/jax") - cmd = ["docker", "run", "-it"] + cmd = ["docker", "run"] mounts = [ "-v", From 701cda8ebd714d28325b463c270e505e143463a6 Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Tue, 6 Aug 2024 15:04:46 -0500 Subject: [PATCH 081/702] Fix not finding wheels in bazel output --- build/rocm/tools/build_wheels.py | 35 ++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index c2da03602e49..7f2a4c862bf0 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -80,7 +80,7 @@ def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None): LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) pattern = re.compile("Output wheel: (.+)\n") - return _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr") + _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr") def build_jax_wheel(jax_path, python_version): @@ -101,10 +101,7 @@ def build_jax_wheel(jax_path, python_version): LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) pattern = re.compile("Successfully built jax-.+ and (jax-.+\.whl)\n") - wheels = _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout") - - paths = list(map(lambda x: os.path.join(jax_path, "dist", x), wheels)) - return paths + _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout") def _run_scan_for_output(cmd, pattern, env=None, cwd=None, capture=None): @@ -203,6 +200,17 @@ def parse_args(): return p.parse_args() +def find_wheels(path): + wheels = [] + + for f in os.listdir(path): + if f.endswith(".whl"): + wheels.append(os.path.join(path, f)) + + LOG.info("Found wheels: %r" % wheels) + return wheels + + def main(): args = parse_args() python_versions = args.python_versions.split(",") @@ -217,12 +225,16 @@ def main(): update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS) for py in python_versions: - wheel_paths = build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path) + build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path) + wheel_paths = find_wheels(os.path.join(args.jax_path, "dist")) for wheel_path in wheel_paths: - fix_wheel(wheel_path, args.jax_path) + # skip jax wheel since it is non-platform + if not os.path.basename(wheel_path).startswith("jax-"): + fix_wheel(wheel_path, args.jax_path) # build JAX wheel for completeness - jax_wheels = build_jax_wheel(args.jax_path, python_versions[-1]) + build_jax_wheel(args.jax_path, python_versions[-1]) + wheels = find_wheels(os.path.join(args.jax_path, "dist")) # NOTE(mrodden): the jax wheel is a "non-platform wheel", so auditwheel will # do nothing, and in fact will throw an Exception. we just need to copy it @@ -230,9 +242,10 @@ def main(): # copy jax wheel(s) to wheelhouse wheelhouse_dir = "/wheelhouse/" - for whl in jax_wheels: - LOG.info("Copying %s into %s" % (whl, wheelhouse_dir)) - shutil.copy(whl, wheelhouse_dir) + for whl in wheels: + if os.path.basename(whl).startswith("jax-"): + LOG.info("Copying %s into %s" % (whl, wheelhouse_dir)) + shutil.copy(whl, wheelhouse_dir) if __name__ == "__main__": From fafa03c60f25ff732ea92d946caa2823bd0f8b71 Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Tue, 6 Aug 2024 17:00:08 -0500 Subject: [PATCH 082/702] Add missing CPython build deps for pyenv --- build/rocm/Dockerfile.ms | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index 5e28a2b24c75..9d19486b6557 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -27,9 +27,14 @@ ENV PATH="$ROCM_PATH/bin:${PATH}" ENV PATH="$OPENCL_ROOT/bin:${PATH}" ENV PATH="/root/bin:/root/.local/bin:$PATH" -# install pyenv dependencies +# install pyenv and python build dependencies RUN --mount=type=cache,target=/var/cache/apt \ - apt-get update && apt-get install -y git libssl-dev + apt-get update && apt-get install -y \ + git \ + libssl-dev \ + libffi-dev \ + libreadline-dev \ + liblzma-dev # Install pyenv with different python versions ARG PYTHON_VERSION=3.10.14 From 2644299f7e66e11b31f26b4556a6f0802ee5f170 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sun, 11 Aug 2024 17:48:19 -0700 Subject: [PATCH 083/702] docs: sentence case index and sub-index headings We currently use both forms, so for consistency (and easier reading), pick this one. --- docs/advanced_guide.rst | 10 +++++----- docs/contributor_guide.rst | 2 +- docs/faq.rst | 2 +- docs/glossary.rst | 2 +- docs/index.rst | 12 ++++++------ docs/user_guides.rst | 8 ++++---- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index 5fe6c03ee059..86742313822b 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -1,6 +1,6 @@ .. _advanced_guide: -Advanced Tutorials +Advanced tutorials ================== This section contains examples and tutorials on more advanced topics, such as Multi Core computation, Custom operations, and more in depth applications @@ -13,7 +13,7 @@ This section contains examples and tutorials on more advanced topics, such as Mu notebooks/vmapped_log_probs .. toctree:: - :caption: Parallel Computation + :caption: Parallel computation :maxdepth: 1 multi_process @@ -22,7 +22,7 @@ This section contains examples and tutorials on more advanced topics, such as Mu distributed_data_loading .. toctree:: - :caption: Automatic Differentiation + :caption: Automatic differentiation :maxdepth: 1 notebooks/autodiff_cookbook @@ -30,7 +30,7 @@ This section contains examples and tutorials on more advanced topics, such as Mu notebooks/autodiff_remat .. toctree:: - :caption: JAX Internals + :caption: JAX internals :maxdepth: 1 notebooks/How_JAX_primitives_work @@ -38,7 +38,7 @@ This section contains examples and tutorials on more advanced topics, such as Mu Custom_Operation_for_GPUs .. toctree:: - :caption: Deep Dives + :caption: Deep dives :maxdepth: 1 notebooks/convolutions diff --git a/docs/contributor_guide.rst b/docs/contributor_guide.rst index cb0c034be850..b5ebd5057df7 100644 --- a/docs/contributor_guide.rst +++ b/docs/contributor_guide.rst @@ -1,6 +1,6 @@ .. _contributor-guide: -Developer Documentation +Developer documentation ======================= JAX welcomes contributions from the community. diff --git a/docs/faq.rst b/docs/faq.rst index 3b63128d2c28..92868dc5df42 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -1,4 +1,4 @@ -JAX Frequently Asked Questions (FAQ) +JAX frequently asked questions (FAQ) ==================================== .. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html diff --git a/docs/glossary.rst b/docs/glossary.rst index 78b7fcd246f3..179c3c75dc6c 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -1,4 +1,4 @@ -JAX Glossary of Terms +JAX glossary of terms ===================== .. glossary:: diff --git a/docs/index.rst b/docs/index.rst index 2e13c109dbbe..ef59f87dd2d0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -27,7 +27,7 @@ For an end-to-end transformer library built on JAX, see MaxText_. JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization. - .. grid-item-card:: Run Anywhere + .. grid-item-card:: Run anywhere :columns: 12 6 6 4 :class-card: sd-border-0 :shadow: None @@ -36,19 +36,19 @@ For an end-to-end transformer library built on JAX, see MaxText_. .. grid:: 3 - .. grid-item-card:: :material-regular:`rocket_launch;2em` Getting Started + .. grid-item-card:: :material-regular:`rocket_launch;2em` Getting started :columns: 12 6 6 4 :link: beginner-guide :link-type: ref :class-card: getting-started - .. grid-item-card:: :material-regular:`library_books;2em` User Guides + .. grid-item-card:: :material-regular:`library_books;2em` User guides :columns: 12 6 6 4 :link: user-guides :link-type: ref :class-card: user-guides - .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer Docs + .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer docs :columns: 12 6 6 4 :link: contributor-guide :link-type: ref @@ -58,7 +58,7 @@ For an end-to-end transformer library built on JAX, see MaxText_. .. toctree:: :hidden: :maxdepth: 1 - :caption: Getting Started + :caption: Getting started installation quickstart @@ -75,7 +75,7 @@ For an end-to-end transformer library built on JAX, see MaxText_. .. toctree:: :hidden: :maxdepth: 2 - :caption: Further Resources + :caption: Further resources user_guides advanced_guide diff --git a/docs/user_guides.rst b/docs/user_guides.rst index 57913bf6d4c8..45260067604e 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -1,6 +1,6 @@ .. _user-guides: -User Guides +User guides =========== User guides are deeper dives into particular topics within JAX @@ -9,7 +9,7 @@ or deployed codebases. .. toctree:: :maxdepth: 1 - :caption: Debugging and Performance + :caption: Debugging and performance notebooks/thinking_in_jax profiling @@ -29,7 +29,7 @@ or deployed codebases. .. toctree:: :maxdepth: 1 - :caption: Run Time + :caption: Run time aot export/index @@ -38,7 +38,7 @@ or deployed codebases. .. toctree:: :maxdepth: 1 - :caption: Custom Operations + :caption: Custom operations pallas/index ffi From aa66fb37c33fd562660227d869d005034557deb7 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 12 Aug 2024 14:41:58 -0700 Subject: [PATCH 084/702] [Pallas][XLA:Mosaic] Add python stack traces to Mosaic errors that occur in Pallas. PiperOrigin-RevId: 662232859 --- jax/_src/compiler.py | 50 +++++-- jax/_src/pallas/mosaic/BUILD | 12 ++ jax/_src/pallas/mosaic/error_handling.py | 158 +++++++++++++++++++++ jax/_src/pallas/mosaic/lowering.py | 17 +-- tests/pallas/BUILD | 18 +++ tests/pallas/pallas_error_handling_test.py | 142 ++++++++++++++++++ 6 files changed, 375 insertions(+), 22 deletions(-) create mode 100644 jax/_src/pallas/mosaic/error_handling.py create mode 100644 tests/pallas/pallas_error_handling_test.py diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 8cad5a8fe9a3..81457f1cbd07 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -21,7 +21,7 @@ import os import tempfile import time -from typing import Any +from typing import Any, Callable import warnings from jax._src import compilation_cache @@ -253,15 +253,45 @@ def backend_compile( else: built_c = module - # we use a separate function call to ensure that XLA compilation appears - # separately in Python profiling results - if host_callbacks: - return backend.compile(built_c, compile_options=options, - host_callbacks=host_callbacks) - # Some backends don't have `host_callbacks` option yet - # TODO(sharadmv): remove this fallback when all backends allow `compile` - # to take in `host_callbacks` - return backend.compile(built_c, compile_options=options) + try: + # we use a separate function call to ensure that XLA compilation appears + # separately in Python profiling results + if host_callbacks: + return backend.compile( + built_c, compile_options=options, host_callbacks=host_callbacks + ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` + return backend.compile(built_c, compile_options=options) + except xc.XlaRuntimeError as e: + for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: + handler_result = error_handler(e) + if handler_result is not None: + raise handler_result from e + raise e + + +_XLA_RUNTIME_ERROR_HANDLERS = [] + + +def register_xla_runtime_error_handler( + handler_fn: Callable[[xc.XlaRuntimeError], Exception | None], +): + """Registers a custom exception handler for XLA runtime errors. + + Registering a custom handler allows re-raising a more informative exception + after encountering an XLARuntimeError. + + Args: + handler_fn: A function which returns a new exception to replace the original + XLA runtime error, or None if the original error should be propagated. + + Returns: + A new exception or None. + """ + _XLA_RUNTIME_ERROR_HANDLERS.append(handler_fn) + def compile_or_get_cached( backend: xc.Client, diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 57dad7793116..f1616962f349 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -43,6 +43,16 @@ py_library( ], ) +py_library( + name = "error_handling", + srcs = ["error_handling.py"], + deps = [ + "//jax:compiler", + "//jax:traceback_util", + "//jax/_src/lib", + ], +) + py_library( name = "primitives", srcs = ["primitives.py"], @@ -71,10 +81,12 @@ py_library( srcs = ["lowering.py"], deps = [ ":core", + ":error_handling", ":primitives", "//jax", "//jax:ad_util", "//jax:core", + "//jax:dtypes", "//jax:mesh", "//jax:mlir", "//jax:mosaic", diff --git a/jax/_src/pallas/mosaic/error_handling.py b/jax/_src/pallas/mosaic/error_handling.py new file mode 100644 index 000000000000..5340eb3fa654 --- /dev/null +++ b/jax/_src/pallas/mosaic/error_handling.py @@ -0,0 +1,158 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for raising more informative exceptions from Pallas.""" +from collections import namedtuple +import re +import types +from jax._src import compiler +from jax._src import traceback_util +from jax._src.lib import xla_client +from jax._src.lib.mlir import ir + +# This is a simple ir.Location parsing regex that assumes the string is properly +# formatted coming from Mosaic. +# It will assume everything from the first to last parentheses +# in the string is part of the frame, and does not account for unbalanced +# parentheses. +LOCATION_PATTERN = re.compile( + r'(?Ploc\((?P\".*?\")(?P.*)\))' +) +FRAME_PATTERN = re.compile( + r'(?P\".*?\")\((?P\".*?\"):' + r'(?P[0-9]+):(?P[0-9]+)\)' +) +MLIR_ERR_PREFIX = ( + 'Pallas encountered an internal verification error.' + 'Please file a bug at https://github.com/google/jax/issues. ' + 'Error details: ' +) + +RawFrame = namedtuple('RawFrame', ['func_name', 'filename', 'lineno', 'colno']) + + +class MosaicError(Exception): + """Error thrown by Pallas when re-raising a Mosaic internal error.""" + + +class VerificationError(MosaicError): + """Error thrown by Pallas when re-raising a verification error.""" + + def __init__(self, message: str): + super().__init__(MLIR_ERR_PREFIX + message) + + +def _handle_xla_runtime_error( + base_err: xla_client.XlaRuntimeError, +) -> MosaicError | None: + """Reformats XLARuntimeError to include a Python traceback.""" + if 'Mosaic' not in str(base_err): + return None + try: + _, frames = parse_location_string(str(base_err)) + except ValueError: + # If no location string is found, skip handling and raise the original + # error. + return None + new_tb = traceback_from_raw_frames(frames) + err_msg = base_err.args[0] + err_msg = redact_locations(err_msg) + new_error = MosaicError(err_msg) + new_error.__traceback__ = traceback_util.filter_traceback(new_tb) + return new_error + + +compiler.register_xla_runtime_error_handler(_handle_xla_runtime_error) + + +def mlir_error_to_verification_error( + base_err: ir.MLIRError) -> VerificationError: + """Reformats MLIRError to include a Python traceback.""" + diagnostic = base_err.error_diagnostics[0] # pytype: disable=attribute-error + def _get_diagnostic_message(diagnostic) -> str: + current_msg = diagnostic.message + for d in diagnostic.notes: + current_msg += "\n " + _get_diagnostic_message(d) + return current_msg + + _, frames = parse_location_string(str(diagnostic.location.attr)) + new_tb = traceback_from_raw_frames(frames) + new_error = VerificationError(_get_diagnostic_message(diagnostic)) + new_error.__traceback__ = traceback_util.filter_traceback(new_tb) + return new_error + + +def redact_locations(err_msg: str) -> str: + """Removes location strings from an error message.""" + for mat in re.finditer(LOCATION_PATTERN, err_msg): + start, end = mat.span('location') + # Remove the entire line containing the location. + line_start = err_msg.rfind('\n', 0, end) + line_start = line_start if line_start >= 0 else start + line_end = err_msg.find('\n', start) + line_end = line_end if line_end >= 0 else end + return err_msg[:line_start] + err_msg[line_end+1:] + return err_msg + + +def parse_location_string(location_string: str) -> tuple[str, list[RawFrame]]: + """Parses a serialized MLIR location. + + Locations strings have the format: + `loc("location_name"())` + + Where is a nested callsite string representing the entire + call stack: + `callsite("fn_name"("filename":lineno:colno) at callsite(...))` + + Args: + location_string: A string serialization of an MLIR location. + + Returns: + A tuple (name, frames) where name is the name of the location and frames + is a list of RawFrame objects representing the Python call stack associated + with the location. + """ + frame_str = '' + loc_name = None + matches = list(re.finditer(LOCATION_PATTERN, location_string)) + if len(matches) > 1: + raise ValueError( + 'More than one location found in string: ', location_string) + for mat in matches: + loc_name = mat.group('eqn_str')[1:-1] + frame_str = mat.group('frames')[1:-1] + if loc_name is None: + raise ValueError(f'Could not find location in string {location_string}') + frames: list[RawFrame] = [] + for mat in re.finditer(FRAME_PATTERN, frame_str): + frames.append( + RawFrame( + mat.group('fun_name')[1:-1], + mat.group('filename')[1:-1], + int(mat.group('lineno')), + int(mat.group('colno')), + ) + ) + return loc_name, frames + + +def traceback_from_raw_frames(frames: list[RawFrame]) -> types.TracebackType: + """Constructs a traceback from a list of RawFrame objects.""" + xla_frames = [ + xla_client.Frame(frame.filename, frame.func_name, -1, frame.lineno + ) # type: ignore [call-arg] + for frame in frames + ] + return xla_client.Traceback.traceback_from_frames(xla_frames) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index bef39142c120..17baef85d154 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -53,6 +53,7 @@ from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core +from jax._src.pallas.mosaic import error_handling from jax._src.pallas.mosaic import primitives as tpu_primitives from jax._src.state import discharge as state_discharge from jax._src.state import indexing @@ -632,12 +633,8 @@ def body_func(*args): body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) try: body.func_op.verify() - except Exception as e: - raise LoweringException( - f"Body failed to verify: {body.func_op}.\nThis is an internal error." - " Please report a bug at:" - " https://github.com/google/jax/issues/new?assignees=sharadmv." - ) from e + except ir.MLIRError as e: + raise error_handling.mlir_error_to_verification_error(e) from e return body.func_op @@ -694,12 +691,8 @@ def body_func(*args): body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) try: body.func_op.verify() - except Exception as e: - raise LoweringException( - f"Body failed to verify: {body.func_op}.\nThis is an internal error." - " Please report a bug at:" - " https://github.com/google/jax/issues/new?assignees=sharadmv." - ) from e + except ir.MLIRError as e: + raise error_handling.mlir_error_to_verification_error(e) from e return body.func_op diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index d2a7ec56db07..f9a8c17cec30 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -267,6 +267,24 @@ jax_test( ], ) +jax_test( + name = "pallas_error_handling_test", + srcs = [ + "pallas_error_handling_test.py", + ], + disable_backends = [ + "cpu", + "gpu", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_tpu", + "//jax/_src/pallas/mosaic:random", + "//third_party/py/absl/testing:absltest", + "//third_party/py/absl/testing:parameterized", + ] + py_deps("numpy"), +) + jax_test( name = "tpu_all_gather_test", srcs = [ diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py new file mode 100644 index 000000000000..06b4bd3e3a4f --- /dev/null +++ b/tests/pallas/pallas_error_handling_test.py @@ -0,0 +1,142 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Pallas error handling.""" +import functools +import traceback + +from absl.testing import absltest +import jax +from jax import numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas.mosaic import error_handling +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + + +config.parse_flags_with_absl() + +LOCATION_TEST_STRING = ( + r'loc("/squeeze"' + r'(callsite("foo_fn"("third_party/foo.py":104:22) at ' + r'callsite("bar_fn"("third_party/bar.py":115:6) at ' + r'""("third_party/pallas_error_handling_test.py":181:2' + r")))))" +) + + +class PallasErrorHandlingTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test only works on TPU.") + + def test_vector_extract_nonzero(self): + input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32) + out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + ) + + @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) + def test_kernel(input_ref, output_ref): + val = input_ref[...] + x = val[0, 0] + val[0, 1] + output_ref[0, 0] = x + + # Test that a Mosaic error is raised. This assert is a guard against + # underlying changes in Mosaic. + # If this is fixed in future Mosaic releases we will need to change + # the test example to force a different error. + with self.assertRaisesRegex( + error_handling.MosaicError, + "Not implemented: Only 0 indices supported for scalar results", + ): + test_kernel(input_arr) + + # Test that the python source is the final frame in the traceback. + tb_string = "" + try: + test_kernel(input_arr) + except error_handling.MosaicError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + self.assertEndsWith(tb_string, "x = val[0, 0] + val[0, 1]\n") + + @jax.jit + def kernel_in_jitted_fn(x): + return test_kernel(x) + + with self.subTest("inside_jitted_fn"): + tb_string = "" + try: + kernel_in_jitted_fn(input_arr) + except error_handling.MosaicError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + self.assertEndsWith(tb_string, "x = val[0, 0] + val[0, 1]\n") + + def test_invalid_smem_vmem_verification_error(self): + input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32) + out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + ) + + @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) + def test_kernel(input_ref, output_ref): + output_ref[0, 0] = input_ref[0, 0] + + # Test that a verification error is raised. This assert is a guard against + # underlying changes in Pallas lowering. + # If this is fixed in future Pallas releases we will need to change + # the test example to force a different error. + with self.assertRaisesRegex( + error_handling.VerificationError, + "'memref.store' op failed to verify that type of 'value' matches " + "element type of 'memref'", + ): + test_kernel(input_arr) + + # Test that the python source is the final frame in the traceback. + tb_string = "" + try: + test_kernel(input_arr) + except error_handling.MosaicError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + self.assertEndsWith(tb_string, "output_ref[0, 0] = input_ref[0, 0]\n") + + def test_parse_location_string(self): + name, frames = error_handling.parse_location_string(LOCATION_TEST_STRING) + self.assertEqual(name, "/squeeze") + self.assertLen(frames, 3) + self.assertEqual(frames[0].func_name, "foo_fn") + self.assertEqual(frames[0].filename, "third_party/foo.py") + self.assertEqual(frames[0].lineno, 104) + self.assertEqual(frames[0].colno, 22) + + +if __name__ == "__main__": + absltest.main() From 7afa90780b5dc5bd4158f7fb218f3a147abfb67a Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 12 Aug 2024 15:42:45 -0700 Subject: [PATCH 085/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b18bc612506b4fb759e5930b9d4b24d4c33dbdbd. PiperOrigin-RevId: 662252977 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index db188b0a9b68..f76c1698e8a2 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "a6fc99fadc72358a5d79dd3ece66340ac5e45ad7" -XLA_SHA256 = "da0b6beeb418933b380c439f34416b6635931809a4c2dc9a99eceb6ff35363fe" +XLA_COMMIT = "b18bc612506b4fb759e5930b9d4b24d4c33dbdbd" +XLA_SHA256 = "48b2cc62b3e99ba4e60088aad9673489001d98c8103da9bdef90fdfdb8a76dd7" def repo(): tf_http_archive( From 26800f193250d3015ef2760aa06c888c8f6334eb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 12 Aug 2024 15:49:46 -0700 Subject: [PATCH 086/702] test: improve test cases for set-like operations Previously, many of the generated test cases were suboptimal because they had few overlaps. This change generates more comprehensive test cases. --- tests/lax_numpy_test.py | 98 +++++++++++++++++++++++------------------ 1 file changed, 55 insertions(+), 43 deletions(-) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c750c3021004..d5efe1e03f31 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -161,6 +161,23 @@ def _shapes_are_equal_length(shapes): return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) +def arrays_with_overlapping_values(rng, shapes, dtypes, unique=False, overlap=0.5) -> list[jax.Array]: + """Generate multiple arrays with some overlapping values. + + This is useful for tests of set-like operations. + """ + assert 0 <= overlap <= 1 + sizes = [math.prod(jtu._dims_of_shape(shape)) for shape in shapes] + total_size = int(sum(sizes) * (1 - overlap)) + max(sizes) # non-strict upper-bound. + if unique: + vals = jtu.rand_unique_int(rng)((total_size,), 'int32') + else: + vals = jtu.rand_default(rng)((total_size,), 'int32') + offsets = [int(sum(sizes[:i]) * (1 - overlap)) for i in range(len(sizes))] + return [np.random.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype) + for (offset, size, shape, dtype) in zip(offsets, sizes, shapes, dtypes)] + + class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" @@ -685,11 +702,12 @@ def testIsin(self, element_shape, test_shape, dtype, invert): dtype2=[s for s in default_dtypes if s != jnp.bfloat16], shape1=all_shapes, shape2=all_shapes, + overlap=[0.1, 0.5, 0.9], ) - def testSetdiff1d(self, shape1, shape2, dtype1, dtype2): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - + def testSetdiff1d(self, shape1, shape2, dtype1, dtype2, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker) @@ -700,10 +718,12 @@ def testSetdiff1d(self, shape1, shape2, dtype1, dtype2): shape2=all_shapes, size=[1, 5, 10], fill_value=[None, -1], + overlap=[0.1, 0.5, 0.9], ) - def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) def np_fun(arg1, arg2): result = np.setdiff1d(arg1, arg2) if size <= len(result): @@ -719,12 +739,14 @@ def jnp_fun(arg1, arg2): @jtu.sample_product( dtype1=[s for s in default_dtypes if s != jnp.bfloat16], dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=nonempty_nonscalar_array_shapes, - shape2=nonempty_nonscalar_array_shapes, + shape1=all_shapes, + shape2=all_shapes, + overlap=[0.1, 0.5, 0.9], ) - def testUnion1d(self, shape1, shape2, dtype1, dtype2): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def testUnion1d(self, shape1, shape2, dtype1, dtype2, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) def np_fun(arg1, arg2): dtype = jnp.promote_types(arg1.dtype, arg2.dtype) return np.union1d(arg1, arg2).astype(dtype) @@ -734,14 +756,16 @@ def np_fun(arg1, arg2): @jtu.sample_product( dtype1=[s for s in default_dtypes if s != jnp.bfloat16], dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=nonempty_nonscalar_array_shapes, - shape2=nonempty_nonscalar_array_shapes, + shape1=nonempty_shapes, + shape2=nonempty_shapes, size=[1, 5, 10], fill_value=[None, -1], + overlap=[0.1, 0.5, 0.9], ) - def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) def np_fun(arg1, arg2): dtype = jnp.promote_types(arg1.dtype, arg2.dtype) result = np.union1d(arg1, arg2).astype(dtype) @@ -762,14 +786,16 @@ def jnp_fun(arg1, arg2): shape1=all_shapes, shape2=all_shapes, assume_unique=[False, True], + overlap=[0.1, 0.5, 0.9], ) - def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique) def np_fun(ar1, ar2): if assume_unique: - # pre-flatten the arrays to match with jax implementation + # numpy requires 1D inputs when assume_unique is True. ar1 = np.ravel(ar1) ar2 = np.ravel(ar2) return np.setxor1d(ar1, ar2, assume_unique) @@ -779,33 +805,19 @@ def np_fun(ar1, ar2): @jtu.sample_product( dtype1=[s for s in default_dtypes if s != jnp.bfloat16], dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=[(), (5,), (2, 5)], - shape2=[(), (5,), (2, 5)], + shape1=nonempty_shapes, + shape2=nonempty_shapes, assume_unique=[False, True], return_indices=[False, True], size=[None, 3, 5], - fill_value=[None, -1] + fill_value=[None, -1], + overlap=[0.1, 0.5, 0.9], ) def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, - return_indices, size, fill_value): - rng = jtu.rand_default(self.rng()) - def args_maker(): - # Generate two arrays with overlapping values. - size1, size2 = math.prod(shape1), math.prod(shape2) - num_vals = max(size1, size2) + min(size1, size2) // 2 - vals = rng((num_vals,), 'int32') - arr1 = vals[:size1].astype(dtype1).reshape(shape1) - arr2 = vals[-size2:].astype(dtype2).reshape(shape2) - # if assume_unique is True, we need the results to contain unique values. - # This may lead to different shapes than requested, but ¯\_(ツ)_/¯ - if assume_unique: - arr1 = np.unique(arr1) - self.rng().shuffle(arr1) # inplace - arr1 = arr1.reshape(shape1) if arr1.shape == size1 else arr1 - arr2 = np.unique(arr2) - self.rng().shuffle(arr2) # inplace - arr2 = arr1.reshape(shape2) if arr2.shape == size2 else arr2 - return arr1, arr2 + return_indices, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + unique=assume_unique, overlap=overlap) def jnp_fun(ar1, ar2): return jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices, From 734ebd570891ceaf8c7104e12256a1edfe942b14 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 12 Aug 2024 15:58:07 -0700 Subject: [PATCH 087/702] Support donating arrays with non-default layouts by setting up XLA donation directly instead of defining aliasing for arrays with potentially incompatible layouts. PiperOrigin-RevId: 662258042 --- jax/_src/interpreters/mlir.py | 24 ++++++++++++++++++ tests/layout_test.py | 48 +++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 814c6a9886d7..5cfb6a0e0699 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1010,6 +1010,23 @@ def _get_mem_kind(s: JSharding | None) -> str | None: return s.memory_kind +def _is_default_layout(curr_layout, sharding, aval): + if curr_layout is None or sharding is None: + return True + if isinstance(curr_layout, AutoLayout): + return False + d = sharding._device_assignment[0] + try: + return curr_layout == DeviceLocalLayout.from_pjrt_layout( + d.client.get_default_layout(aval.dtype, aval.shape, d)) + except xla_extension.XlaRuntimeError as e: + msg, *_ = e.args + if type(msg) is str and msg.startswith("UNIMPLEMENTED"): + return True + else: + raise + + def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, @@ -1064,6 +1081,13 @@ def lower_jaxpr_to_module( "In multi-platform lowering either all or no lowering platforms " f"should support donation. Lowering for {platforms} of which " f"only {platforms_with_donation} support donation") + if (in_layouts is not None and arg_shardings is not None and + out_layouts is not None and result_shardings is not None + ) and not ( + all(map(_is_default_layout, in_layouts, arg_shardings, in_avals)) and + all(map(_is_default_layout, out_layouts, result_shardings, out_avals)) + ): + xla_donated_args = donated_args if num_partitions > 1 and ( result_shardings is None or all(s is None for s in result_shardings)): xla_donated_args = donated_args diff --git a/tests/layout_test.py b/tests/layout_test.py index c72082d0a16c..2ddd72764be0 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -500,6 +500,54 @@ def g(x): 'Layout passed to jit does not match the layout on the respective arg'): g(arr) + def test_layout_donation(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (16, 128) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + custom_dll = DLL(major_to_minor=(0, 1)) + arr = jax.device_put(np_inp, Layout(custom_dll, s)) + + @partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0) + def f(x): + return x + + out = f(arr) + self.assertTrue(arr.is_deleted()) + + def test_layout_donation_auto(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (128, 16) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + arr = jax.device_put(np_inp, s) + + @partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0) + def f(x): + return x * x + + out = f(arr) + self.assertTrue(arr.is_deleted()) + + def test_layout_donation_matching_in_and_out(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (128, 16) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + custom_dll = DLL(major_to_minor=(0, 1)) + l = Layout(custom_dll, s) + arr = jax.device_put(np_inp, l) + + @partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0) + def f(x): + return x * x + + out = f(arr) + self.assertTrue(arr.is_deleted()) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From b8f8b7b07fc2250a8f265b2d80cfa0d53ad04f67 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Mon, 12 Aug 2024 13:52:27 -0700 Subject: [PATCH 088/702] docs: sentence case page titles, section headings, some content --- docs/building_on_jax.md | 12 ++++++------ docs/contributing.md | 4 ++-- docs/device_memory_profiling.md | 4 ++-- docs/glossary.rst | 2 +- docs/investigating_a_regression.md | 4 ++-- docs/jaxpr.rst | 2 +- .../Neural_Network_and_Data_Loading.ipynb | 6 +++--- docs/notebooks/Neural_Network_and_Data_Loading.md | 6 +++--- docs/notebooks/convolutions.ipynb | 8 ++++---- docs/notebooks/convolutions.md | 8 ++++---- docs/notebooks/external_callbacks.ipynb | 2 +- docs/notebooks/external_callbacks.md | 2 +- docs/notebooks/neural_network_with_tfds_data.ipynb | 6 +++--- docs/notebooks/neural_network_with_tfds_data.md | 6 +++--- docs/notebooks/thinking_in_jax.ipynb | 14 +++++++------- docs/notebooks/thinking_in_jax.md | 14 +++++++------- docs/persistent_compilation_cache.md | 2 +- docs/stateful-computations.md | 2 +- 18 files changed, 52 insertions(+), 52 deletions(-) diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index e0a4404911a7..9416b16cde10 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -11,7 +11,7 @@ and how it's used for computational speedup in other libraries. Below are examples of how JAX's features can be used to define accelerated computation across numerous domains and software packages. -## Gradient Computation +## Gradient computation Easy gradient calculation is a key feature of JAX. In the [JaxOpt library](https://github.com/google/jaxopt) value and grad is directly utilized for users in multiple optimization algorithms in [its source code](https://github.com/google/jaxopt/blob/main/jaxopt/_src/base.py#LL87C30-L87C44). @@ -19,7 +19,7 @@ Similarly the same Dynamax Optax pairing mentioned above is an example of gradients enabling estimation methods that were challenging historically [Maximum Likelihood Expectation using Optax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_learning.html). -## Computational Speedup on a Single Core across Multiple Devices +## Computational speedup on a single core across multiple devices Models defined in JAX can then be compiled to enable single computation speedup through JIT compiling. The same compiled code can then be sent to a CPU device, to a GPU or TPU device for additional speedup, @@ -28,7 +28,7 @@ This allows for a smooth workflow from development into production. In Dynamax the computationally expensive portion of a Linear State Space Model solver has been [jitted](https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/models.py#L579). A more complex example comes from PyTensor which compiles a JAX function dynamically and then [jits the constructed function](https://github.com/pymc-devs/pytensor/blob/main/pytensor/link/jax/linker.py#L64). -## Single and Multi Computer Speedup Using Parallelization +## Single and multi computer speedup using parallelization Another benefit of JAX is the simplicity of parallelizing computation using `pmap` and `vmap` function calls or decorators. In Dynamax state space models are parallelized with a [VMAP decorator](https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/parallel_inference.py#L89) @@ -43,7 +43,7 @@ such as Neural Networks or State Space models or others, or provide specific functionality such as optimization. Here are more specific examples of each pattern. -### Direct Usage +### Direct usage Jax can be directly imported and utilized to build models “from scratch” as shown across this website, for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html). @@ -51,7 +51,7 @@ This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you're looking to reduce the number of dependencies in your codebase. -### Composable Domain Specific Libraries with JAX exposed +### Composable domain specific libraries with JAX exposed Another common approach are packages that provide prebuilt functionality, whether it be model definition, or computation of some type. Combinations of these packages can then be mixed and matched for a full @@ -68,7 +68,7 @@ With Dynamax parameters can be estimated using [Maximum Likelihood using Optax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_learning.html) or full Bayesian Posterior can be estimating using [MCMC from Blackjax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_hmc.html) -### JAX Totally Hidden from Users +### JAX totally hidden from users Other libraries opt to completely wrap JAX in their model specific API. An example is PyMC and [Pytensor](https://github.com/pymc-devs/pytensor), in which a user may never “see” JAX directly diff --git a/docs/contributing.md b/docs/contributing.md index 4aecf7153a03..d7fa6e9da8a3 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -162,7 +162,7 @@ possible. The `git rebase -i` command might be useful to this end. (linting-and-type-checking)= -### Linting and Type-checking +### Linting and type-checking JAX uses [mypy](https://mypy.readthedocs.io/) and [ruff](https://docs.astral.sh/ruff/) to statically test code quality; the @@ -186,7 +186,7 @@ fix the issues you can push new commits to your branch. ### Restricted test suite -Once your PR has been reviewed, a JAX maintainer will mark it as `Pull Ready`. This +Once your PR has been reviewed, a JAX maintainer will mark it as `pull ready`. This will trigger a larger set of tests, including tests on GPU and TPU backends that are not available via standard GitHub CI. Detailed results of these tests are not publicly viewable, but the JAX maintainer assigned to your PR will communicate with you regarding diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md index e4d871b780f3..a906c54d5c06 100644 --- a/docs/device_memory_profiling.md +++ b/docs/device_memory_profiling.md @@ -1,4 +1,4 @@ -# Device Memory Profiling +# Device memory profiling @@ -9,7 +9,7 @@ profile, open the `memory_viewer` tab of the Tensorboard profiler for more detailed and understandable device memory usage. ``` -The JAX Device Memory Profiler allows us to explore how and why JAX programs are +The JAX device memory profiler allows us to explore how and why JAX programs are using GPU or TPU memory. For example, it can be used to: * Figure out which arrays and executables are in GPU memory at a given time, or diff --git a/docs/glossary.rst b/docs/glossary.rst index 179c3c75dc6c..a7668e9a02b4 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -28,7 +28,7 @@ JAX glossary of terms able to target GPUs for fast operations on arrays (see also :term:`CPU` and :term:`TPU`). jaxpr - Short for *JAX Expression*, a jaxpr is an intermediate representation of a computation that + Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution. See :ref:`understanding-jaxprs` for more discussion and examples. diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md index 4affae3a65d8..9b712056e9bc 100644 --- a/docs/investigating_a_regression.md +++ b/docs/investigating_a_regression.md @@ -23,7 +23,7 @@ Here is a suggested investigation strategy: 2. Hourly recompilation while keeping XLA and JAX in sync. 3. Final verification: maybe a manual check of a few commits (or a git bisect). -## Nightly investigation. +## Nightly investigation This can be done by using [JAX-Toolbox nightly containers](https://github.com/NVIDIA/JAX-Toolbox). @@ -128,7 +128,7 @@ investigate hourly between 8-24 and 8-26. There was a smaller slowdown earlier, lets ignore it for this example. It would be only another hourly investigation between those dates. -## Hourly investigation. +## Hourly investigation This does a checkout of JAX and XLA at each hour between the 2 dates, rebuilds everything and runs the test. The scripts are structured diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst index 56be62162a9e..d7b50dcb301e 100644 --- a/docs/jaxpr.rst +++ b/docs/jaxpr.rst @@ -164,7 +164,7 @@ before (with two input vars, one for each element of the input tuple) -Constant Vars +Constant vars ------------- Some values in jaxprs are constants, in that their value does not depend on the diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index f0c157655790..16e623d0f28b 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -6,7 +6,7 @@ "id": "18AF5Ab4p6VL" }, "source": [ - "# Training a Simple Neural Network, with PyTorch Data Loading\n", + "# Training a simple neural network, with PyTorch data loading\n", "\n", "\n", "\n", @@ -261,7 +261,7 @@ "id": "umJJGZCC2oKl" }, "source": [ - "## Data Loading with PyTorch\n", + "## Data loading with PyTorch\n", "\n", "JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays." ] @@ -494,7 +494,7 @@ "id": "xxPd6Qw3Z98v" }, "source": [ - "## Training Loop" + "## Training loop" ] }, { diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index 2c53bb1e4ab5..87533117e56a 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -14,7 +14,7 @@ kernelspec: +++ {"id": "18AF5Ab4p6VL"} -# Training a Simple Neural Network, with PyTorch Data Loading +# Training a simple neural network, with PyTorch data loading @@ -175,7 +175,7 @@ def update(params, x, y): +++ {"id": "umJJGZCC2oKl"} -## Data Loading with PyTorch +## Data loading with PyTorch JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays. @@ -245,7 +245,7 @@ test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets) +++ {"id": "xxPd6Qw3Z98v"} -## Training Loop +## Training loop ```{code-cell} ipython3 :id: X2DnZo3iYj18 diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index 0a823353068b..5246e810de64 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -6,7 +6,7 @@ "id": "TVT_MVvc02AA" }, "source": [ - "# Generalized Convolutions in JAX\n", + "# Generalized convolutions in JAX\n", "\n", "\n", "\n", @@ -28,7 +28,7 @@ "id": "ewZEn2X12-Ng" }, "source": [ - "## Basic One-dimensional Convolution\n", + "## Basic one-dimensional convolution\n", "\n", "Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution:" ] @@ -91,7 +91,7 @@ "id": "5ndvLDIH4rv6" }, "source": [ - "## Basic N-dimensional Convolution\n", + "## Basic N-dimensional convolution\n", "\n", "For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions.\n", "\n", @@ -160,7 +160,7 @@ "id": "bxuUjFVG-v1h" }, "source": [ - "## General Convolutions" + "## General convolutions" ] }, { diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 3de8f261aa5b..2dec35847359 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -14,7 +14,7 @@ kernelspec: +++ {"id": "TVT_MVvc02AA"} -# Generalized Convolutions in JAX +# Generalized convolutions in JAX @@ -31,7 +31,7 @@ For basic convolution operations, the `jax.numpy` and `jax.scipy` operations are +++ {"id": "ewZEn2X12-Ng"} -## Basic One-dimensional Convolution +## Basic one-dimensional convolution Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution: @@ -65,7 +65,7 @@ For more information, see the {func}`jax.numpy.convolve` documentation, or the d +++ {"id": "5ndvLDIH4rv6"} -## Basic N-dimensional Convolution +## Basic N-dimensional convolution For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions. @@ -105,7 +105,7 @@ Like in the one-dimensional case, we use `mode='same'` to specify how we would l +++ {"id": "bxuUjFVG-v1h"} -## General Convolutions +## General convolutions +++ {"id": "0pcn2LeS-03b"} diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb index bdf71004c01b..050a641a7845 100644 --- a/docs/notebooks/external_callbacks.ipynb +++ b/docs/notebooks/external_callbacks.ipynb @@ -6,7 +6,7 @@ "id": "7XNMxdTwURqI" }, "source": [ - "# External Callbacks in JAX\n", + "# External callbacks in JAX\n", "\n", "" ] diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md index 857eef42e2b3..ab0a2fcd3317 100644 --- a/docs/notebooks/external_callbacks.md +++ b/docs/notebooks/external_callbacks.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "7XNMxdTwURqI"} -# External Callbacks in JAX +# External callbacks in JAX diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 95c00bf1e689..7d353c924845 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -36,7 +36,7 @@ "id": "B_XlLLpcWjkA" }, "source": [ - "# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n", + "# Training a simple neural network, with tensorflow/datasets data loading\n", "\n", "\n", "\n", @@ -274,7 +274,7 @@ "id": "umJJGZCC2oKl" }, "source": [ - "## Data Loading with `tensorflow/datasets`\n", + "## Data loading with `tensorflow/datasets`\n", "\n", "JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader." ] @@ -344,7 +344,7 @@ "id": "xxPd6Qw3Z98v" }, "source": [ - "## Training Loop" + "## Training loop" ] }, { diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 8f795484d5b9..2f7ba3271312 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -34,7 +34,7 @@ limitations under the License. +++ {"id": "B_XlLLpcWjkA"} -# Training a Simple Neural Network, with tensorflow/datasets Data Loading +# Training a simple neural network, with tensorflow/datasets data loading @@ -183,7 +183,7 @@ def update(params, x, y): +++ {"id": "umJJGZCC2oKl"} -## Data Loading with `tensorflow/datasets` +## Data loading with `tensorflow/datasets` JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader. @@ -229,7 +229,7 @@ print('Test:', test_images.shape, test_labels.shape) +++ {"id": "xxPd6Qw3Z98v"} -## Training Loop +## Training loop ```{code-cell} ipython3 :id: X2DnZo3iYj18 diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 1c1c9729b654..e4f9d888e6fc 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -6,7 +6,7 @@ "id": "LQHmwePqryRU" }, "source": [ - "# How to Think in JAX\n", + "# How to think in JAX\n", "\n", "\n", "\n", @@ -23,7 +23,7 @@ "source": [ "## JAX vs. NumPy\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- JAX provides a NumPy-inspired interface for convenience.\n", "- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.\n", @@ -282,7 +282,7 @@ "source": [ "## NumPy, lax & XLA: JAX API layering\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- `jax.numpy` is a high-level wrapper that provides a familiar interface.\n", "- `jax.lax` is a lower-level API that is stricter and often more powerful.\n", @@ -475,7 +475,7 @@ "source": [ "## To JIT or not to JIT\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- By default JAX executes operations one at a time, in sequence.\n", "- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.\n", @@ -675,7 +675,7 @@ "source": [ "## JIT mechanics: tracing and static variables\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.\n", "\n", @@ -932,9 +932,9 @@ "id": "r-RCl_wD5lI7" }, "source": [ - "## Static vs Traced Operations\n", + "## Static vs traced operations\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- Just as values can be either static or traced, operations can be static or traced.\n", "\n", diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 14089fa36e32..16be7b9e9369 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "LQHmwePqryRU"} -# How to Think in JAX +# How to think in JAX @@ -25,7 +25,7 @@ JAX provides a simple and powerful API for writing accelerated numerical code, b ## JAX vs. NumPy -**Key Concepts:** +**Key concepts:** - JAX provides a NumPy-inspired interface for convenience. - Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays. @@ -132,7 +132,7 @@ print(y) ## NumPy, lax & XLA: JAX API layering -**Key Concepts:** +**Key concepts:** - `jax.numpy` is a high-level wrapper that provides a familiar interface. - `jax.lax` is a lower-level API that is stricter and often more powerful. @@ -215,7 +215,7 @@ Every JAX operation is eventually expressed in terms of these fundamental XLA op ## To JIT or not to JIT -**Key Concepts:** +**Key concepts:** - By default JAX executes operations one at a time, in sequence. - Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once. @@ -308,7 +308,7 @@ This is because the function generates an array whose shape is not known at comp ## JIT mechanics: tracing and static variables -**Key Concepts:** +**Key concepts:** - JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type. @@ -417,9 +417,9 @@ Understanding which values and operations will be static and which will be trace +++ {"id": "r-RCl_wD5lI7"} -## Static vs Traced Operations +## Static vs traced operations -**Key Concepts:** +**Key concepts:** - Just as values can be either static or traced, operations can be static or traced. diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index af20aa7bba24..1d6a5d9b701a 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -1,4 +1,4 @@ -# Persistent Compilation Cache +# Persistent compilation cache diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 5a8af2b74142..4e4063a68467 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -12,7 +12,7 @@ kernelspec: name: python3 --- -# Stateful Computations +# Stateful computations From 4533aeaf265b605539096f25c783118e20f4af8d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 12 Aug 2024 19:15:01 -0700 Subject: [PATCH 089/702] Remove `jax_enable_memories` conditionals from JAX and remove it from tests too. PiperOrigin-RevId: 662322241 --- jax/_src/interpreters/pxla.py | 31 +++++++++---------- .../array_serialization/serialization.py | 2 +- .../array_serialization/serialization_test.py | 1 - tests/memories_test.py | 4 --- 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index ce96f7e815e1..12f981b309e0 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2189,13 +2189,11 @@ def lower_sharding_computation( [js for js, _ in unique_intermediate_shardings], transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types - # TODO(yashkatariya): Remove this when XLA can propagate memory kinds or when - # JAX puts memory kinds in the types of jaxpr. - if not all_default_mem_kind: + if all_default_mem_kind: + propagated_out_mem_kinds = (None,) * len(global_out_avals) + else: propagated_out_mem_kinds = get_out_memory_kinds_via_propagation( closed_jaxpr, in_shardings) - else: - propagated_out_mem_kinds = (None,) * len(global_out_avals) # 2. Build up the HLO semantic_in_shardings = SemanticallyEqualShardings( @@ -2258,6 +2256,8 @@ def lower_sharding_computation( out_layouts=out_layouts, pmap_nreps=nreps, shape_poly_state=shape_poly_state, + # TODO(yashkatariya): Remove `all_default_mem_kind` after + # MemoryDescription works in OSS. all_default_mem_kind=all_default_mem_kind, all_args_info=all_args_info, pgle_profiler=pgle_profiler, @@ -2327,18 +2327,17 @@ def get_out_shardings_from_executable( ) -> Sequence[sharding_impls.GSPMDSharding] | None: from jax._src import pjit - if config.enable_memories.value: - if all_default_mem_kind: - omk = [None] * num_out_avals - else: - try: - omk = xla_executable.get_output_memory_kinds()[0] - if num_ordered_effects > 0: - omk = omk[num_ordered_effects:] - except: - omk = [None] * num_out_avals - else: + # TODO(yashkatariya): Remove `all_default_mem_kind` branch after + # MemoryDescription works in OSS. + if all_default_mem_kind: omk = [None] * num_out_avals + else: + try: + omk = xla_executable.get_output_memory_kinds()[0] + if num_ordered_effects > 0: + omk = omk[num_ordered_effects:] + except: + omk = [None] * num_out_avals assert len(omk) == num_out_avals, (len(omk), num_out_avals) diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index c7992dc629f1..2620f5cc760c 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -189,7 +189,7 @@ async def transfer_shard_to_host(shard: array.Shard) -> np.ndarray: data = shard.data has_pinned_host = any( m.kind == "pinned_host" for m in shard.device.addressable_memories()) - if config.enable_memories.value and has_pinned_host: + if has_pinned_host: # If available, transfer to pinned host memory sharding = jax.sharding.SingleDeviceSharding(shard.device, memory_kind="pinned_host") diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 2712e2b4a819..e60bfaa1dc89 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -613,7 +613,6 @@ def test_deserialization_with_int4(self): self.assertArraysEqual(out + out, out * 2) -@jtu.with_config(jax_enable_memories=True) class TransferShardTest(jtu.JaxTestCase): @jtu.skip_on_devices('cpu') diff --git a/tests/memories_test.py b/tests/memories_test.py index 87c85ffc47d8..6140c6945df5 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -54,7 +54,6 @@ def _create_inputs(shape, pspec, mem_kind=None): return mesh, s, np_inp, inp -@jtu.with_config(jax_enable_memories=True) class ShardingMemoriesTest(jtu.JaxTestCase): def setUp(self): @@ -186,7 +185,6 @@ def test_default_memory_kind(self): self.assertEqual(dev.default_memory().kind, self._default_memory_kind) -@jtu.with_config(jax_enable_memories=True) class DevicePutTest(jtu.JaxTestCase): def setUp(self): @@ -668,7 +666,6 @@ def f(): self._check_device_put_addressable_shards(out, np_inp * 2, s_dev, 'device') -@jtu.with_config(jax_enable_memories=True) class ComputeOffload(jtu.BufferDonationTestCase): def setUp(self): @@ -1392,7 +1389,6 @@ def f(x): self.assertEqual(out.sharding.memory_kind, 'device') -@jtu.with_config(jax_enable_memories=True) class ActivationOffloadingTest(jtu.JaxTestCase): def setUp(self): From 09e73118bf0e1da283e84b0d51005201ea29b0b0 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Mon, 12 Aug 2024 20:07:49 -0700 Subject: [PATCH 090/702] docs: more sentence case --- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 18 +++++++++--------- docs/notebooks/Common_Gotchas_in_JAX.md | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 7ba192437a32..d769144406eb 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -258,7 +258,7 @@ "id": "oBdKtkVW8Lha" }, "source": [ - "## 🔪 In-Place Updates" + "## 🔪 In-place updates" ] }, { @@ -533,7 +533,7 @@ "id": "oZ_jE2WAypdL" }, "source": [ - "## 🔪 Out-of-Bounds Indexing" + "## 🔪 Out-of-bounds indexing" ] }, { @@ -868,7 +868,7 @@ "id": "MUycRNh6e50W" }, "source": [ - "## 🔪 Random Numbers" + "## 🔪 Random numbers" ] }, { @@ -888,7 +888,7 @@ "id": "Qikt9pPW9L5K" }, "source": [ - "### RNGs and State\n", + "### RNGs and state\n", "You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:" ] }, @@ -1183,7 +1183,7 @@ "id": "rg4CpMZ8c3ri" }, "source": [ - "## 🔪 Control Flow" + "## 🔪 Control flow" ] }, { @@ -1192,7 +1192,7 @@ "id": "izLTvT24dAq0" }, "source": [ - "### ✔ python control_flow + autodiff ✔\n", + "### ✔ Python control_flow + autodiff ✔\n", "\n", "If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)." ] @@ -1231,7 +1231,7 @@ "id": "hIfPT7WMmZ2H" }, "source": [ - "### python control flow + JIT\n", + "### Python control flow + JIT\n", "\n", "Using control flow with `jit` is more complicated, and by default it has more constraints.\n", "\n", @@ -1791,7 +1791,7 @@ "id": "OxLsZUyRt_kF" }, "source": [ - "## 🔪 Dynamic Shapes" + "## 🔪 Dynamic shapes" ] }, { @@ -2194,7 +2194,7 @@ "id": "WAHjmL0E2XwO" }, "source": [ - "## 🔪 Miscellaneous Divergences from NumPy\n", + "## 🔪 Miscellaneous divergences from NumPy\n", "\n", "While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n", "Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 98c4b391c7ce..edf5c9446743 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -158,7 +158,7 @@ iter_operand = iter(range(10)) +++ {"id": "oBdKtkVW8Lha"} -## 🔪 In-Place Updates +## 🔪 In-place updates +++ {"id": "JffAqnEW4JEb"} @@ -268,7 +268,7 @@ For more details on indexed array updates, see the [documentation for the `.at` +++ {"id": "oZ_jE2WAypdL"} -## 🔪 Out-of-Bounds Indexing +## 🔪 Out-of-bounds indexing +++ {"id": "btRFwEVzypdN"} @@ -385,7 +385,7 @@ jnp.sum(jnp.array(x)) +++ {"id": "MUycRNh6e50W"} -## 🔪 Random Numbers +## 🔪 Random numbers +++ {"id": "O8vvaVt3MRG2"} @@ -395,7 +395,7 @@ jnp.sum(jnp.array(x)) +++ {"id": "Qikt9pPW9L5K"} -### RNGs and State +### RNGs and state You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness: ```{code-cell} ipython3 @@ -538,11 +538,11 @@ for subkey in subkeys: +++ {"id": "rg4CpMZ8c3ri"} -## 🔪 Control Flow +## 🔪 Control flow +++ {"id": "izLTvT24dAq0"} -### ✔ python control_flow + autodiff ✔ +### ✔ Python control_flow + autodiff ✔ If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager). @@ -562,7 +562,7 @@ print(grad(f)(4.)) # ok! +++ {"id": "hIfPT7WMmZ2H"} -### python control flow + JIT +### Python control flow + JIT Using control flow with `jit` is more complicated, and by default it has more constraints. @@ -865,7 +865,7 @@ $\ast$ = argument-value-independent loop condition - unrolls the loop +++ {"id": "OxLsZUyRt_kF"} -## 🔪 Dynamic Shapes +## 🔪 Dynamic shapes +++ {"id": "1tKXcAMduDR1"} @@ -1130,7 +1130,7 @@ x.dtype # --> dtype('float64') +++ {"id": "WAHjmL0E2XwO"} -## 🔪 Miscellaneous Divergences from NumPy +## 🔪 Miscellaneous divergences from NumPy While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ. Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge. From 69fc8bb419ab95201bfb41bc938a5f35f708fe03 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 13 Aug 2024 00:29:33 -0700 Subject: [PATCH 091/702] Consolidate handling of input argument resolution in custom_* APIs. This is a partial re-land of https://github.com/google/jax/pull/22869 with some updates to ensure that it doesn't break existing uses of `custom_vmap`. Previously, using a `custom_jvp` or `custom_vjp` with a primal function that has keyword-only arguments would result in a type error, even if these arguments weren't passed by the caller. I believe that this check is actually slightly stricter than it needed to be, as discovered when adding a similar check to `custom_vmap`. Instead, I think that it is sufficient to check that the caller hasn't _passed_ any keyword-only arguments. The previous behavior in `custom_vmap` was even harsher: it would error if any keyword arguments were passed. In this change, I have moved `resolve_kwargs` into `api_utils` so that the same function can be used in both `custom_derivatives` and `custom_batching`. I've also updated the logic to only throw a `TypeError` if the caller passes a keyword only argument when calling a `custom_*`-decorated function. This changes the behavior of `custom_jvp` and `custom_vjp`, although users shouldn't see that effect, since previously having kwargs would have errored. PiperOrigin-RevId: 662402158 --- jax/_src/api_util.py | 20 ++++++++++++++++++++ jax/_src/custom_batching.py | 4 ++-- jax/_src/custom_derivatives.py | 21 +++++---------------- tests/api_test.py | 17 +++++++++++++++++ tests/api_util_test.py | 16 ++++++++++++++++ 5 files changed, 60 insertions(+), 18 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index dd1cdcbe6bb8..329abd6b7570 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -556,6 +556,26 @@ def _assert_no_intersection(static_argnames, donate_argnames): f"{out} appear in both static_argnames and donate_argnames") +def resolve_kwargs(fun: Callable, args, kwargs) -> tuple[Any, ...]: + """Resolve input arguments to positional following a function's signature. + + This will raise a TypeError if any keyword-only arguments were passed by the + caller. + """ + if isinstance(fun, partial): + # functools.partial should have an opaque signature. + fun = lambda *args, **kwargs: None + ba = inspect.signature(fun).bind(*args, **kwargs) + ba.apply_defaults() + if ba.kwargs: + passed_kwargs = [k for k in ba.kwargs if k in kwargs] + if passed_kwargs: + raise TypeError( + f"keyword arguments ({passed_kwargs}) could not be resolved to " + "positions") + return ba.args + + def _dtype(x): try: return dtypes.result_type(x) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 1d405c4e5bbf..07df2321be95 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -27,7 +27,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util -from jax._src.api_util import flatten_fun_nokwargs +from jax._src.api_util import flatten_fun_nokwargs, resolve_kwargs from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters.batching import not_mapped @@ -64,7 +64,7 @@ def def_vmap(self, vmap_rule: Callable) -> Callable: @traceback_util.api_boundary def __call__(self, *args, **kwargs): - assert not kwargs + args = resolve_kwargs(self.fun, args, kwargs) fun_name = getattr(self.fun, "__name__", str(self.fun)) if not self.vmap_rule: raise AttributeError( diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 56accc273dbf..6dbf7dc11a8a 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -17,7 +17,6 @@ from collections.abc import Callable, Sequence import dataclasses from functools import update_wrapper, reduce, partial, wraps -import inspect from typing import Any, Generic, TypeVar from jax._src import config @@ -30,7 +29,8 @@ from jax._src import traceback_util from jax._src.ad_util import ( stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) -from jax._src.api_util import argnums_partial, flatten_fun_nokwargs +from jax._src.api_util import ( + argnums_partial, flatten_fun_nokwargs, resolve_kwargs) from jax._src.core import raise_to_shaped from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import ad @@ -56,17 +56,6 @@ ### util -def _resolve_kwargs(fun, args, kwargs): - if isinstance(fun, partial): - # functools.partial should have an opaque signature. - fun = lambda *args, **kwargs: None - ba = inspect.signature(fun).bind(*args, **kwargs) - ba.apply_defaults() - if ba.kwargs: - raise TypeError("keyword arguments could not be resolved to positions") - else: - return ba.args - def _initial_style_jaxpr(fun, in_avals): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) return jaxpr, consts @@ -240,7 +229,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp." raise AttributeError(msg) jvp_name = getattr(self.jvp, '__name__', str(self.jvp)) - args = _resolve_kwargs(self.fun, args, kwargs) + args = resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple(_stop_gradient(x) if i in nondiff_argnums else x @@ -599,7 +588,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp." raise AttributeError(msg) fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) - args = _resolve_kwargs(self.fun, args, kwargs) + args = resolve_kwargs(self.fun, args, kwargs) if self.optimize_remat: fwd = optimize_remat_of_custom_vjp_fwd( self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums, @@ -1453,7 +1442,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: fwd_name = getattr(fwd, "__name__", str(fwd)) # Note: we use `fun` instead of `fwd` here for consistency with # custom_vjp.__call__ above. - args = _resolve_kwargs(fun, args, kwargs) + args = resolve_kwargs(fun, args, kwargs) if nondiff_argnums: for i in nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums_ = set(nondiff_argnums) diff --git a/tests/api_test.py b/tests/api_test.py index 52670bfd1590..cba74bee59e5 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10835,6 +10835,23 @@ def f(x): return jnp.sin(x) AttributeError, "No batching rule defined for custom_vmap function f"): f(0.5) + def test_kwargs(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x=x) + self.assertAllClose(y, jnp.sin(x)) + ys = api.vmap(f)(x=xs) + self.assertAllClose(ys, jnp.cos(xs)) + class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" diff --git a/tests/api_util_test.py b/tests/api_util_test.py index 46bed8c86b8a..e34611c6e785 100644 --- a/tests/api_util_test.py +++ b/tests/api_util_test.py @@ -69,5 +69,21 @@ def test_rebase_donate_argnums(self, donate, static, expected): self.assertEqual(expected, api_util.rebase_donate_argnums(donate, static)) + def test_resolve_kwargs(self): + def fun(x, y, z=3): + return x, y, z + assert api_util.resolve_kwargs(fun, (1,), {"y": 2}) == (1, 2, 3) + assert api_util.resolve_kwargs(fun, (1, 2), {"z": 3}) == (1, 2, 3) + assert api_util.resolve_kwargs( + fun, (), {"x": 1, "y": 2, "z": 3}) == (1, 2, 3) + + def test_resolve_kwargs_with_keyword(self): + def fun(x, y, z, *, kw=True): + del kw + return x, y, z + assert api_util.resolve_kwargs(fun, (1, 2), {"z": 3}) == (1, 2, 3) + with self.assertRaisesRegex(TypeError, "keyword arguments"): + api_util.resolve_kwargs(fun, (1, 2), {"z": 3, "kw": False}) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 850edee36ec57e4664d91c1673190bda7447f221 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 12 Aug 2024 13:29:24 +0100 Subject: [PATCH 092/702] Fix bug in custom_vjp with optimize_remat and custom_vmap. When used with a `custom_vmap` that introduces a new const the previous implementation of `optimize_remat` would error in its DCE rule because of unexpected consts when closing the fwd jaxpr. This shouldn't have ever been hit, but there was a bug in the batching rule for `remat_opt_p` where we weren't properly converting constvars to invars. This fixes this bug and should unbreak internal users. --- jax/_src/custom_derivatives.py | 7 ++++++- tests/api_test.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 6dbf7dc11a8a..9a6253b1bef9 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1523,6 +1523,9 @@ def _remat_opt_vmap( batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( fwd_jaxpr, axis_size, in_batched, False, axis_name, spmd_axis_name, main_type) + extra_consts = batched_fwd_jaxpr.consts + batched_fwd_jaxpr = pe.close_jaxpr( + pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr)) out_dims = [0 if b else not_mapped for b in out_batched] _, prim_batched = split_list(in_batched, [num_consts]) @@ -1535,7 +1538,8 @@ def batched_fun_jaxpr_thunk(): main_type) return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts - batched_outs = remat_opt_p.bind(*args, num_consts=num_consts, + batched_outs = remat_opt_p.bind(*extra_consts, *args, + num_consts=num_consts + len(extra_consts), num_res=num_res, fwd_jaxpr=batched_fwd_jaxpr, fun_jaxpr_thunk=batched_fun_jaxpr_thunk) @@ -1603,6 +1607,7 @@ def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"]) new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs, instantiate=instantiate) + assert not new_jaxpr.constvars closed_jaxpr = pe.close_jaxpr(new_jaxpr) invars = [v for used, v in zip(used_ins, eqn.invars) if used] new_params = dict(eqn.params) diff --git a/tests/api_test.py b/tests/api_test.py index cba74bee59e5..eee4924435ed 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -9779,6 +9779,31 @@ def f_bwd(res, g): x, y = 3.2, 1.0 jax.grad(f)(x, y) # Doesn't error + def test_optimize_remat_custom_vmap(self): + # See https://github.com/google/jax/pull/23000 + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + @jax.custom_batching.custom_vmap + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + @f_fwd.def_vmap + def f_fwd_vmap(_, in_batched, x, y): + # Insert a new const here to test the optimize_remat batching rule. + out = np.array([2.0])*f(x, y) + out_batched = (True, (True, True, True)) + return (out, (jnp.cos(x), jnp.sin(x), y)), out_batched + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) + jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error + def transpose_unary(f, x_example): def transposed(y): From 5fc992e5e12761569c87dd64c7ca54b25db3ae42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Tue, 13 Aug 2024 01:16:37 -0700 Subject: [PATCH 093/702] Determine LAPACK workspaces during SVD kernel runtime The SVD kernel implementation used to require workspace shapes to be determined prior to the custom call on the JAX's side. The new FFI kernels need not demand these shapes to be specified anymore. They are evaluated during kernel runtime. PiperOrigin-RevId: 662413273 --- jaxlib/cpu/BUILD | 1 + jaxlib/cpu/_lapack/__init__.pyi | 7 -- jaxlib/cpu/lapack.cc | 27 -------- jaxlib/cpu/lapack_kernels.cc | 112 ++++++++++++++++---------------- jaxlib/cpu/lapack_kernels.h | 29 ++++----- jaxlib/ffi_helpers.h | 29 ++++++--- 6 files changed, 91 insertions(+), 114 deletions(-) diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index d97a11e4f61c..48332ee1a4d2 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -41,6 +41,7 @@ cc_library( "@xla//xla/service:custom_call_status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], ) diff --git a/jaxlib/cpu/_lapack/__init__.pyi b/jaxlib/cpu/_lapack/__init__.pyi index 35c46fceeb9f..5fcb2a5ad50e 100644 --- a/jaxlib/cpu/_lapack/__init__.pyi +++ b/jaxlib/cpu/_lapack/__init__.pyi @@ -13,7 +13,6 @@ # limitations under the License. from . import eig as eig -from . import svd as svd def initialize() -> None: ... @@ -50,10 +49,6 @@ def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matr # FFI Kernel LAPACK Workspace Size Queries -def cgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... -def dgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... -def gesdd_iwork_size_ffi(m: int, n: int) -> int: ... -def gesdd_rwork_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... def heevd_rwork_size_ffi(n: int) -> int: ... def heevd_work_size_ffi(n: int) -> int: ... def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ... @@ -64,7 +59,5 @@ def lapack_sgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def sgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... def syevd_iwork_size_ffi(n: int) -> int: ... def syevd_work_size_ffi(n: int) -> int: ... -def zgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 3e59a4a024d6..c13608e813f5 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -37,21 +37,6 @@ svd::ComputationMode GetSvdComputationMode(bool job_opt_compute_uv, return svd::ComputationMode::kComputeFullUVt; } -template -int64_t GesddGetWorkspaceSize(lapack_int m, lapack_int n, - bool job_opt_compute_uv, - bool job_opt_full_matrices) { - svd::ComputationMode mode = - GetSvdComputationMode(job_opt_compute_uv, job_opt_full_matrices); - return svd::SVDType::GetWorkspaceSize(m, n, mode); -}; - -lapack_int GesddGetRealWorkspaceSize(lapack_int m, lapack_int n, - bool job_opt_compute_uv) { - svd::ComputationMode mode = GetSvdComputationMode(job_opt_compute_uv, true); - return svd::GetRealWorkspaceSize(m, n, mode); -} - // Due to enforced kComputeEigenvectors, this assumes a larger workspace size. // Could be improved to more accurately estimate the expected size based on the // eig::ComputationMode value. @@ -375,18 +360,6 @@ NB_MODULE(_lapack, m) { m.def("lapack_zungqr_workspace_ffi", &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("gesdd_iwork_size_ffi", &svd::GetIntWorkspaceSize, nb::arg("m"), - nb::arg("n")); - m.def("sgesdd_work_size_ffi", &svd::SVDType::GetWorkspaceSize, - nb::arg("m"), nb::arg("n"), nb::arg("mode")); - m.def("dgesdd_work_size_ffi", &svd::SVDType::GetWorkspaceSize, - nb::arg("m"), nb::arg("n"), nb::arg("mode")); - m.def("gesdd_rwork_size_ffi", &svd::GetRealWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("mode")); - m.def("cgesdd_work_size_ffi", &svd::SVDType::GetWorkspaceSize, - nb::arg("m"), nb::arg("n"), nb::arg("mode")); - m.def("zgesdd_work_size_ffi", &svd::SVDType::GetWorkspaceSize, - nb::arg("m"), nb::arg("n"), nb::arg("mode")); m.def("syevd_work_size_ffi", BoundWithEigvecs, nb::arg("n")); m.def("syevd_iwork_size_ffi", BoundWithEigvecs, diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index c3a32c481a8b..551800bae8f2 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "jaxlib/ffi_helpers.h" #include "xla/ffi/api/c_api.h" @@ -686,19 +687,12 @@ template struct ComplexGesdd>; namespace internal { -template -using RealBufferForComplexOrNull = - std::conditional_t(), - ffi::ResultBuffer, std::nullptr_t>; - template static ffi::Error SvdKernel( ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer singular_values, ffi::ResultBuffer u, ffi::ResultBuffer vt, - ffi::ResultBuffer info, - ffi::ResultBuffer iwork, ffi::ResultBuffer work, - svd::ComputationMode mode, RealBufferForComplexOrNull rwork) { + ffi::ResultBuffer info, svd::ComputationMode mode) { if (mode == svd::ComputationMode::kComputeVtOverwriteXPartialU) [[unlikely]] { return ffi::Error( XLA_FFI_Error_Code_UNIMPLEMENTED, @@ -710,16 +704,30 @@ static ffi::Error SvdKernel( auto* u_data = u->typed_data(); auto* vt_data = vt->typed_data(); auto* info_data = info->typed_data(); - auto* iwork_data = iwork->typed_data(); - auto* work_data = work->typed_data(); + + // Prepare LAPACK workspaces. + FFI_ASSIGN_OR_RETURN( + const auto work_size, + svd::SVDType::GetWorkspaceSize(x_rows, x_cols, mode)); + FFI_ASSIGN_OR_RETURN(const auto iwork_size, + svd::GetIntWorkspaceSize(x_rows, x_cols)); + auto work_data = AllocateScratchMemory(work_size); + auto iwork_data = AllocateScratchMemory(iwork_size); + using RealType = typename svd::SVDType::RealType; + std::unique_ptr rwork; + if constexpr (ffi::IsComplexType()) { + FFI_ASSIGN_OR_RETURN(const auto rwork_size, + svd::GetRealWorkspaceSize(x_rows, x_cols, mode)); + rwork = AllocateScratchMemory(rwork_size); + } CopyIfDiffBuffer(x, x_out); FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow(x_rows)); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); auto mode_v = static_cast(mode); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, + MaybeCastNoOverflow(work_size)); auto x_leading_dim_v = x_rows_v; auto u_leading_dim_v = x_rows_v; @@ -738,14 +746,14 @@ static ffi::Error SvdKernel( svd::SVDType::fn(&mode_v, &x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, singular_values_data, u_data, &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data, &workspace_dim_v, rwork->typed_data(), - iwork_data, info_data); + work_data.get(), &workspace_dim_v, rwork.get(), + iwork_data.get(), info_data); } else { svd::SVDType::fn(&mode_v, &x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, singular_values_data, u_data, &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data, &workspace_dim_v, iwork_data, - info_data); + work_data.get(), &workspace_dim_v, + iwork_data.get(), info_data); } x_out_data += x_out_step; singular_values_data += singular_values_step; @@ -767,7 +775,6 @@ static int64_t SvdGetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, auto x_leading_dim_v = x_rows; auto u_leading_dim_v = x_rows; auto vt_leading_dim_v = mode == svd::ComputationMode::kComputeFullUVt - ? x_cols : std::min(x_rows, x_cols); if constexpr (ffi::IsComplexType()) { @@ -791,10 +798,9 @@ ffi::Error SingularValueDecomposition::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer singular_values, ffi::ResultBuffer u, ffi::ResultBuffer vt, ffi::ResultBuffer info, - ffi::ResultBuffer iwork, ffi::ResultBuffer work, svd::ComputationMode mode) { return internal::SvdKernel(x, x_out, singular_values, u, vt, info, - iwork, work, mode, nullptr); + mode); } template @@ -802,39 +808,38 @@ ffi::Error SingularValueDecompositionComplex::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer singular_values, ffi::ResultBuffer u, ffi::ResultBuffer vt, - ffi::ResultBuffer info, - ffi::ResultBuffer rwork, - ffi::ResultBuffer iwork, ffi::ResultBuffer work, - svd::ComputationMode mode) { + ffi::ResultBuffer info, svd::ComputationMode mode) { return internal::SvdKernel(x, x_out, singular_values, u, vt, info, - iwork, work, mode, rwork); + mode); } template -int64_t SingularValueDecomposition::GetWorkspaceSize( +absl::StatusOr SingularValueDecomposition::GetWorkspaceSize( lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { return internal::SvdGetWorkspaceSize(x_rows, x_cols, mode); } template -int64_t SingularValueDecompositionComplex::GetWorkspaceSize( +absl::StatusOr +SingularValueDecompositionComplex::GetWorkspaceSize( lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { return internal::SvdGetWorkspaceSize(x_rows, x_cols, mode); } -lapack_int svd::GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, - svd::ComputationMode mode) { +absl::StatusOr svd::GetRealWorkspaceSize( + int64_t x_rows, int64_t x_cols, svd::ComputationMode mode) { const auto min_dim = std::min(x_rows, x_cols); if (!ComputesUV(mode)) { - return CastNoOverflow(7 * min_dim); + return MaybeCastNoOverflow(7 * min_dim); } const auto max_dim = std::max(x_rows, x_cols); - return CastNoOverflow( + return MaybeCastNoOverflow( std::max(5 * min_dim * min_dim + 5 * min_dim, 2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim)); } -lapack_int svd::GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols) { +absl::StatusOr svd::GetIntWorkspaceSize(int64_t x_rows, + int64_t x_cols) { return CastNoOverflow(8 * std::min(x_rows, x_cols)); } @@ -1725,33 +1730,28 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*x_out*/) \ .Ret<::xla::ffi::Buffer>(/*info*/)) -#define JAX_CPU_DEFINE_GESDD(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, SingularValueDecomposition::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*s*/) \ - .Ret<::xla::ffi::Buffer>(/*u*/) \ - .Ret<::xla::ffi::Buffer>(/*vt*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ +#define JAX_CPU_DEFINE_GESDD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, SingularValueDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*s*/) \ + .Ret<::xla::ffi::Buffer>(/*u*/) \ + .Ret<::xla::ffi::Buffer>(/*vt*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ .Attr("mode")) -#define JAX_CPU_DEFINE_GESDD_COMPLEX(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, SingularValueDecompositionComplex::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*s*/) \ - .Ret<::xla::ffi::Buffer>(/*u*/) \ - .Ret<::xla::ffi::Buffer>(/*vt*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ +#define JAX_CPU_DEFINE_GESDD_COMPLEX(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, SingularValueDecompositionComplex::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*s*/) \ + .Ret<::xla::ffi::Buffer>(/*u*/) \ + .Ret<::xla::ffi::Buffer>(/*vt*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ .Attr("mode")) #define JAX_CPU_DEFINE_SYEVD(name, data_type) \ diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 5493ec8cbffc..d78fd1b8d3d3 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -20,8 +20,9 @@ limitations under the License. #include #include -#include "xla/ffi/api/ffi.h" +#include "absl/status/statusor.h" #include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" // Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either @@ -303,6 +304,7 @@ struct SingularValueDecomposition { static_assert(!::xla::ffi::IsComplexType(), "There exists a separate implementation for Complex types"); using ValueType = ::xla::ffi::NativeType; + using RealType = ValueType; using FnType = void(char* jobz, lapack_int* m, lapack_int* n, ValueType* a, lapack_int* lda, ValueType* s, ValueType* u, lapack_int* ldu, ValueType* vt, lapack_int* ldvt, @@ -315,12 +317,11 @@ struct SingularValueDecomposition { ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, ::xla::ffi::ResultBuffer singular_values, ::xla::ffi::ResultBuffer u, ::xla::ffi::ResultBuffer vt, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer iwork, - ::xla::ffi::ResultBuffer work, svd::ComputationMode mode); + ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); - static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, - svd::ComputationMode mode); + static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols, + svd::ComputationMode mode); }; template <::xla::ffi::DataType dtype> @@ -341,13 +342,11 @@ struct SingularValueDecompositionComplex { ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> singular_values, ::xla::ffi::ResultBuffer u, ::xla::ffi::ResultBuffer vt, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork, - ::xla::ffi::ResultBuffer iwork, - ::xla::ffi::ResultBuffer work, svd::ComputationMode mode); + ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); - static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, - svd::ComputationMode mode); + static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols, + svd::ComputationMode mode); }; namespace svd { @@ -357,9 +356,9 @@ using SVDType = std::conditional_t<::xla::ffi::IsComplexType(), SingularValueDecompositionComplex, SingularValueDecomposition>; -lapack_int GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols); -lapack_int GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, - ComputationMode mode); +absl::StatusOr GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols); +absl::StatusOr GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, + ComputationMode mode); } // namespace svd diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 2374680c20eb..69c63a4ba000 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -1,11 +1,14 @@ #ifndef JAXLIB_FFI_HELPERS_H_ #define JAXLIB_FFI_HELPERS_H_ +#include #include #include #include +#include #include #include +#include #include "absl/algorithm/container.h" #include "absl/base/optimization.h" @@ -55,32 +58,40 @@ inline absl::StatusOr MaybeCastNoOverflow( } } -inline xla::ffi::Error AsFfiError(const absl::Status& status) { +inline ::xla::ffi::Error AsFfiError(const absl::Status& status) { if (ABSL_PREDICT_FALSE(!status.ok())) { - return xla::ffi::Error(static_cast(status.code()), - std::string(status.message())); + return ::xla::ffi::Error(static_cast(status.code()), + std::string(status.message())); } else { - return xla::ffi::Error::Success(); + return ::xla::ffi::Error::Success(); } } template -xla::ffi::Error CheckMatrixDimensions(xla::ffi::Span dims) { +::xla::ffi::Error CheckMatrixDimensions(::xla::ffi::Span dims) { if (dims.size() < 2) { - return xla::ffi::Error(xla::ffi::ErrorCode::kInvalidArgument, - "Matrix must have at least 2 dimensions"); + return ::xla::ffi::Error(::xla::ffi::ErrorCode::kInvalidArgument, + "Matrix must have at least 2 dimensions"); } - return xla::ffi::Error::Success(); + return ::xla::ffi::Error::Success(); } template -std::tuple SplitBatch2D(xla::ffi::Span dims) { +std::tuple SplitBatch2D(::xla::ffi::Span dims) { auto matrix_dims = dims.last(2); return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1, std::multiplies()), matrix_dims.front(), matrix_dims.back()); } +template <::xla::ffi::DataType dtype> +auto AllocateScratchMemory(std::size_t size) + -> std::unique_ptr>[]> { + // TODO(paruzelp): use std::make_unique_for_overwrite when C++20 is available. + using ValueType = std::remove_extent_t<::xla::ffi::NativeType>; + return std::unique_ptr(new ValueType[size]); +} + } // namespace jax #endif // JAXLIB_FFI_HELPERS_H_ From 1a7c6aa186f690007c3b6055dde2c52d4d51b897 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 13 Aug 2024 01:41:59 -0700 Subject: [PATCH 094/702] [pallas] Fix test timeouts PiperOrigin-RevId: 662420238 --- tests/pallas/BUILD | 16 +++++++++++++--- tests/pallas/ops_test.py | 2 +- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index f9a8c17cec30..c0cf61387cbb 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -85,10 +85,15 @@ jax_test( "gpu_h100_x32", ], shard_count = { - "cpu": 4, - "gpu": 4, - "tpu": 4, + "cpu": 8, + "gpu": 8, + "tpu": 8, }, + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], deps = [ "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep @@ -105,6 +110,11 @@ jax_test( disable_backends = [ "gpu", ], + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], deps = [ "//jax:pallas", "//jax:pallas_tpu", diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 6fbca406931a..2a776b6347f1 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -56,7 +56,7 @@ # ruff: noqa: F811 jax.config.parse_flags_with_absl() -jtu.setup_hypothesis(max_examples=100) +jtu.setup_hypothesis(max_examples=50) def smem_on_tpu(): From 354293da4865ec478a6fc81d06430c0a35b66b14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Tue, 13 Aug 2024 02:40:52 -0700 Subject: [PATCH 095/702] Activate Singular Value Decomposition to XLA's FFI PiperOrigin-RevId: 662436635 --- jax/_src/export/_export.py | 1 + .../cpu_svd_lapack_gesdd.py | 421 ++++++++++++++++++ jax/_src/lax/linalg.py | 4 +- jaxlib/lapack.py | 175 +++++--- tests/export_back_compat_test.py | 11 + 5 files changed, 552 insertions(+), 60 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 10b6d09a2d91..978747266448 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -920,6 +920,7 @@ def _check_lowering(lowering) -> None: _CPU_FFI_KERNELS = [ "lapack_spotrf_ffi", "lapack_dpotrf_ffi", "lapack_cpotrf_ffi", "lapack_zpotrf_ffi", + "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", "lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi", ] # These are the JAX custom call target names that are guaranteed to be stable. diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py index 192309f2a54c..2d71308caeda 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py @@ -442,3 +442,424 @@ mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xa9=\x01S\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03W\x0fo/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b\'\x0f\x17+O\x1f\x0f\x0b\x0b//OOo\x01\x03\x0f\x03;\x0f\x1b\x07\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x1b\x1b\x1f\x13\x17\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02N\x08\x1d+-\x05\x15\x1f\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x11\x01\x05\x05\x19\x03\x03\x03{\x03\x03\x03\x7f\x03\x03\x03\xa5\x03\x03\t\xa7\x03\x07\x1b\r\x1d\r\x0f\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#[%g\'i\x0fw)y\x05!\x05#\x05%\x05\'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03}\x03\x03\x03\x81\x03\x117\x839\x85;\x87=\x89?\x8bA\x8dC\x8fE\x93\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x97\x03\x05K\x9bM\x9d\x05=\x05?\x03\x03\x03\xa1\x03\x03\t\xa3\x1f\'\x01\x1f)1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x03\x03]\r\x05_ace\x1dC\x1dE\x1dG\x1dI#\x1f\x03\x07kos\r\x03Ym\x1dK\r\x03Yq\x1dM\r\x03Yu\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x08\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fSSSSSSU\x03\x03\x91\x15\x03\x01\x19\x01\x03\x11U\x95UUWWWW\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x15!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\x11\x01\x0b)\x05\t\x11\t\x13\x1d\x03\t\x1b)\x01\x11)\x03\t\x13)\x01\t)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03"\x03\t)\x03B\x08\x11)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04\x82\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ck\x03\x05\x05\x03\x03\x01\x11\x03\x03\x03\x03\x01\x11\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x11\x05\x0b\x05\x05\x17!#%\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x17\x03\x1f\r\x07\x01I\x031\x05\x17!\x05\x07\x01\x0b\x033\x03#\x03\x03\x01O\x03\x19\x05\x07\x01\x07\x03\x0b\x03\'\x05\x07\x01Q\x037\x03%\x07\x06\x01\x03\x0b\x07+\x11)\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x031\x05\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x133\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x03;\x05\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x05\x07?\x15=\x0f\x04\x05\x077-A\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgesdd\x00', xla_call_module_version=6, ) # End paste + +data_2024_08_13 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_13["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgesdd_ffi'], + serialized_date=datetime.date(2024, 8, 13), + inputs=(array([[[-0.9247611722912019-1.3615157109291343j , + -1.0663457975211892+4.73170030936092j , + -1.4918732811689488-2.880861991859318j , + -1.111356346434667 -2.869701609083459j ], + [-4.71291623424314 -1.5444012898828912j , + -5.232967549101415 -0.41287816948482003j, + 0.8905737109262459+9.50245186328329j , + 4.397722119094926 -6.842005210371916j ], + [ 1.9369405063276903+2.3496014107398917j , + -1.5609345742256133+4.2102103739897805j , + 0.6596030248996742+5.195353435247212j , + 0.6315014498240328-1.2778849649354402j ], + [ 5.115159214503849 -0.8856276268773485j , + 1.3719934567460779-2.236070491368575j , + 0.4974504006612811-3.0462081956756637j , + -0.2620346712025989+4.424682727912594j ]], + + [[-1.8242711798401063-0.8543252170262536j , + -2.724527211360488 +2.256038331706666j , + -1.2777487543905157+0.976556823566376j , + 3.7438974536713223-0.4994301527847589j ], + [-0.6359051102028691+2.730662301129662j , + -1.2877728943263032+3.9124921723649053j , + -3.4618573226579894+1.7835551986994034j , + -1.4710491660152465+2.144967500163963j ], + [-3.6013691182532828+2.8182351980619034j , + 2.0045935428878803+1.1146211993017152j , + -2.332213857689336 -0.874915651404938j , + -1.5393862406530452+0.6852883119580928j ], + [-2.674897392856801 +2.0724239502976984j , + -3.349108041292141 -1.0215359152295307j , + 0.2603515088197114-1.9093411474619364j , + 5.41252457188561 +8.634368042893094j ]]]),), + expected_outputs=(array([[[-0.0417367825863334 +0.10796693731538422j , + 0.6813428383170979 +0.3432797958929331j , + -0.4177022900286576 +0.20028957850808846j , + -0.4344351366508529 +0.034743251442636236j], + [-0.8408468609573512 -0.13260646044648036j , + -0.21674151028481226 +0.015170556885426567j, + 0.17147327711152344 +0.15310416152982537j , + -0.3568765623609291 +0.2190438430670875j ], + [-0.26736181440441353 +0.1379833616281102j , + -0.1753427835255798 -0.3789926157696272j , + -0.8179957069096053 -0.037506032257391686j, + 0.25392637883428515 -0.009771014463849592j], + [ 0.4056923996806594 -0.08297706578106906j , + -0.4321527034953763 +0.097915456635744j , + -0.23439193826962634 -0.0842713053222817j , + -0.423482961456089 +0.625144811494929j ]], + + [[ 0.027268437398665468+0.3631205555033544j , + 0.2702977135592881 +0.13046165871625626j , + 0.042868670139236786-0.47658594176021335j , + 0.7242702256119966 +0.15420620503522459j ], + [-0.08593436615104452 +0.11899901833255505j , + 0.370502861093553 -0.6240865462984537j , + 0.46902056878805953 -0.3474794992077024j , + -0.31667671459632085 -0.1034006436993295j ], + [-0.07914843440873574 -0.033487314943774216j, + 0.4110353453489126 -0.4550908055665629j , + -0.43113180393027273 +0.40910871949631994j , + 0.137827301024203 +0.49428280062680047j ], + [-0.7478497242333215 +0.5283836938016965j , + -0.08345894989956637 +0.011807690067190318j, + -0.27178304569905287 +0.05652627940674812j , + -0.0991195491344199 -0.25988596540006825j ]]]), array([[16.80132997488892 , 7.74475561455812 , 5.831221808032042 , + 1.1195288361137763], + [12.395375946948931 , 8.218551160453815 , 4.68363485027408 , + 1.882091536383919 ]]), array([[[ 0.3579625104055671 +0.j , + 0.40179383774178024 -0.12693597167020743j , + -0.0751486661300563 -0.6109813931761134j , + -0.23049271148274275 +0.51209309438597j ], + [-0.46828614153085474 +0.j , + -0.013958972669495653+0.4210606476774212j , + -0.6006888466394118 -0.3766516564723723j , + -0.24264518623236989 -0.20408557153193463j ], + [-0.6392945524816099 +0.j , + 0.24323886076029005 -0.6679928485374246j , + 0.18168178910997027 -0.08126854868489738j , + -0.2030612067046727 -0.07124733621915219j ], + [-0.49383540371426055 +0.j , + -0.010402968929686451+0.37346249914107377j , + 0.2799428270410499 +0.019494062167627474j, + 0.32588905219319264 +0.6569569657140542j ]], + + [[ 0.26669203705168437 +0.j , + 0.24929033811571388 +0.27271089049933883j , + -0.012922512768026959+0.16383354123801502j , + 0.07388201893235019 -0.8717175469187742j ], + [-0.6156140469162427 +0.j , + -0.33787077397020177 +0.3779715465092333j , + -0.39160430587261197 -0.2839601305776179j , + -0.27148886041576736 -0.23729034093304668j ], + [ 0.5618758038857614 +0.j , + -0.5788776267734558 -0.13833058883452376j , + -0.48995086206819655 +0.19259594116096806j , + -0.22967101640965004 -0.012926826751577636j], + [-0.48393210641613604 +0.j , + -0.10492296054284367 -0.4911419972025976j , + -0.07782239226461207 +0.6751317817750168j , + 0.11941657609231512 -0.19354808489959857j ]]])), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("input")) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor loc(#loc3) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc3) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc3) + %0:5 = stablehlo.custom_call @lapack_zgesdd_ffi(%arg0) {mhlo.backend_config = {mode = 65 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>) loc(#loc3) + %c_2 = stablehlo.constant dense<0> : tensor loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3) + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_3 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_4 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc3) + %12 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc3) + %13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc3) + return %10, %6, %14 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":574:13) +#loc3 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xab;\x01Y\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x17\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03S\x0b\x0bo\x0b\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b//\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0f\x0f\x17\x1fO/\x1f\x0f\x0b\x0b//OOo\x01\x05\x0b\x0f\x037\x1b\x0f\x07\x07\x17\x07\x07\x0f\x0b\x13\x07\x0f\x0f\x1b\x1b\x1f\x07\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\x1a\x08\x1d35\x05\x15\x1f\x03\x03\t\x9b\x05\x17\x03\x03\t\xa1\x11\x03\x05\x05\x19\x03\x03\x03{\x03\x03\x03\xa7\x03\x03\t\xa9\x03\t\x19\x1b\x1d\r\x1f\r\x0f!\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b%a'e)g\x0fu+w\x05#\x05%\x05'\x05)\x1d/\x05\x05+\x03\x03\x03y\x05-\x177\xfa\x08\x1b\x05/\x03\x13;}=\x7f?\x81A\x83C\x85E\x87G\x8dI\x8fK\x93\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x03\x03\x99\x03\x05Q\x9dS\x9f\x05C\x05E\x03\x03\x03\xa3\x03\x03\t\xa5\x1dG\x1dI\x1f'1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dK\x03\x03c\r\x03Y[##\x03\x07imq\r\x05_kY[\x1dM\r\x05_oY[\x1dO\r\x05_sY[\x1dQ\x1dS\x1dU\x1f\x07\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dW\x1dY\x03\x01\x05\x01\r\x03\x89\x8b\x1d[\x13%A\x03\x03]\x03\x03\x91\x15\x03\x01\x01\x01\x03\x0b]\x95]]\x97\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x13!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f91\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x15)\x01\t\x1d\x01)\x05\t\x11\x0f\x0b\x13)\x01\x15\x03\x0f)\x03\t\x19\x1b)\x01\x19)\x01\x0f)\x07\t\x05\x05\x0b)\x07\t\x11\x11\x0b\x11\x03\x05\x07\x05\r\x05!)\x03\r\x11)\x03\t\x11)\x03\x05\x11)\x03\x01\t)\x03\t\x0b)\x05\t\x05\x0b)\x03\x05\t)\x05\t\x11\x0b)\x03\t\t)\x03\r\t\x04\x16\x03\x05\x01\x11\x05\x17\x07\x03\x01\x05\t\x11\x05#\x07\x037_\x03\x05-\x05\x03\x011\x03\x07\x05\x03\x01\x11\x03\x07\x05\x03\x01\x11\x03\x07\x0b\x07\x019\x0b\x05\r\x05\x05\x17\x03\x01\x05\x03\x01M\x03\x1b\x03\x07\x01\x07\x03\x17\x03\x13\r\x07\x01O\x03/\x05\x11\x15\x03\x07\x01\x0b\x031\x03\x17\x05\x03\x01U\x03\x1d\x03\x07\x01\x07\x03\r\x03\x1b\x03\x07\x01W\x035\x03\x19\x07\x06\x01\x03\r\x07\x1f\x0b\x1d\x03\x07\x01\x0b\x03\x1f\x03\x17\x05\x03\x01\x13\x03\x13\x03\x07\x01\x07\x03\x05\x03%\x03\x07\x01\x15\x03!\x03#\x07\x06\x01\x03\x05\x07)\r'\x03\x07\x01\x0b\x03\x1f\x03\x17\x05\x03\x01\x13\x03\x13\x03\x07\x01\x07\x03\x05\x03/\x03\x07\x01\x15\x03!\x03-\x07\x06\x01\x03\x05\x073\x0f1\x0f\x04\x05\x07+!5\x06\x03\x01\x05\x01\x00f\x0b]\x0b%\x03\x0f\x0b\t\t\t!\x11#+\x1b\x1f/!)!)#\x1f\x19i\xa3\r\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00input\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgesdd_ffi\x00mode\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_13["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgesdd_ffi'], + serialized_date=datetime.date(2024, 8, 13), + inputs=(array([[[ 1.6052934 +0.45878917j, 4.587192 -4.5177283j , + 0.4177733 -1.9419309j , -2.2248359 -4.5042715j ], + [-7.083374 -8.127356j , 2.7596245 -4.991001j , + -0.52622825+5.033981j , -0.35441273-1.8215327j ], + [-0.7996552 -2.4052901j , -0.8506142 -3.164714j , + -0.3090829 +2.2020447j , 1.2367196 +2.8830793j ], + [ 1.4633094 -0.5451007j , -3.7833478 +6.6770763j , + -3.1279542 -2.2322626j , -2.1099617 -2.9661314j ]], + + [[ 1.2560439 -5.4743752j , -2.0085676 +2.0063214j , + -0.8132642 -3.4407883j , -0.17360081+0.6419895j ], + [ 2.3756726 +6.3315964j , -0.31447247-1.9387872j , + 4.6732006 -4.286903j , 1.7702469 -1.4957623j ], + [ 1.6918924 -0.52161306j, 0.49963537+4.7751374j , + -1.9243752 -4.5870543j , 2.8829405 +1.7382988j ], + [ 1.4884951 -0.44194785j, -1.3645276 -2.8733373j , + -0.39430943+2.4366508j , -0.76268387+5.2014065j ]]], + dtype=complex64),), + expected_outputs=(array([[[ 0.016725361+0.19210356j , 0.545269 +0.5572638j , + 0.41363978 +0.18964852j , -0.26152337 -0.28195122j ], + [ 0.53678614 +0.6405725j , -0.21783227 -0.21288806j , + 0.28426635 +0.30535886j , 0.15201291 +0.1076857j ], + [ 0.21286921 +0.15473497j , 0.06647172 -0.25652882j , + -0.4074609 -0.10356678j , -0.11794218 -0.8184482j ], + [-0.39079374 -0.20583557j , -0.18335938 -0.44217706j , + 0.63489586 +0.19758745j , 0.038679928-0.363512j ]], + + [[-0.31785947 +0.39032045j , -0.12733367 -0.30841753j , + 0.2639419 +0.26815215j , -0.21332225 -0.6694792j ], + [-0.39241248 -0.60790956j , -0.14006217 +0.4104069j , + -0.08306134 -0.101844534j, -0.45091915 -0.26039878j ], + [-0.36103737 +0.28761536j , -0.49654633 +0.100843735j, + -0.13752809 -0.6203827j , 0.35439843 -0.028546259j], + [ 0.062335134-0.07821423j , 0.35014486 -0.5668197j , + -0.42214072 -0.5090834j , -0.2889286 -0.15894136j ]]], + dtype=complex64), array([[15.135656 , 9.3730345, 7.44493 , 0.4152342], + [12.316968 , 8.661011 , 5.005059 , 2.1159043]], dtype=float32), array([[[-0.65378654 +0.j , -0.20306695 -0.6166746j , + 0.29948464 +0.24257994j , -0.00760437 +0.049453575j], + [ 0.5271269 +0.j , -0.112915546-0.7116953j , + -0.08921899 -0.36348897j , -0.23654734 -0.08269382j ], + [-0.31538552 +0.j , -0.014410704+0.15958196j , + -0.17958632 -0.136909j , -0.6930434 -0.58613425j ], + [-0.44185144 +0.j , 0.17604697 -0.05049205j , + -0.42138547 -0.6948516j , 0.22373372 +0.24654455j ]], + + [[-0.64551586 +0.j , 0.3293224 -0.1167212j , + -0.09352748 +0.6710144j , -0.038554132+0.02716675j ], + [ 0.4241116 +0.j , 0.031135 -0.539813j , + -0.26271757 +0.22760022j , -0.6360964 -0.04817466j ], + [-0.45774835 +0.j , -0.15202752 +0.2734652j , + 0.18930997 -0.32975054j , -0.73310995 -0.10269694j ], + [ 0.4403465 +0.j , 0.29474002 +0.6330784j , + 0.31271845 +0.42166728j , -0.20595443 -0.02053237j ]]], + dtype=complex64)), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("input")) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor loc(#loc3) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc3) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc3) + %0:5 = stablehlo.custom_call @lapack_cgesdd_ffi(%arg0) {mhlo.backend_config = {mode = 65 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>) loc(#loc3) + %c_2 = stablehlo.constant dense<0> : tensor loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3) + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_3 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_4 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc3) + %12 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc3) + %13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc3) + return %10, %6, %14 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":574:13) +#loc3 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xab;\x01Y\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x17\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03S\x0b\x0bo\x0b\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b//\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0f\x0f\x17\x1fO/\x1f\x0f\x0b\x0b/\x1fO/o\x01\x05\x0b\x0f\x037\x1b\x0f\x07\x07\x17\x07\x07\x0f\x0b\x13\x07\x0f\x0f\x1b\x1b\x1f\x07\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xea\x07\x1d35\x05\x15\x1f\x03\x03\t\x9b\x05\x17\x03\x03\t\xa1\x11\x03\x05\x05\x19\x03\x03\x03{\x03\x03\x03\xa7\x03\x03\t\xa9\x03\t\x19\x1b\x1d\r\x1f\r\x0f!\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b%a'e)g\x0fu+w\x05#\x05%\x05'\x05)\x1d/\x05\x05+\x03\x03\x03y\x05-\x177\xfa\x08\x1b\x05/\x03\x13;}=\x7f?\x81A\x83C\x85E\x87G\x8dI\x8fK\x93\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x03\x03\x99\x03\x05Q\x9dS\x9f\x05C\x05E\x03\x03\x03\xa3\x03\x03\t\xa5\x1dG\x1dI\x1f'1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dK\x03\x03c\r\x03Y[##\x03\x07imq\r\x05_kY[\x1dM\r\x05_oY[\x1dO\r\x05_sY[\x1dQ\x1dS\x1dU\x1f\x07\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dW\x1dY\x03\x01\x05\x01\r\x03\x89\x8b\x1d[\x13%A\x03\x03]\x03\x03\x91\x15\x03\x01\x01\x01\x03\x0b]\x95]]\x97\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\t\x00\x00\xc0\x7f\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f91\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x15)\x01\t\x1d\x01)\x05\t\x11\x0f\t\x13)\x01\x15\x03\x0f)\x03\t\x19\x1b)\x01\x19)\x01\x0f)\x07\t\x05\x05\x0b)\x07\t\x11\x11\x0b\x11\x03\x05\x07\x05\r\x05!)\x03\r\x11)\x03\t\x11)\x03\x05\x11)\x03\x01\t)\x03\t\x0b)\x05\t\x05\x0b)\x03\x05\t)\x05\t\x11\x0b)\x03\t\t)\x03\r\t\x04\x16\x03\x05\x01\x11\x05\x17\x07\x03\x01\x05\t\x11\x05#\x07\x037_\x03\x05-\x05\x03\x011\x03\x07\x05\x03\x01\x11\x03\x07\x05\x03\x01\x11\x03\x07\x0b\x07\x019\x0b\x05\r\x05\x05\x17\x03\x01\x05\x03\x01M\x03\x1b\x03\x07\x01\x07\x03\x17\x03\x13\r\x07\x01O\x03/\x05\x11\x15\x03\x07\x01\x0b\x031\x03\x17\x05\x03\x01U\x03\x1d\x03\x07\x01\x07\x03\r\x03\x1b\x03\x07\x01W\x035\x03\x19\x07\x06\x01\x03\r\x07\x1f\x0b\x1d\x03\x07\x01\x0b\x03\x1f\x03\x17\x05\x03\x01\x13\x03\x13\x03\x07\x01\x07\x03\x05\x03%\x03\x07\x01\x15\x03!\x03#\x07\x06\x01\x03\x05\x07)\r'\x03\x07\x01\x0b\x03\x1f\x03\x17\x05\x03\x01\x13\x03\x13\x03\x07\x01\x07\x03\x05\x03/\x03\x07\x01\x15\x03!\x03-\x07\x06\x01\x03\x05\x073\x0f1\x0f\x04\x05\x07+!5\x06\x03\x01\x05\x01\x00f\x0b]\x0b%\x03\x0f\x0b\t\t\t!\x11#+\x1b\x1f/!)!)#\x1f\x19i\xa3\r\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00input\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgesdd_ffi\x00mode\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_13["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgesdd_ffi'], + serialized_date=datetime.date(2024, 8, 13), + inputs=(array([[[ 1.5410905 , -2.775912 , -2.374003 , 4.028736 ], + [-0.56933475, 1.6115232 , 0.9041465 , -0.8321383 ], + [-5.382895 , 4.734856 , 2.1972926 , 1.5553856 ], + [ 0.5109847 , -1.1969309 , 3.3766198 , -1.3678027 ]], + + [[ 2.2637439 , 3.406768 , 4.809871 , 2.8010902 ], + [-1.9981416 , -0.6599986 , 0.5138156 , 4.5982494 ], + [-2.335944 , -9.151717 , -1.0481138 , 2.272443 ], + [-8.257684 , 1.8223318 , 0.38403794, 5.0769973 ]]], + dtype=float32),), + expected_outputs=(array([[[-0.48540133 , 0.6682398 , -0.48819908 , -0.28196266 ], + [ 0.21800542 , -0.13631387 , 0.14819776 , -0.9549501 ], + [ 0.84570533 , 0.44643924 , -0.27943408 , 0.08597416 ], + [ 0.04052323 , -0.57928103 , -0.8133976 , -0.034290295]], + + [[-0.21146727 , 0.46376404 , 0.7863092 , 0.34917426 ], + [ 0.3461469 , 0.21883708 , 0.3399651 , -0.846591 ], + [ 0.6526193 , -0.58340365 , 0.39724028 , 0.27555162 ], + [ 0.6399629 , 0.6298205 , -0.32915345 , 0.29228795 ]]], + dtype=float32), array([[ 8.551605 , 5.3574076 , 2.8073733 , 0.52260846], + [11.457574 , 10.041604 , 5.671653 , 1.4754113 ]], + dtype=float32), array([[[-0.6319044 , 0.66122514, 0.39110142, -0.10255312], + [-0.29710513, 0.13673344, -0.50112027, 0.8011937 ], + [ 0.08969161, 0.4433049 , -0.736473 , -0.5030347 ], + [-0.7101976 , -0.5895469 , -0.23135659, -0.30745378]], + + [[-0.69643414, -0.50230867, -0.11150038, 0.50023323], + [-0.32121184, 0.7889567 , 0.31831914, 0.4159848 ], + [ 0.5096959 , -0.31399366, 0.60193473, 0.5284817 ], + [-0.3898877 , -0.16322279, 0.72382 , -0.5453722 ]]], + dtype=float32)), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("input")) -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor loc(#loc3) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc3) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc3) + %0:5 = stablehlo.custom_call @lapack_sgesdd_ffi(%arg0) {mhlo.backend_config = {mode = 65 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<2xi32>) loc(#loc3) + %c_2 = stablehlo.constant dense<0> : tensor loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3) + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_4 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc3) + %12 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc3) + %13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3) + return %10, %6, %14 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":574:13) +#loc3 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf1\xa77\x01W\x0f\x07\x0b\x13\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x17\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03Q\x0b\x0bo\x0b\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b//\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0f\x0f\x17\x1fO/\x1f\x0f\x0b\x0b/\x1fOo\x01\x05\x0b\x0f\x033\x1b\x0f\x07\x07\x17\x0f\x07\x07\x13\x07\x0f\x1b\x1b\x1f\x07\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\x9a\x07\x1d35\x1f\x05\x15\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x03\x03\x05\xa1\x11\x03\x05\x05\x19\x03\x03\x05y\x03\x03\t\xa5\x03\t\x19\x1b\x1d\x0f\x1f\x0f\x11!\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b%_'c)e\x11s+u\x05#\x05%\x05'\x05)\x1d/\x03\x05+\x03\x03\x05w\x05-\x177\xfa\x08\x1b\x05/\x03\x13;{=}?\x7fA\x81C\x83E\x85G\x8bI\x8dK\x91\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x03\x05\x97\x03\x05Q\x9bS\x9d\x05C\x05E\x03\x03\t\xa3\x1dG\x1dI\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dK\x03\x03a\r\x03WY#\x1f\x03\x07gko\r\x05]iWY\x1dM\r\x05]mWY\x1dO\r\x05]qWY\x1dQ\x1dS\x1dU\x1f\x07\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dW\x1dY\x03\x01\x05\x01\r\x03\x87\x89\x1d[\x13!A\x03\x03[\x03\x03\x8f\x15\x03\x01\x01\x01\x03\x0b[\x93[[\x95\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0f\t\x00\x00\xc0\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x11)\x01\t\x1d\x01)\x05\t\x11\x11)\x01\x11\t\x13)\x03\t\x17\x1b)\x01\x17)\x07\t\x05\x05\x0b)\x07\t\x11\x11\x0b\x11\x03\x05\x07\x05\r\x05!)\x03\r\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\t)\x03\t\x0b)\x05\t\x05\x0b)\x03\x05\t)\x05\t\x11\x0b)\x03\t\t)\x03\r\t\x04\x16\x03\x05\x01\x11\x03\x17\x07\x03\x01\x05\t\x11\x03#\x07\x037_\x03\x05-\x05\x03\x011\x03\x07\x05\x03\x01\x13\x03\x07\x05\x03\x01\x13\x03\x07\x0b\x07\x019\x0b\x05\r\x05\x05\x15\x03\x01\x05\x03\x01M\x03\x19\x03\x07\x01\x07\x03\x15\x03\x13\r\x07\x01O\x03+\x05\x11\x15\x03\x07\x01\x0b\x03-\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\r\x03\x1b\x03\x07\x01U\x031\x03\x19\x07\x06\x01\x03\r\x07\x1f\x0b\x1d\x03\x07\x01\x0b\x03\x1b\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\x05\x03%\x03\x07\x01\x15\x03\x1d\x03#\x07\x06\x01\x03\x05\x07)\r'\x03\x07\x01\x0b\x03\x1b\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\x05\x03/\x03\x07\x01\x15\x03\x1d\x03-\x07\x06\x01\x03\x05\x073\x0f1\x0f\x04\x03\x07+!5\x06\x03\x01\x05\x01\x00f\x0b]\x0b%\x03\x0f\x0b\t\t\t!\x11#+\x1b\x1f/!)!)#\x1f\x19i\xa3\r\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00input\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgesdd_ffi\x00mode\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_13["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgesdd_ffi'], + serialized_date=datetime.date(2024, 8, 13), + inputs=(array([[[ 0.3445689867809981 , 3.5114993759427104 , + 4.702602090972179 , -0.2702264758497052 ], + [ 2.209901632583705 , -2.6286702510632773 , + 4.591276599385847 , 3.4465035398844828 ], + [-1.5083742421154478 , 3.3225165204269635 , + 1.2596205557926703 , 3.524804355848018 ], + [ 1.5118969169108838 , 1.838885943509677 , + 2.818520751293422 , 3.06002540493494 ]], + + [[-2.4045510943950843 , -1.5657555633438576 , + -0.6061472334580296 , -0.23926156407779164], + [ 4.087879920053448 , -3.2507640936811715 , + -2.2556577657517476 , 6.090369998330348 ], + [ 1.1165401344486945 , 2.2134726894037247 , + 5.225178515435584 , 1.9794693474107725 ], + [-4.127878192684534 , -0.37313660200336163, + 0.7893465897510026 , -2.0315217791342848 ]]]),), + expected_outputs=(array([[[-0.5109626909166218 , -0.41744996156105796 , + -0.731253241567692 , 0.17297790257908272 ], + [-0.5623501368035175 , 0.7608931604238581 , + 0.03470920608540995 , 0.32186828528169453 ], + [-0.39585755254587396 , -0.49547702914054115 , + 0.6561880513437817 , 0.4089212062978682 ], + [-0.5157288533916832 , -0.035772078593888285, + 0.18297871183094855 , -0.8362194085221047 ]], + + [[-0.12124821978030864 , -0.30260506534356224 , + -0.5817463045715605 , -0.7451847292758066 ], + [ 0.8877417367326683 , -0.1579400123987918 , + -0.37611807392676866 , 0.21331843758089156 ], + [ 0.030552216758649886, 0.9244545314395404 , + -0.36861075330670934 , -0.09260936183071362 ], + [-0.443035032603635 , -0.1699086407831784 , + -0.6198649402326368 , 0.624994775612963 ]]]), array([[8.951386926411187 , 5.762891699811625 , 3.8391040088894437, + 1.269646897103325 ], + [9.215006888576916 , 6.4772976708832255, 3.246269458558178 , + 0.0511210199435459]]), array([[[-0.1789027692424481 , -0.28818125207050604, + -0.7749616998111009 , -0.5332726590950896 ], + [ 0.3871215938703837 , -0.8985113987184387 , + 0.13976186700464233, 0.1525803344591491 ], + [-0.2314069792404015 , -0.03708202130554682, + -0.5045854966104311 , 0.8309447696839618 ], + [-0.8744034999217863 , -0.32901938548360005, + 0.35396957633060844, -0.04324699218274111]], + + [[ 0.6276106632546885 , -0.267287353478729 , + -0.2299525871877408 , 0.69410671635204 ], + [ 0.28029316975925644, 0.47811378046591546, + 0.8083625695047307 , 0.1984764674680803 ], + [ 0.6187014005224261 , 0.4771409534394446 , + -0.37406866975606345, -0.4996175715979325 ], + [-0.38045915857935025, 0.6872417290515942 , + -0.3921025301835002 , 0.4787538410571401 ]]])), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("input")) -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor loc(#loc3) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc3) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc3) + %0:5 = stablehlo.custom_call @lapack_dgesdd_ffi(%arg0) {mhlo.backend_config = {mode = 65 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64>, tensor<2x4x4xf64>, tensor<2xi32>) loc(#loc3) + %c_2 = stablehlo.constant dense<0> : tensor loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3) + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_3 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_4 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc3) + %12 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc3) + %13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3) + return %10, %6, %14 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":574:13) +#loc3 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf1\xa77\x01W\x0f\x07\x0b\x13\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x17\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03Q\x0b\x0bo\x0b\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b//\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0f\x0f\x17\x1fO/\x1f\x0f\x0b\x0b//Oo\x01\x05\x0b\x0f\x033\x1b\x0f\x07\x07\x17\x0f\x07\x07\x13\x07\x0f\x1b\x1b\x1f\x07\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xaa\x07\x1d35\x1f\x05\x15\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x03\x03\x05\xa1\x11\x03\x05\x05\x19\x03\x03\x05y\x03\x03\t\xa5\x03\t\x19\x1b\x1d\x0f\x1f\x0f\x11!\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b%_'c)e\x11s+u\x05#\x05%\x05'\x05)\x1d/\x03\x05+\x03\x03\x05w\x05-\x177\xfa\x08\x1b\x05/\x03\x13;{=}?\x7fA\x81C\x83E\x85G\x8bI\x8dK\x91\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x03\x05\x97\x03\x05Q\x9bS\x9d\x05C\x05E\x03\x03\t\xa3\x1dG\x1dI\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dK\x03\x03a\r\x03WY#\x1f\x03\x07gko\r\x05]iWY\x1dM\r\x05]mWY\x1dO\r\x05]qWY\x1dQ\x1dS\x1dU\x1f\x07\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dW\x1dY\x03\x01\x05\x01\r\x03\x87\x89\x1d[\x13!A\x03\x03[\x03\x03\x8f\x15\x03\x01\x01\x01\x03\x0b[\x93[[\x95\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x11)\x01\t\x1d\x01)\x05\t\x11\x11)\x01\x11\x0b\x13)\x03\t\x17\x1b)\x01\x17)\x07\t\x05\x05\x0b)\x07\t\x11\x11\x0b\x11\x03\x05\x07\x05\r\x05!)\x03\r\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\t)\x03\t\x0b)\x05\t\x05\x0b)\x03\x05\t)\x05\t\x11\x0b)\x03\t\t)\x03\r\t\x04\x16\x03\x05\x01\x11\x03\x17\x07\x03\x01\x05\t\x11\x03#\x07\x037_\x03\x05-\x05\x03\x011\x03\x07\x05\x03\x01\x13\x03\x07\x05\x03\x01\x13\x03\x07\x0b\x07\x019\x0b\x05\r\x05\x05\x15\x03\x01\x05\x03\x01M\x03\x19\x03\x07\x01\x07\x03\x15\x03\x13\r\x07\x01O\x03+\x05\x11\x15\x03\x07\x01\x0b\x03-\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\r\x03\x1b\x03\x07\x01U\x031\x03\x19\x07\x06\x01\x03\r\x07\x1f\x0b\x1d\x03\x07\x01\x0b\x03\x1b\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\x05\x03%\x03\x07\x01\x15\x03\x1d\x03#\x07\x06\x01\x03\x05\x07)\r'\x03\x07\x01\x0b\x03\x1b\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\x05\x03/\x03\x07\x01\x15\x03\x1d\x03-\x07\x06\x01\x03\x05\x073\x0f1\x0f\x04\x03\x07+!5\x06\x03\x01\x05\x01\x00f\x0b]\x0b%\x03\x0f\x0b\t\t\t!\x11#+\x1b\x1f/!)!)#\x1f\x19i\xa3\r\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00input\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgesdd_ffi\x00mode\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index c7ef1462361f..b1e92b89af29 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1977,7 +1977,9 @@ def _svd_cpu_gpu_lowering( compute_uv=compute_uv) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - s, u, vt, info = gesvd_impl(operand_aval.dtype, operand, + # TODO(b/344892332): Remove the conditional after the compatibility period. + ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand, full_matrices=full_matrices, compute_uv=compute_uv, a_shape_vals=a_shape_vals) diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 11ba6803d9df..d43eef8c8fc3 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -17,6 +17,7 @@ from collections.abc import Sequence from enum import Enum +from typing import Optional import numpy as np @@ -25,12 +26,12 @@ from jaxlib import xla_client +from .cpu import _lapack from .hlo_helpers import ( custom_call, hlo_u8, hlo_s32, ensure_hlo_s32, hlo_add, hlo_min, DimensionSize, ShapeTypePair, mk_result_types_and_shapes, ) -from .cpu import _lapack for _name, _value in _lapack.registrations().items(): xla_client.register_custom_call_target( @@ -69,6 +70,23 @@ def _matrix_diagonal_attr(*, unit_diag: bool): return _char_attr("U" if unit_diag else "N") +def _svd_computation_attr( + *, compute_uv: bool, full_matrices: Optional[bool] = True +): + mode = "A" + if full_matrices is None: + full_matrices = True + if not compute_uv: + # We should assert that `full_matrices` is never True here. + # This should never happen because `full_matrices` can only be computed when + # `compute_uv` is True. However, at this point there are too many tests that + # rely on this behavior. + mode = "N" + elif not full_matrices: + mode = "S" + return _char_attr(mode) + + LAPACK_DTYPE_PREFIX = { np.float32: "s", np.float64: "d", @@ -375,7 +393,7 @@ def potrf_hlo(ctx, dtype, a: ir.Value, *, lower=False, # # ?gesdd: Singular value decomposition -def gesdd_hlo(dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, +def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, a_shape_vals: tuple[DimensionSize, ...]): _lapack.initialize() a_type = ir.RankedTensorType(a.type) @@ -385,81 +403,120 @@ def gesdd_hlo(dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, assert type(n) is int batch_dims_vals = a_shape_vals[:-2] num_bd = len(batch_dims_vals) - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - + fn_base = build_lapack_fn_target(fn_base="gesdd", dtype=dtype) i32_type = ir.IntegerType.get_signless(32) workspace: list[ShapeTypePair] - if dtype == np.float32: - fn = "lapack_sgesdd" - singular_vals_type = ir.F32Type.get() - lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices) - workspace = [ - ([_lapack.gesdd_iwork_size(m, n)], i32_type), - ([lwork], a_type.element_type), - ] - workspace_layouts = [[0], [0]] - elif dtype == np.float64: - fn = "lapack_dgesdd" - singular_vals_type = ir.F64Type.get() - lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices) - workspace = [ - ([_lapack.gesdd_iwork_size(m, n)], i32_type), - ([lwork], a_type.element_type), - ] - workspace_layouts = [[0], [0]] - elif dtype == np.complex64: - fn = "lapack_cgesdd" + + # TODO(b/344892332): Remove the old kernel after the compatibility period. + if ctx.is_forward_compat(): + fn = fn_base + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + if dtype == np.float32: + singular_vals_type = ir.F32Type.get() + lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices) + workspace = [ + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([lwork], a_type.element_type), + ] + workspace_layouts = [[0], [0]] + elif dtype == np.float64: + singular_vals_type = ir.F64Type.get() + lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices) + workspace = [ + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([lwork], a_type.element_type), + ] + workspace_layouts = [[0], [0]] + elif dtype == np.complex64: + singular_vals_type = ir.F32Type.get() + lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices) + workspace = [ + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F32Type.get()), + ([lwork], a_type.element_type), + ] + workspace_layouts = [[0], [0], [0]] + elif dtype == np.complex128: + singular_vals_type = ir.F64Type.get() + lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices) + workspace = [ + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F64Type.get()), + ([lwork], a_type.element_type), + ] + workspace_layouts = [[0], [0], [0]] + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + scalar_layout = [] + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + + shape_type_pairs: Sequence[ShapeTypePair] = [ + (a_shape_vals, a_type.element_type), + (batch_dims_vals + (min(m, n),), singular_vals_type), + (batch_dims_vals + (m, m if full_matrices else min(m, n)), a_type.element_type), + (batch_dims_vals + (n if full_matrices else min(m, n), n), a_type.element_type), + (batch_dims_vals, i32_type), + ] + workspace + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + return custom_call( + fn, + result_types=result_types, + operands=[hlo_s32(int(full_matrices)), hlo_s32(int(compute_uv)), batch_size_val, + hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a], + operand_layouts=[scalar_layout] * 6 + [layout], + result_layouts=[ + layout, + (num_bd,) + tuple(range(num_bd - 1, -1, -1)), + layout, + layout, + tuple(range(num_bd - 1, -1, -1)), + ] + workspace_layouts, + operand_output_aliases={6: 0}, + result_shapes=result_shapes + ).results[1:5] + fn = fn_base + "_ffi" + mode_attr = _svd_computation_attr( + compute_uv=compute_uv, full_matrices=full_matrices + ) + if dtype == np.float32 or dtype == np.complex64: singular_vals_type = ir.F32Type.get() - lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices) - workspace = [ - ([_lapack.gesdd_iwork_size(m, n)], i32_type), - ([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F32Type.get()), - ([lwork], a_type.element_type), - ] - workspace_layouts = [[0], [0], [0]] - elif dtype == np.complex128: - fn = "lapack_zgesdd" + elif dtype == np.float64 or dtype == np.complex128: singular_vals_type = ir.F64Type.get() - lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices) - workspace = [ - ([_lapack.gesdd_iwork_size(m, n)], i32_type), - ([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F64Type.get()), - ([lwork], a_type.element_type), - ] - workspace_layouts = [[0], [0], [0]] else: raise NotImplementedError(f"Unsupported dtype {dtype}") - scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - + a_elem_type = a_type.element_type shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (min(m, n),), singular_vals_type), - (batch_dims_vals + (m, m if full_matrices else min(m, n)), a_type.element_type), - (batch_dims_vals + (n if full_matrices else min(m, n), n), a_type.element_type), - (batch_dims_vals, i32_type), - ] + workspace + (a_shape_vals, a_elem_type), + (batch_dims_vals + (min(m, n),), singular_vals_type), + (batch_dims_vals + (m, m if full_matrices else min(m, n)), a_elem_type), + (batch_dims_vals + (n if full_matrices else min(m, n), n), a_elem_type), + (batch_dims_vals, i32_type), + ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - out = custom_call( + return custom_call( fn, result_types=result_types, - operands=[hlo_s32(int(full_matrices)), hlo_s32(int(compute_uv)), batch_size_val, - hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a], - operand_layouts=[scalar_layout] * 6 + [layout], + operands=[a], + operand_layouts=[layout], result_layouts=[ layout, (num_bd,) + tuple(range(num_bd - 1, -1, -1)), layout, layout, tuple(range(num_bd - 1, -1, -1)), - ] + workspace_layouts, - operand_output_aliases={6: 0}, - result_shapes=result_shapes - ).results - return out[1:5] + ], + operand_output_aliases={0: 0}, + result_shapes=result_shapes, + backend_config={ + "mode": mode_attr, + }, + api_version=4, + ).results[1:] # # syevd: Symmetric eigendecomposition diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 7118495efa9b..7fdf15c598ea 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -114,6 +114,7 @@ def test_custom_call_coverage(self): cpu_ffi_testdatas = [ cpu_cholesky_lapack_potrf.data_2024_05_31, cpu_lu_lapack_getrf.data_2024_05_31, + cpu_svd_lapack_gesdd.data_2024_08_13, ] # Add here all the testdatas that should cover the targets guaranteed # stable @@ -579,6 +580,16 @@ def func(input): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_svd_results, input)) + # TODO(b/344892332): Remove the check after the compatibility period. + has_xla_ffi_support = jaxlib_version >= (0, 4, 32) + if has_xla_ffi_support: + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_svd_results, input)) @jtu.parameterized_filterable( kwargs=[ From ca6be2573be38c3aa48e7d6b76411d6f6d1d4bbd Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 13 Aug 2024 03:21:12 -0700 Subject: [PATCH 096/702] [Mosaic GPU] Move matmul tests to Hypothesis We've been generating thousands of test cases and that's just not scalable. Hypothesis should let us efficiently explore a large number of configurations. PiperOrigin-RevId: 662447113 --- .../mosaic/gpu/examples/matmul.py | 11 +- tests/mosaic/BUILD | 4 +- tests/mosaic/matmul_test.py | 135 ++++++------------ 3 files changed, 53 insertions(+), 97 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index a0d008bf1b43..a6fec6c5e3b0 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -130,6 +130,8 @@ def build_kernel( raise ValueError(f"n must be divisible by 64, but got {n=}") if stages < 2: raise ValueError(f"Need at least 2 stages, but got {stages=}") + if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2: + raise ValueError("Transpose only supported for only happen for 16bit types") lhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(lhs_dtype)) rhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(rhs_dtype)) @@ -301,15 +303,10 @@ def verify( cluster_m=1, cluster_n=1, profile=False, - lhs_dtype=jnp.float16, - rhs_dtype=jnp.float16, + in_dtype=jnp.float16, rhs_transpose=False, ): - if not rhs_transpose and jnp.dtype(lhs_dtype).itemsize != 2: - raise ValueError( - "Implicit transpose can only happen for 16bit types (or mixed precision" - " that is underpinned by 16bit operations)." - ) + lhs_dtype, rhs_dtype = in_dtype, in_dtype kx, ky = random.split(random.key(1234)) x = random.uniform(kx, (m, k), dtype=lhs_dtype) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index d182c99be7b1..a52a62962b9c 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -58,11 +58,11 @@ jax_test( srcs = ["matmul_test.py"], disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, - shard_count = 16, + shard_count = 5, deps = [ "//jax:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) jax_test( diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index fe29615ced0d..32a2f54ab78c 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -15,6 +15,7 @@ """Test different parameterizations of a matmul.""" import os +import unittest from absl.testing import absltest, parameterized from jax._src import config @@ -27,13 +28,25 @@ matmul = None else: from jax.experimental.mosaic.gpu.examples import matmul +try: + import hypothesis as hp + import hypothesis.strategies as hps +except (ModuleNotFoundError, ImportError): + raise unittest.SkipTest("these tests require hypothesis") config.parse_flags_with_absl() +jtu.setup_hypothesis() os.environ["XLA_FLAGS"] = ( os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") +def seed_hypothesis(f): + def wrapper(self, seed): + return hp.seed(seed)(f)(self) + return wrapper + + @jtu.with_config(jax_traceback_filtering="off") class MatmulTestCase(jtu.JaxTestCase): @@ -45,60 +58,38 @@ def setUp(self): not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("Only works on GPU with capability >= sm90") - @parameterized.product( - m=(128, 256, 512, 2048), - n=(128, 256, 512, 2048), - k=(128, 256, 512, 2048), - stages=(2, 4), - tile_m=(64, 128, 256), - tile_n=(64, 128, 256), - in_dtype=(jnp.float16, jnp.bfloat16), # f32 tested separately - rhs_transpose=(False, True), - ) - def test_matmul(self, m, k, n, stages, tile_m, tile_n, in_dtype, rhs_transpose): - if stages * (128 // jnp.dtype(in_dtype).itemsize) > k: - self.skipTest("Too many stages.") - - if m < tile_m: - self.skipTest(f"No use in running a test with {m=} < {tile_m=}.") - - if n < tile_n: - self.skipTest(f"No use in running a test with {n=} < {tile_n=}.") - - try: - matmul.verify( - m, - k, - n, - stages, - tile_m=tile_m, - tile_n=tile_n, - lhs_dtype=in_dtype, - rhs_dtype=in_dtype, - rhs_transpose=rhs_transpose, - ) - except ValueError as e: - if "Mosaic GPU kernel exceeds available shared memory" in str(e): - self.skipTest("Not enough shared memory for test, skipping.") - raise e - - @parameterized.product( - m=(128, 256, 512, 2048), - n=(128, 256, 512, 2048), - k=(128, 256, 512, 2048), - stages=(2, 4), - tile_m=(64, 128, 256), - tile_n=(64, 128, 256), + @parameterized.named_parameters( + (f"_shard{i}", i) for i in range(5) ) - def test_matmul_f32(self, m, k, n, stages, tile_m, tile_n): - if stages * (128 // jnp.dtype(jnp.float32).itemsize) > k: - self.skipTest("Too many stages.") - - if m < tile_m: - self.skipTest(f"No use in running a test with {m=} < {tile_m=}.") - - if n < tile_n: - self.skipTest(f"No use in running a test with {n=} < {tile_n=}.") + @seed_hypothesis + @hp.settings(max_examples=100) # Add verbosity=hp.Verbosity.verbose to debug + @hp.given(hps.data()) + def test_matmul(self, data): + m, n, k = ( + data.draw(hps.sampled_from([128, 256, 512, 2048]), label=d) + for d in "mnk" + ) + stages = data.draw(hps.integers(2, 5), label="stages") + tile_m = data.draw( + hps.sampled_from([t for t in [64, 128, 256] if t <= m]), label="tile_m" + ) + tile_n = data.draw( + hps.sampled_from([t for t in [64, 128, 256] if t <= n]), label="tile_n" + ) + in_dtype = data.draw( + hps.sampled_from([jnp.float16, jnp.bfloat16, jnp.float32]), + label="dtype", + ) + cluster_m = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_m") + hp.assume((m // tile_m) % cluster_m == 0) + cluster_n = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_n") + hp.assume((n // tile_n) % cluster_n == 0) + # TODO(apaszke): Non-portable clusters (16 blocks) sometimes deadlock. + hp.assume(cluster_m * cluster_n <= 8) + if jnp.dtype(in_dtype).itemsize == 4: + rhs_transpose = True + else: + rhs_transpose = data.draw(hps.booleans(), label="rhs_transpose") try: matmul.verify( @@ -108,46 +99,14 @@ def test_matmul_f32(self, m, k, n, stages, tile_m, tile_n): stages, tile_m=tile_m, tile_n=tile_n, - lhs_dtype=jnp.float32, - rhs_dtype=jnp.float32, - rhs_transpose=True, - ) - except ValueError as e: - if "Mosaic GPU kernel exceeds available shared memory" in str(e): - self.skipTest("Not enough shared memory for test, skipping.") - raise e - - @parameterized.product( - m=(512, 2048), - n=(512, 2048), - k=(512, 2048), - stages=(2, 4), - tile_m=(64, 128), - tile_n=(64, 128), - cluster_m=(1, 2, 4), - cluster_n=(1, 2, 4), - ) - def test_matmul_clusters(self, m, k, n, stages, tile_m, tile_n, cluster_m, cluster_n): - if cluster_m * cluster_n > 8: - # TODO(apaszke): Investigate - self.skipTest("Tests sometimes fail with non-portable cluster sizes.") - try: - matmul.verify( - m, - k, - n, - stages, - tile_m=tile_m, - tile_n=tile_n, + in_dtype=in_dtype, cluster_m=cluster_m, cluster_n=cluster_n, - lhs_dtype=jnp.float32, - rhs_dtype=jnp.float32, - rhs_transpose=True, + rhs_transpose=rhs_transpose, ) except ValueError as e: if "Mosaic GPU kernel exceeds available shared memory" in str(e): - self.skipTest("Not enough shared memory for test, skipping.") + hp.assume(False) raise e From 5cf89b3f61faec83ee0bbf27bfb6850c8da3de50 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 13 Aug 2024 04:12:01 -0700 Subject: [PATCH 097/702] [Mosaic GPU] Add support for various swizzles in the matmul example PiperOrigin-RevId: 662459766 --- .../mosaic/gpu/examples/matmul.py | 49 ++++++++++++------- tests/mosaic/matmul_test.py | 15 +++--- 2 files changed, 40 insertions(+), 24 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index a6fec6c5e3b0..efaee88292f8 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -87,13 +87,14 @@ def wgmma( b_order: WGMMALayout, a_slice: SmemRef, b_slice: SmemRef, + swizzle: int, ) -> dict[str, WGMMAAccumulator]: """Perform a matrix multiplication. This function must guarantee that all WGMMA operations queued before it was called have completed before returning. """ - acc = wgmma(acc, a_slice, b_slice, b_order=b_order) + acc = wgmma(acc, a_slice, b_slice, b_order=b_order, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(1) return acc @@ -113,15 +114,13 @@ def build_kernel( stages: int = 2, tile_m: int = 128, tile_n: int = 128, + swizzle: int = 128, cluster: tuple[int, int] = (1, 1), rhs_transpose: bool = False, wgmma_impl=WGMMADefaultImpl, profiler_spec: profiler.ProfilerSpec | None = None, ): f32 = ir.F32Type.get() - out_128b_elems = 128 // bytewidth(f32) - out_tiling = (64, out_128b_elems) - out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), jnp.float32) if tile_m % 64 != 0: raise ValueError(f"{tile_m=} must be divisible by 64") if m % tile_m != 0: @@ -132,24 +131,36 @@ def build_kernel( raise ValueError(f"Need at least 2 stages, but got {stages=}") if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2: raise ValueError("Transpose only supported for only happen for 16bit types") + if swizzle not in {32, 64, 128}: + raise ValueError(f"swizzle must be 32, 64, or 128, but got {swizzle=}") + + if tile_n % 32 == 0: + out_swizzle = 128 + elif tile_n % 16 == 0: + out_swizzle = 64 + else: + raise NotImplementedError(f"{tile_n=} must by divisible by 16") + out_swizzle_elems = out_swizzle // bytewidth(f32) + out_tiling = (64, out_swizzle_elems) + out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), jnp.float32) lhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(lhs_dtype)) rhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(rhs_dtype)) - lhs_128b_elems = 128 // lhs_elem_bytes - rhs_128b_elems = 128 // rhs_elem_bytes - tile_k = max(lhs_128b_elems, rhs_128b_elems) + lhs_swizzle_elems = swizzle // lhs_elem_bytes + rhs_swizzle_elems = swizzle // rhs_elem_bytes + tile_k = max(lhs_swizzle_elems, rhs_swizzle_elems) - if tile_n % rhs_128b_elems != 0: + if tile_n % rhs_swizzle_elems != 0: raise ValueError( - f"{tile_n=} must be divisible by 128 bytes =" - f" {((lhs_128b_elems, lhs_dtype), (rhs_128b_elems, rhs_dtype))}" + f"{tile_n=} must be divisible by {swizzle} bytes =" + f" {((lhs_swizzle_elems, lhs_dtype), (rhs_swizzle_elems, rhs_dtype))}" ) if k % tile_k != 0: raise ValueError(f"k must be divisible by {tile_k=}, but got {k=}") block_tiling = Tiling(m=tile_m, n=tile_n, k=tile_k) - tma_tiling = Tiling(m=64, n=rhs_128b_elems, k=lhs_128b_elems) + tma_tiling = Tiling(m=64, n=rhs_swizzle_elems, k=lhs_swizzle_elems) k_steps = k // block_tiling.k stages = min(stages, k_steps) @@ -186,7 +197,7 @@ def fetch(slot, ki): rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes) txcount = lhs_tma_tile_bytes + rhs_tma_tile_bytes common_copy_args = dict( - swizzle=128, barrier=barrier, arrive=False, uniform=False, + swizzle=swizzle, barrier=barrier, arrive=False, uniform=False, ) with single_thread(): barrier.arrive_expect_tx(txcount) @@ -232,7 +243,9 @@ def stage_loop_body(ki, accs): rhs_smem_order = ( WGMMALayout.COL_MAJOR if rhs_transpose else WGMMALayout.ROW_MAJOR ) - accs = wgmma_impl.wgmma(impl_smem, accs, rhs_smem_order, a_slice, b_slice) + accs = wgmma_impl.wgmma( + impl_smem, accs, rhs_smem_order, a_slice, b_slice, swizzle=swizzle + ) with ctx.named_region("TMA start"): tma_ki = arith.addi(ki, c(stages - 1)) @@ -258,7 +271,7 @@ def stage_loop_body(ki, accs): with ctx.named_region("SMEM store"): acc_val = wgmma_impl.get_result(stage_loop_body.result) - acc_val.store_tiled(epilogue_smem, swizzle=128) + acc_val.store_tiled(epilogue_smem, swizzle=out_swizzle) commit_shared() # Make sure the stores are visible to TMA. with ctx.named_region("GMEM store"): @@ -267,7 +280,7 @@ def stage_loop_body(ki, accs): dst_ref=c_device, gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)), gmem_transform=mosaic_gpu.TileTransform(out_tiling), - swizzle=128, + swizzle=out_swizzle, ) ctx.await_async_copy(0) @@ -302,6 +315,7 @@ def verify( tile_n=128, cluster_m=1, cluster_n=1, + swizzle=128, profile=False, in_dtype=jnp.float16, rhs_transpose=False, @@ -312,8 +326,6 @@ def verify( x = random.uniform(kx, (m, k), dtype=lhs_dtype) y = random.uniform(ky, (n, k) if rhs_transpose else (k, n), dtype=rhs_dtype) - impl = WGMMADefaultImpl - prof_spec = profiler.ProfilerSpec(4096) if profile else None f = build_kernel( m, n, k, @@ -323,7 +335,8 @@ def verify( tile_n=tile_n, cluster=(cluster_m, cluster_n), rhs_transpose=rhs_transpose, - wgmma_impl=impl, + swizzle=swizzle, + wgmma_impl=WGMMADefaultImpl, profiler_spec=prof_spec, ) z, runtime = profiler.measure(f, x, y) diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index 32a2f54ab78c..d0653387b311 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -65,28 +65,30 @@ def setUp(self): @hp.settings(max_examples=100) # Add verbosity=hp.Verbosity.verbose to debug @hp.given(hps.data()) def test_matmul(self, data): + in_dtype = data.draw( + hps.sampled_from([jnp.float16, jnp.bfloat16, jnp.float32]), + label="dtype", + ) + bytewidth = jnp.dtype(in_dtype).itemsize m, n, k = ( data.draw(hps.sampled_from([128, 256, 512, 2048]), label=d) for d in "mnk" ) stages = data.draw(hps.integers(2, 5), label="stages") + swizzle = data.draw(hps.sampled_from([32, 64, 128]), label="swizzle") tile_m = data.draw( hps.sampled_from([t for t in [64, 128, 256] if t <= m]), label="tile_m" ) tile_n = data.draw( hps.sampled_from([t for t in [64, 128, 256] if t <= n]), label="tile_n" ) - in_dtype = data.draw( - hps.sampled_from([jnp.float16, jnp.bfloat16, jnp.float32]), - label="dtype", - ) cluster_m = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_m") hp.assume((m // tile_m) % cluster_m == 0) cluster_n = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_n") hp.assume((n // tile_n) % cluster_n == 0) # TODO(apaszke): Non-portable clusters (16 blocks) sometimes deadlock. hp.assume(cluster_m * cluster_n <= 8) - if jnp.dtype(in_dtype).itemsize == 4: + if bytewidth == 4: rhs_transpose = True else: rhs_transpose = data.draw(hps.booleans(), label="rhs_transpose") @@ -96,12 +98,13 @@ def test_matmul(self, data): m, k, n, - stages, + stages=stages, tile_m=tile_m, tile_n=tile_n, in_dtype=in_dtype, cluster_m=cluster_m, cluster_n=cluster_n, + swizzle=swizzle, rhs_transpose=rhs_transpose, ) except ValueError as e: From 52c269c8cde2dbad848baeb20c683824cdaffb6e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 13 Aug 2024 05:24:59 -0700 Subject: [PATCH 098/702] jnp.setxor1d: add support for static size argument --- jax/_src/numpy/setops.py | 79 +++++++++++++++++++++++++++++++++------- tests/lax_numpy_test.py | 28 +++++++++++--- 2 files changed, 88 insertions(+), 19 deletions(-) diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index db4237dbd069..8e0162660951 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -61,6 +61,17 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array: return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1) +def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]: + """Utility to concatenate the unique values from two arrays.""" + arr1, arr2 = ravel(arr1), ravel(arr2) + arr1, num_unique1 = _unique(arr1, axis=0, size=arr1.size, return_true_size=True) + arr2, num_unique2 = _unique(arr2, axis=0, size=arr2.size, return_true_size=True) + arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2)) + arr = lax.dynamic_update_slice(arr, arr1, (0,)) + arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,)) + return arr, num_unique1 + num_unique2 + + def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set difference of two 1D arrays. @@ -220,7 +231,39 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, return cast(Array, out) -def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array: +@partial(jit, static_argnames=['assume_unique', 'size']) +def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *, + assume_unique: bool, size: int, ) -> Array: + # Ensured by caller + assert arr1.ndim == arr2.ndim == 1 + assert arr1.dtype == arr2.dtype + + if assume_unique: + arr = concatenate([arr1, arr2]) + aux = sort(concatenate([arr1, arr2])) + flag = concatenate((bool(aux.size), aux[1:] != aux[:-1], True), axis=None) + else: + arr, num_unique = _concat_unique(arr1, arr2) + mask = arange(arr.size + 1) < num_unique + 1 + _, aux = lax.sort([~mask[1:], arr], is_stable=True, num_keys=2) + flag = mask & concatenate((bool(aux.size), aux[1:] != aux[:-1], False), + axis=None).at[num_unique].set(True) + aux_mask = flag[1:] & flag[:-1] + num_results = aux_mask.sum() + if aux.size: + indices = nonzero(aux_mask, size=size, fill_value=len(aux))[0] + vals = aux.at[indices].get(mode='fill', fill_value=0) + else: + vals = zeros(size, aux.dtype) + if fill_value is None: + vals = where(arange(len(vals)) < num_results, vals, vals.max()) + return where(arange(len(vals)) < num_results, vals, vals.min()) + else: + return where(arange(len(vals)) < num_results, vals, fill_value) + + +def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, + size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set-wise xor of elements in two arrays. JAX implementation of :func:`numpy.setxor1d`. @@ -234,6 +277,12 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr assume_unique: if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if ``assume_unique`` is True and the input arrays contain duplicates, the behavior is undefined. default: False. + size: if specified, return only the first ``size`` sorted elements. If there are fewer + elements than ``size`` indicates, the return value will be padded with ``fill_value``, + and returned indices will be padded with an out-of-bound index. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the smallest value + in the xor result. Returns: An array of values that are found in exactly one of the input arrays. @@ -250,22 +299,21 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr Array([1, 2, 5, 6], dtype=int32) """ check_arraylike("setxor1d", ar1, ar2) - ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()") - ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()") + arr1, arr2 = promote_dtypes(ravel(ar1), ravel(ar2)) + del ar1, ar2 - ar1 = ravel(ar1) - ar2 = ravel(ar2) + if size is not None: + return _setxor1d_size(arr1, arr2, fill_value=fill_value, + assume_unique=assume_unique, size=size) if not assume_unique: - ar1 = unique(ar1) - ar2 = unique(ar2) - - aux = concatenate((ar1, ar2)) + arr1 = unique(arr1) + arr2 = unique(arr2) + aux = concatenate((arr1, arr2)) if aux.size == 0: return aux - aux = sort(aux) - flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True]))) + flag = concatenate((True, aux[1:] != aux[:-1], True), axis=None) return aux[flag[1:] & flag[:-1]] @@ -312,7 +360,7 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as arr1, ind1, num_unique1 = _unique(arr1, 0, size=arr1.size, return_index=True, return_true_size=True, fill_value=0) arr2, ind2, num_unique2 = _unique(arr2, 0, size=arr2.size, return_index=True, return_true_size=True, fill_value=0) arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2)) - arr = arr.at[:arr1.size].set(arr1) + arr = lax.dynamic_update_slice(arr, arr1, (0,)) arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,)) mask = arange(arr.size) < num_unique1 + num_unique2 _, aux, aux_sort_indices = lax.sort([~mask, arr, arange(arr.size)], is_stable=True, num_keys=2) @@ -326,8 +374,11 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as # and vals[num_results:] contains the appropriate fill_value. aux_mask = (aux[1:] == aux[:-1]) & mask[1:] num_results = aux_mask.sum() - val_indices = nonzero(aux_mask, size=size, fill_value=aux.size)[0] - vals = aux.at[val_indices].get(mode='fill', fill_value=0) + if aux.size: + val_indices = nonzero(aux_mask, size=size, fill_value=aux.size)[0] + vals = aux.at[val_indices].get(mode='fill', fill_value=0) + else: + vals = zeros(size, aux.dtype) if fill_value is None: vals = where(arange(len(vals)) < num_results, vals, vals.max()) vals = where(arange(len(vals)) < num_results, vals, vals.min()) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index d5efe1e03f31..860b3358d2ef 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -18,7 +18,7 @@ import collections from collections.abc import Iterator import copy -from functools import partial +from functools import partial, wraps import inspect import io import itertools @@ -174,10 +174,25 @@ def arrays_with_overlapping_values(rng, shapes, dtypes, unique=False, overlap=0. else: vals = jtu.rand_default(rng)((total_size,), 'int32') offsets = [int(sum(sizes[:i]) * (1 - overlap)) for i in range(len(sizes))] - return [np.random.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype) + return [rng.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype) for (offset, size, shape, dtype) in zip(offsets, sizes, shapes, dtypes)] +def with_size_argument(fun): + @wraps(fun) + def wrapped(*args, size=None, fill_value=None, **kwargs): + result = fun(*args, **kwargs) + if size is None or size == len(result): + return result + elif size < len(result): + return result[:size] + else: + if fill_value is None: + fill_value = result.min() if result.size else 0 + return np.pad(result, (0, size - len(result)), constant_values=fill_value) + return wrapped + + class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" @@ -786,19 +801,22 @@ def jnp_fun(arg1, arg2): shape1=all_shapes, shape2=all_shapes, assume_unique=[False, True], + size=[None, 2, 5], + fill_value=[None, 99], overlap=[0.1, 0.5, 0.9], ) - def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, overlap): + def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, size, fill_value, overlap): args_maker = partial(arrays_with_overlapping_values, self.rng(), shapes=[shape1, shape2], dtypes=[dtype1, dtype2], overlap=overlap) - jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique) + jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique, + size=size, fill_value=fill_value) def np_fun(ar1, ar2): if assume_unique: # numpy requires 1D inputs when assume_unique is True. ar1 = np.ravel(ar1) ar2 = np.ravel(ar2) - return np.setxor1d(ar1, ar2, assume_unique) + return with_size_argument(np.setxor1d)(ar1, ar2, assume_unique, size=size, fill_value=fill_value) with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) From f4c0b1feb0ebf84e34f5d70ed60836d19cb8e2a4 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 13 Aug 2024 05:32:36 -0700 Subject: [PATCH 099/702] [Mosaic GPU] Add control over the output format in the matmul example PiperOrigin-RevId: 662478648 --- .../mosaic/gpu/examples/matmul.py | 47 ++++++++++++------- tests/mosaic/matmul_test.py | 9 +++- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index efaee88292f8..51b1e2041b4f 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -110,7 +110,7 @@ def wrap(*args, **kw): @mlir_context def build_kernel( m, n, k, - lhs_dtype, rhs_dtype, + lhs_dtype, rhs_dtype, out_dtype, stages: int = 2, tile_m: int = 128, tile_n: int = 128, @@ -134,15 +134,20 @@ def build_kernel( if swizzle not in {32, 64, 128}: raise ValueError(f"swizzle must be 32, 64, or 128, but got {swizzle=}") - if tile_n % 32 == 0: - out_swizzle = 128 - elif tile_n % 16 == 0: - out_swizzle = 64 - else: - raise NotImplementedError(f"{tile_n=} must by divisible by 16") - out_swizzle_elems = out_swizzle // bytewidth(f32) + out_mlir_dtype = mlir.dtype_to_ir_type(out_dtype) + out_swizzle = swizzle + if bytewidth(out_mlir_dtype) == 4: + if tile_n % 32 == 0: + out_swizzle = 128 + elif tile_n % 16 == 0: + out_swizzle = 64 + else: + raise NotImplementedError( + f"{tile_n=} must by divisible by 16 for 32-bit output" + ) + out_swizzle_elems = out_swizzle // bytewidth(out_mlir_dtype) out_tiling = (64, out_swizzle_elems) - out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), jnp.float32) + out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), out_dtype) lhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(lhs_dtype)) rhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(rhs_dtype)) @@ -271,7 +276,7 @@ def stage_loop_body(ki, accs): with ctx.named_region("SMEM store"): acc_val = wgmma_impl.get_result(stage_loop_body.result) - acc_val.store_tiled(epilogue_smem, swizzle=out_swizzle) + acc_val.astype(out_mlir_dtype).store_tiled(epilogue_smem, swizzle=out_swizzle) commit_shared() # Make sure the stores are visible to TMA. with ctx.named_region("GMEM store"): @@ -292,7 +297,7 @@ def stage_loop_body(ki, accs): jax.ShapeDtypeStruct((m, k), lhs_dtype), jax.ShapeDtypeStruct((n, k) if rhs_transpose else (k, n), rhs_dtype), ), - jax.ShapeDtypeStruct((m, n), jnp.float32), + jax.ShapeDtypeStruct((m, n), out_dtype), ( smem_shape, TMABarrier(num_barriers=stages), @@ -318,6 +323,7 @@ def verify( swizzle=128, profile=False, in_dtype=jnp.float16, + out_dtype=jnp.float32, rhs_transpose=False, ): lhs_dtype, rhs_dtype = in_dtype, in_dtype @@ -329,7 +335,7 @@ def verify( prof_spec = profiler.ProfilerSpec(4096) if profile else None f = build_kernel( m, n, k, - jnp.dtype(lhs_dtype), jnp.dtype(rhs_dtype), + jnp.dtype(lhs_dtype), jnp.dtype(rhs_dtype), jnp.dtype(out_dtype), stages=stages, tile_m=tile_m, tile_n=tile_n, @@ -352,14 +358,19 @@ def verify( for v in (x, y) ) - ref_f = functools.partial( - jax.lax.dot_general, - dimension_numbers=dimension_numbers, - preferred_element_type=jnp.float32, - ) + @jax.jit + def ref_f(x, y): + return jax.lax.dot_general( + x, + y, + dimension_numbers=dimension_numbers, + preferred_element_type=jnp.float32, + ).astype(out_dtype) ref, ref_runtime = profiler.measure(ref_f, x, y) - np.testing.assert_allclose(z, ref, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose( + z.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3 + ) return runtime, ref_runtime diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index d0653387b311..b7fa615db76a 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -67,8 +67,14 @@ def setUp(self): def test_matmul(self, data): in_dtype = data.draw( hps.sampled_from([jnp.float16, jnp.bfloat16, jnp.float32]), - label="dtype", + label="in_dtype", ) + out_dtype = jnp.float32 + if in_dtype != jnp.float32: + out_dtype = data.draw( + hps.sampled_from([in_dtype, jnp.float32]), + label="out_dtype", + ) bytewidth = jnp.dtype(in_dtype).itemsize m, n, k = ( data.draw(hps.sampled_from([128, 256, 512, 2048]), label=d) @@ -102,6 +108,7 @@ def test_matmul(self, data): tile_m=tile_m, tile_n=tile_n, in_dtype=in_dtype, + out_dtype=out_dtype, cluster_m=cluster_m, cluster_n=cluster_n, swizzle=swizzle, From bab096e5637b50d695323af042487336416c808e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 13 Aug 2024 08:10:00 -0700 Subject: [PATCH 100/702] [Mosaic GPU] Add an autotuning harness to the matmul example PiperOrigin-RevId: 662521895 --- .../mosaic/gpu/examples/matmul.py | 63 ++++++++++++++++--- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 51b1e2041b4f..53cb270b3cdc 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -15,9 +15,9 @@ """Matmul kernels for H100.""" import dataclasses -import functools -from typing import Any +import itertools import math +from typing import Any import jax from jax import random @@ -115,7 +115,8 @@ def build_kernel( tile_m: int = 128, tile_n: int = 128, swizzle: int = 128, - cluster: tuple[int, int] = (1, 1), + cluster_m: int = 1, + cluster_n: int = 1, rhs_transpose: bool = False, wgmma_impl=WGMMADefaultImpl, profiler_spec: profiler.ProfilerSpec | None = None, @@ -304,10 +305,10 @@ def stage_loop_body(ki, accs): ClusterBarrier( collective_dims=(gpu.Dimension.x, gpu.Dimension.y), num_barriers=stages, - ) if math.prod(cluster) > 1 else None, + ) if cluster_m * cluster_n > 1 else None, ), profiler_spec, - cluster=(*cluster, 1), + cluster=(cluster_n, cluster_m, 1), ) @@ -339,7 +340,8 @@ def verify( stages=stages, tile_m=tile_m, tile_n=tile_n, - cluster=(cluster_m, cluster_n), + cluster_m=cluster_m, + cluster_n=cluster_n, rhs_transpose=rhs_transpose, swizzle=swizzle, wgmma_impl=WGMMADefaultImpl, @@ -375,9 +377,54 @@ def ref_f(x, y): if __name__ == "__main__": - m, k, n = 4 * 33 * 128, 2048, 4 * 128 - runtime, ref_runtime = verify(m=m, k=k, n=n, cluster_m=1, cluster_n=4) + dtype = jnp.dtype(jnp.float16) + m, k, n = 16384, 2048, 16384 + + kx, ky = random.split(random.key(1234)) + x = random.uniform(kx, (m, k), dtype=dtype) + y = random.uniform(ky, (k, n), dtype=dtype) + + tile_m = tile_n = (64, 128, 256) + cluster_m = cluster_n = (1, 2) + swizzle = (128,) + stages = (2, 4, 5, 6) + configs = itertools.product(tile_m, tile_n, cluster_m, cluster_n, stages, swizzle) + names = ("tile_m", "tile_n", "cluster_m", "cluster_n", "stages", "swizzle") + best_runtime = float("inf") + best_kwargs = {} + for config in configs: + kwargs = dict(zip(names, config)) + if kwargs["cluster_m"] * kwargs["cluster_n"] > 8: + continue + if m < kwargs["tile_m"] or n < kwargs["tile_n"]: + continue + if (m // kwargs["tile_m"]) % kwargs["cluster_n"]: + continue + if (n // kwargs["tile_n"]) % kwargs["cluster_m"]: + continue + try: + f = build_kernel( + m, n, k, dtype, dtype, dtype, wgmma_impl=WGMMADefaultImpl, **kwargs + ) + _, runtime = profiler.measure(f, x, y) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" not in str(e): + raise + runtime = float("inf") + # Enable this to get more detailed information. + # else: + # print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000)) + if runtime < best_runtime: + best_runtime = runtime + best_kwargs = kwargs + if not best_kwargs: + raise ValueError("No valid configuration found") + + runtime, ref_runtime = verify( + m=m, k=k, n=n, in_dtype=dtype, out_dtype=dtype, **best_kwargs + ) tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12 ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 + print("Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())) print(f"Kernel: {runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS") From 7d2fbd5418c6413674a673e9f273db3272016401 Mon Sep 17 00:00:00 2001 From: Loren Maggiore Date: Tue, 13 Aug 2024 08:51:32 -0700 Subject: [PATCH 101/702] [pallas] enable lowering on an AbstractMesh. PiperOrigin-RevId: 662533955 --- jax/_src/pallas/mosaic/lowering.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 17baef85d154..0c01e01ba686 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -369,7 +369,8 @@ def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): mesh_strides = pallas_utils.strides_from_shape(tuple( mesh.shape[a] for a in axis_names )) - self.mesh_info = MeshInfo(mesh.device_ids.shape, axis_names, mesh_strides) + mesh_shape = tuple(mesh.shape.values()) + self.mesh_info = MeshInfo(mesh_shape, axis_names, mesh_strides) def maybe_compress_grid(self): # If we have many leading parallel dimensions, we should "compress" them From 1bba83894a78adf0ccb70e0daae7f1947258362f Mon Sep 17 00:00:00 2001 From: John QiangZhang Date: Tue, 13 Aug 2024 10:46:13 -0700 Subject: [PATCH 102/702] Add logging the jax2tf `mlir_module_serialized` module size. PiperOrigin-RevId: 662574156 --- jax/_src/export/_export.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 978747266448..e2c60d3778fe 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -609,6 +609,10 @@ def _export_lowered( f"disabled_checks={disabled_checks}") logging.info("Exported JAX function: %s\n", logmsg) logging.info(mlir.dump_module_message(mlir_module, "export")) + logging.info( + "Size of mlir_module_serialized: %d byte", + len(mlir_module_serialized), + ) _check_module(mlir_module, disabled_checks=disabled_checks) From a755f1db837c464f6aa3d3111a1bc40b5ebdd37d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 13 Aug 2024 11:21:55 -0700 Subject: [PATCH 103/702] Import from ``mlir.dialects`` lazily These imports jointly account for ~0.3s of import time internally. PiperOrigin-RevId: 662588167 --- jax/_src/lazy_loader.py | 25 +++++++--- jax/_src/lib/mlir/dialects/__init__.py | 64 ++++++++++++++++---------- 2 files changed, 57 insertions(+), 32 deletions(-) diff --git a/jax/_src/lazy_loader.py b/jax/_src/lazy_loader.py index cf6e68e49c81..5150f38111c3 100644 --- a/jax/_src/lazy_loader.py +++ b/jax/_src/lazy_loader.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A LazyLoader class.""" +"""Lazy loading APIs.""" from collections.abc import Callable, Sequence import importlib +import sys from typing import Any @@ -26,17 +27,27 @@ def attach(package_name: str, submodules: Sequence[str]) -> tuple[ ]: """Lazily loads submodules of a package. - Example use: - ``` - __getattr__, __dir__, __all__ = lazy_loader.attach(__name__, ["sub1", "sub2"]) - ``` + Returns: + A tuple of ``__getattr__``, ``__dir__`` function and ``__all__`` -- + a list of available global names, which can be used to replace the + corresponding definitions in the package. + + Raises: + RuntimeError: If the ``__name__`` of the caller cannot be determined. """ + owner_name = sys._getframe(1).f_globals.get("__name__") + if owner_name is None: + raise RuntimeError("Cannot determine the ``__name__`` of the caller.") - __all__: list[str] = list(submodules) + __all__ = list(submodules) def __getattr__(name: str) -> Any: if name in submodules: - return importlib.import_module(f"{package_name}.{name}") + value = importlib.import_module(f"{package_name}.{name}") + # Update module-level globals to avoid calling ``__getattr__`` again + # for this ``name``. + setattr(sys.modules[owner_name], name, value) + return value raise AttributeError(f"module '{package_name}' has no attribute '{name}") def __dir__() -> list[str]: diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index 01dc7e2725b5..a9bae8821db5 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -13,35 +13,49 @@ # limitations under the License. # ruff: noqa: F401 -from typing import Any -import jaxlib.mlir.dialects.arith as arith -import jaxlib.mlir.dialects.builtin as builtin -import jaxlib.mlir.dialects.chlo as chlo -import jaxlib.mlir.dialects.func as func -import jaxlib.mlir.dialects.math as math -import jaxlib.mlir.dialects.memref as memref -import jaxlib.mlir.dialects.mhlo as mhlo -import jaxlib.mlir.dialects.scf as scf +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from jaxlib.mlir.dialects import arith as arith + from jaxlib.mlir.dialects import builtin as builtin + from jaxlib.mlir.dialects import chlo as chlo + from jaxlib.mlir.dialects import func as func + from jaxlib.mlir.dialects import gpu as gpu + from jaxlib.mlir.dialects import llvm as llvm + from jaxlib.mlir.dialects import math as math + from jaxlib.mlir.dialects import memref as memref + from jaxlib.mlir.dialects import mhlo as mhlo + from jaxlib.mlir.dialects import nvgpu as nvgpu + from jaxlib.mlir.dialects import nvvm as nvvm + from jaxlib.mlir.dialects import scf as scf + from jaxlib.mlir.dialects import sparse_tensor as sparse_tensor + from jaxlib.mlir.dialects import vector as vector +else: + from jax._src import lazy_loader as _lazy + __getattr__, __dir__, __all__ = _lazy.attach("jaxlib.mlir.dialects", [ + "arith", + "builtin", + "chlo", + "func", + "gpu", + "llvm", + "math", + "memref", + "mhlo", + "nvgpu", + "nvvm", + "scf", + "sparse_tensor", + "vector", + ]) + del _lazy + # TODO(bartchr): Once JAX is released with SDY, remove the try/except. try: - import jaxlib.mlir.dialects.sdy as sdy + from jaxlib.mlir.dialects import sdy as sdy except ImportError: sdy: Any = None # type: ignore[no-redef] -import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor -import jaxlib.mlir.dialects.vector as vector -try: - # pytype: disable=import-error - import jaxlib.mlir.dialects.gpu as gpu - import jaxlib.mlir.dialects.nvgpu as nvgpu - import jaxlib.mlir.dialects.nvvm as nvvm - import jaxlib.mlir.dialects.llvm as llvm - # pytype: enable=import-error -except ImportError: - pass - -from jax._src import lib - # Alias that is set up to abstract away the transition from MHLO to StableHLO. -import jaxlib.mlir.dialects.stablehlo as hlo +from jaxlib.mlir.dialects import stablehlo as hlo From 3c223cd253ef388a08630411af8d5496eadc96b6 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Mon, 12 Aug 2024 22:13:36 -0700 Subject: [PATCH 104/702] docs: tidy up titles and headings This shortens some titles and makes them more consistent. It also removes "JAX" from several titles ("in JAX", "for JAX", "JAX's", etc.). Since these are JAX docs, that ought to be clear from context. --- docs/advanced_guide.rst | 9 ++++++--- docs/debugging/index.md | 2 +- docs/device_memory_profiling.md | 2 +- docs/errors.rst | 5 +++-- docs/faq.rst | 4 ++-- docs/ffi.ipynb | 2 +- docs/ffi.md | 2 +- docs/glossary.rst | 4 ++-- docs/index.rst | 4 ++-- docs/installation.md | 2 +- docs/key-concepts.md | 2 +- docs/notebooks/external_callbacks.ipynb | 2 +- docs/notebooks/external_callbacks.md | 2 +- docs/notebooks/vmapped_log_probs.ipynb | 2 +- docs/notebooks/vmapped_log_probs.md | 2 +- docs/profiling.md | 2 +- docs/tutorials.rst | 4 ++-- 17 files changed, 28 insertions(+), 24 deletions(-) diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index 86742313822b..cb987bd6e5a0 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -1,8 +1,11 @@ .. _advanced_guide: -Advanced tutorials -================== -This section contains examples and tutorials on more advanced topics, such as Multi Core computation, Custom operations, and more in depth applications +Advanced guides +=============== + +This section contains examples and tutorials on more advanced topics, +such as multi-core computation, custom operations, and more in-depth +applications. .. toctree:: :caption: Examples diff --git a/docs/debugging/index.md b/docs/debugging/index.md index b00fcc13d0a0..724827f837e3 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -1,4 +1,4 @@ -# Runtime value debugging in JAX +# Debugging runtime values diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md index a906c54d5c06..a2fd3f68780c 100644 --- a/docs/device_memory_profiling.md +++ b/docs/device_memory_profiling.md @@ -1,4 +1,4 @@ -# Device memory profiling +# Profiling device memory diff --git a/docs/errors.rst b/docs/errors.rst index 23dbaf29c46f..96e14ed8d817 100644 --- a/docs/errors.rst +++ b/docs/errors.rst @@ -1,7 +1,8 @@ .. _jax-errors: -JAX Errors -========== +Errors +====== + This page lists a few of the errors you might encounter when using JAX, along with representative examples of how one might fix them. diff --git a/docs/faq.rst b/docs/faq.rst index 92868dc5df42..3ac7d89fb36e 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -1,5 +1,5 @@ -JAX frequently asked questions (FAQ) -==================================== +Frequently asked questions (FAQ) +================================ .. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html .. comment Some links referenced here. Use `JAX - The Sharp Bits`_ (underscore at the end) to reference diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index 9dc49a74ec36..12a2781f7b13 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# JAX's foreign function interface\n", + "# Foreign function interface (FFI)\n", "\n", "_This tutorial requires JAX v0.4.31 or newer._\n", "\n", diff --git a/docs/ffi.md b/docs/ffi.md index aa861d9a094f..4568b670e170 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -12,7 +12,7 @@ kernelspec: name: python3 --- -# JAX's foreign function interface +# Foreign function interface (FFI) _This tutorial requires JAX v0.4.31 or newer._ diff --git a/docs/glossary.rst b/docs/glossary.rst index a7668e9a02b4..4bb9fa15667e 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -1,5 +1,5 @@ -JAX glossary of terms -===================== +Glossary of terms +================= .. glossary:: diff --git a/docs/index.rst b/docs/index.rst index ef59f87dd2d0..476de62bd713 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,4 +1,4 @@ -JAX: High-Performance Array Computing +JAX: High performance array computing ===================================== JAX is a Python library for accelerator-oriented array computation and program transformation, @@ -75,7 +75,7 @@ For an end-to-end transformer library built on JAX, see MaxText_. .. toctree:: :hidden: :maxdepth: 2 - :caption: Further resources + :caption: Resources user_guides advanced_guide diff --git a/docs/installation.md b/docs/installation.md index 20ffe436ff8a..bd0473d89201 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,5 +1,5 @@ (installation)= -# Installing JAX +# Installation diff --git a/docs/key-concepts.md b/docs/key-concepts.md index 4b114c857460..c6cfb176e645 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -13,7 +13,7 @@ kernelspec: --- (key-concepts)= -# Key Concepts +# Key concepts diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb index 050a641a7845..25c551c9834e 100644 --- a/docs/notebooks/external_callbacks.ipynb +++ b/docs/notebooks/external_callbacks.ipynb @@ -6,7 +6,7 @@ "id": "7XNMxdTwURqI" }, "source": [ - "# External callbacks in JAX\n", + "# External callbacks\n", "\n", "" ] diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md index ab0a2fcd3317..c93139e1658c 100644 --- a/docs/notebooks/external_callbacks.md +++ b/docs/notebooks/external_callbacks.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "7XNMxdTwURqI"} -# External callbacks in JAX +# External callbacks diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index 96b334296667..a355959ba45d 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -6,7 +6,7 @@ "id": "6umP1IKf4Dg6" }, "source": [ - "# Autobatching for Bayesian Inference\n", + "# Autobatching for Bayesian inference\n", "\n", "\n", "\n", diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index ea8b4fce2f70..9ecbd9d23a0b 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -14,7 +14,7 @@ kernelspec: +++ {"id": "6umP1IKf4Dg6"} -# Autobatching for Bayesian Inference +# Autobatching for Bayesian inference diff --git a/docs/profiling.md b/docs/profiling.md index 6eceec8f54b8..91f4d61b21b6 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -1,4 +1,4 @@ -# Profiling JAX programs +# Profiling computation diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 2f90e4226e50..be70c6d41654 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -1,7 +1,7 @@ .. _jax-tutorials: -JAX tutorials -============= +Tutorials +========= .. toctree:: :maxdepth: 1 From 9f6857620b672d2667ed937ca3d16b2d54f044bd Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 13 Aug 2024 11:57:49 -0700 Subject: [PATCH 105/702] Disable TensorRT in TF, XLA and JAX. This is needed for hermetic CUDA integration in Google ML projects since tensorRT is not distributed in the same free way as other CUDA/CUDNN distributives. PiperOrigin-RevId: 662601190 --- .bazelrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bazelrc b/.bazelrc index cb223ef02f78..ce9a219f4157 100644 --- a/.bazelrc +++ b/.bazelrc @@ -223,7 +223,7 @@ build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1" build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDA_VERSION=12 build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDNN_VERSION=9 build:rbe_linux_cuda12.3_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12" -build:rbe_linux_cuda12.3_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" +build:rbe_linux_cuda12.3_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64" From 2dea3d6a0cdb3c9467103eeddc87d4791c1dc1d2 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 13 Aug 2024 14:40:37 -0700 Subject: [PATCH 106/702] [Mosaic:TPU] Add shuffled load and store. we also emulate shuffled store using (store + shuffled load + store) for previous generations. PiperOrigin-RevId: 662657663 --- jaxlib/mosaic/dialect/tpu/tpu.td | 31 ++++++++++++ jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 75 ++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 04709690e7d7..b1a9ac910998 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -233,6 +233,37 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> { let hasVerifier = 1; } +def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load"> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + DenseI32ArrayAttr:$sublane_offsets + ); + let results = (outs TPU_Vreg:$result); + let assemblyFormat = [{ + $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> { + let arguments = (ins + TPU_Vreg:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + DenseI32ArrayAttr:$sublane_offsets + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + // TODO(jevinjiang): deprecate to use dynamic_rotate. def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { let arguments = (ins diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 0202fbb3b7f7..5baec61ad138 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -488,6 +488,81 @@ LogicalResult RegionOp::verify() { return success(); } +LogicalResult ShuffledLoadOp::verify() { + if (getBase().getType().getRank() != getIndices().size()) { + return emitOpError("Base memref's rank and indices size do not match: ") + << getBase().getType().getRank() << " vs " << getIndices().size(); + } + if (getSublaneMask().size() != getType().getShape()[0]) { + return emitOpError("Expected sublane mask size equals to ") + << getType().getShape()[0] << " but got " << getSublaneMask().size(); + } + if (getSublaneOffsets().size() != getType().getShape()[0]) { + return emitOpError("Expected sublane offsets size equals to ") + << getType().getShape()[0] << " but got " + << getSublaneOffsets().size(); + } + return success(); +} + +LogicalResult ShuffledLoadOp::canonicalize(ShuffledLoadOp op, + PatternRewriter &rewriter) { + bool can_convert_to_simple_load = true; + for (int i = 0; i < op.getSublaneOffsets().size(); ++i) { + if (op.getSublaneOffsets()[i] != i) { + can_convert_to_simple_load = false; + break; + }; + } + if (can_convert_to_simple_load) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getBase(), op.getIndices(), op.getSublaneMask(), + /*sublane_stride=*/nullptr); + } + return success(); +} + +LogicalResult ShuffledStoreOp::verify() { + if (getBase().getType().getRank() != getIndices().size()) { + return emitOpError("Base memref's rank and indices size do not match: ") + << getBase().getType().getRank() << " vs " << getIndices().size(); + } + if (getValueToStore().getType().getRank() != getIndices().size()) { + return emitOpError( + "The rank of value to store and indices size do not match: ") + << getBase().getType().getRank() << " vs " << getIndices().size(); + } + if (getSublaneMask().size() != getValueToStore().getType().getShape()[0]) { + return emitOpError("Expected sublane mask size equals to ") + << getValueToStore().getType().getShape()[0] << " but got " + << getSublaneMask().size(); + } + if (getSublaneOffsets().size() != getValueToStore().getType().getShape()[0]) { + return emitOpError("Expected sublane offsets size equals to ") + << getValueToStore().getType().getShape()[0] << " but got " + << getSublaneOffsets().size(); + } + return success(); +} + +LogicalResult ShuffledStoreOp::canonicalize(ShuffledStoreOp op, + PatternRewriter &rewriter) { + bool can_convert_to_simple_store = true; + for (int i = 0; i < op.getSublaneOffsets().size(); ++i) { + if (op.getSublaneOffsets()[i] != i) { + can_convert_to_simple_store = false; + break; + }; + } + if (can_convert_to_simple_store) { + rewriter.replaceOpWithNewOp(op, op.getValueToStore(), + op.getBase(), op.getIndices(), + op.getSublaneMask(), + /*mask=*/nullptr, + /*sublane_stride=*/nullptr); + } + return success(); +} } // namespace tpu } // namespace mlir From 28dfe0d280883beed597924e85cdbb96d972cc58 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 13 Aug 2024 14:48:03 -0700 Subject: [PATCH 107/702] Import etils.epath lazily This shaves off an extra 0.1-0.2s from JAX import times internally. PiperOrigin-RevId: 662660356 --- jax/_src/compilation_cache_interface.py | 8 ++++--- jax/_src/path.py | 29 +++++++++++++++++-------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/jax/_src/compilation_cache_interface.py b/jax/_src/compilation_cache_interface.py index 95d557c5531e..480457871a2f 100644 --- a/jax/_src/compilation_cache_interface.py +++ b/jax/_src/compilation_cache_interface.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import abstractmethod +from __future__ import annotations + +import abc from jax._src import path as pathlib from jax._src import util @@ -21,10 +23,10 @@ class CacheInterface(util.StrictABC): _path: pathlib.Path - @abstractmethod + @abc.abstractmethod def get(self, key: str): pass - @abstractmethod + @abc.abstractmethod def put(self, key: str, value: bytes): pass diff --git a/jax/_src/path.py b/jax/_src/path.py index 1dd523249692..8c46c5560b3c 100644 --- a/jax/_src/path.py +++ b/jax/_src/path.py @@ -14,22 +14,33 @@ import logging import pathlib +import importlib.util -logger = logging.getLogger(__name__) +__all__ = ["Path"] -try: - import etils.epath as epath - epath_installed = True -except: - epath = None - epath_installed = False +logger = logging.getLogger(__name__) # If etils.epath (aka etils[epath] to pip) is present, we prefer it because it # can read and write to, e.g., GCS buckets. Otherwise we use the builtin # pathlib and can only read/write to the local filesystem. -if epath: +epath_installed = bool( + importlib.util.find_spec("etils") and + importlib.util.find_spec("etils.epath") +) +if epath_installed: logger.debug("etils.epath found. Using etils.epath for file I/O.") - Path = epath.Path + + def __dir__(): + return ["Path"] + + def __getattr__(name): + if name != "Path": + raise AttributeError(f"module '{__name__}' has no attribute '{name}") + + global Path + from etils import epath + Path = epath.Path + return Path else: logger.debug("etils.epath was not found. Using pathlib for file I/O.") Path = pathlib.Path From 98521ad35d1d19d4e2fc54b783869b8f0fa6e769 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 13 Aug 2024 14:52:42 -0700 Subject: [PATCH 108/702] Add todo for slow codegen in Pallas pipeline PiperOrigin-RevId: 662661951 --- jax/_src/pallas/mosaic/pipeline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 812c496306f0..bc548a69c5c5 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -158,6 +158,8 @@ def _grid_size(grid): def _get_indices(step, grid, offsets): """Get indices for a given step and grid.""" + # TODO(enriqueps): Implement using bitwise ops, avoid div/rem since they are + # expensive. extended_grid = grid + (1,) strides = tuple( itertools.accumulate(extended_grid[::-1], func=operator.mul))[::-1] From daa69da3216c3f85629914abaf9840c457781d95 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 13 Aug 2024 15:17:30 -0700 Subject: [PATCH 109/702] Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`). **Semantics** Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing. During compilation, the order of devices throughout the program needs to be consistent (same as before this change). Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh. **Why do this?** There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature. So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example: ``` mesh1 = Mesh(jax.devices()[:2], 'x') mesh2 = Mesh(jax.devices()[2:4], 'x') arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) @jax.jit def f(x): y = with_sharding_constraint(x, NamedSharding(mesh1, P('x'))) return y * 2 f(arr_mesh1) f(arr_mesh2) # DEVICE MISMATCH ERROR! ``` The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature. **Okay, so how do you fix this?** As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here) The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh. **The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.** ``` mesh1 = Mesh(jax.devices()[:2], 'x') mesh2 = Mesh(jax.devices()[2:4], 'x') arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) # Creating abstract mesh with mesh1 but since both meshes have the same shape (names # and axis size), it should be ok. abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple) @jax.jit def f(x): y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x'))) return y * 2 f(arr_mesh1) f(arr_mesh2) # tracing and lowering cache hit ``` **One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.** **What about `shard_map`?** shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`. ``` mesh1 = Mesh(jax.devices()[:2], 'x') mesh2 = Mesh(jax.devices()[2:4], 'x') arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) # Creating abstract mesh with mesh1 but since both meshes have the same shape (names # and axis size), it should be ok. abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple) @jax.jit def f(x): y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x')) return y * 2 f(arr_mesh1) f(arr_mesh2) # tracing and lowering cache hit ``` This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits! PiperOrigin-RevId: 662670932 --- jax/_src/dispatch.py | 6 +- jax/_src/interpreters/pxla.py | 3 +- jax/_src/mesh.py | 59 ++++++++++++++++++ jax/_src/pjit.py | 28 +++++++++ jax/_src/sharding_impls.py | 17 +++++- jax/_src/test_util.py | 19 ++++++ jax/experimental/shard_map.py | 39 ++++++++++-- jax/sharding.py | 1 + tests/pjit_test.py | 110 ++++++++++++++++++++++++++++++++++ tests/shard_map_test.py | 99 ++++++++++++++++++++++++++++++ 10 files changed, 371 insertions(+), 10 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 068b5e3b7e25..7b20ca9f6ac1 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -44,6 +44,7 @@ from jax._src.interpreters import xla from jax._src.interpreters import pxla from jax._src import lib +from jax._src.mesh import AbstractMesh from jax._src.lib import xla_client as xc from jax._src.monitoring import record_event_duration_secs from jax._src.partition_spec import PartitionSpec @@ -227,8 +228,11 @@ def get_intermediate_shardings( for eqn in jaxpr.eqns: if eqn.primitive is pjit.sharding_constraint_p: + s = eqn.params['sharding'] + if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): + continue source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield (eqn.params['sharding'], source_info) + yield (s, source_info) elif eqn.primitive is pjit.pjit_p: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) yield from ((i, source_info) for i in eqn.params['in_shardings']) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 12f981b309e0..baf475592f80 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1989,7 +1989,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings): except: return True for i in shardings: - if is_unspecified_or_auto(i): + if is_unspecified_or_auto(i) or i.memory_kind is None: continue if i.memory_kind != default_mem_kind: return False @@ -2426,6 +2426,7 @@ def _register_out_sharding_handler( def _gspmd_to_named_sharding( out_s: sharding_impls.GSPMDSharding, orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: + assert isinstance(orig_in_s.mesh, mesh_lib.Mesh) return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) _register_out_sharding_handler( diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 39663f7d711a..7bbf3d2b42eb 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -308,6 +308,10 @@ def local_devices(self): return [d for d in self.devices.flat if d.process_index == d.client.process_index()] + @functools.cached_property + def abstract_mesh(self): + return AbstractMesh(self.shape_tuple) + EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -318,3 +322,58 @@ def __init__(self): self.env = self.stack[-1] thread_resources = _ThreadResourcesLocalState() + + +class AbstractMesh: + """AbstractMesh contains only axis names and axis sizes. + + It does not contain concrete devices compared to `jax.sharding.Mesh`. You + should use this as an input to the sharding passed to with_sharding_constraint + and mesh passed to shard_map to avoid tracing and lowering cache misses when + your mesh shape and names stay the same but the devices change. + See the description of https://github.com/google/jax/pull/23022 for more + details. + """ + + def __init__(self, shape_tuple: tuple[tuple[str, int], ...]): + self.shape_tuple = shape_tuple + self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple)) + + def __hash__(self): + return hash(self.shape_tuple) + + def __eq__(self, other): + if not isinstance(other, AbstractMesh): + return False + if id(self) == id(other): + return True + return self.shape_tuple == other.shape_tuple + + def __repr__(self): + return f"AbstractMesh({self.shape_tuple})" + + @property + def axis_names(self): + return self._axis_names + + @functools.cached_property + def size(self): + return math.prod(self._axis_sizes) + + @functools.cached_property + def shape(self): + return collections.OrderedDict(self.shape_tuple) + + @property + def _is_jax_device_mesh(self): + return False + + @property + def _internal_device_list(self): + return None + + def __enter__(self): + raise RuntimeError("AbstractMesh is not a context manager") + + def __exit__(self, exc_type, exc_value, traceback): + raise RuntimeError("AbstractMesh is not a context manager") diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ee7af5183ad9..a09a958ab8a3 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -64,6 +64,7 @@ from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc from jax._src import sharding +from jax._src.mesh import AbstractMesh from jax._src.sharding_impls import ( NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, @@ -1996,6 +1997,9 @@ def _pjit_batcher_for_sharding( if spmd_axis_name is None: if sharding_impls.is_op_sharding_replicated(hlo_s): return s + if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): + parsed_pspec = s._parsed_pspec.insert_axis_partitions(dim, None) + return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec) new_op = hlo_s.to_proto().clone() tad = list(new_op.tile_assignment_dimensions) tad.insert(dim, 1) @@ -2005,6 +2009,9 @@ def _pjit_batcher_for_sharding( _device_list=getattr(s, '_internal_device_list', None)) return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0] else: + if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): + parsed_pspec = s._parsed_pspec.insert_axis_partitions(dim, spmd_axis_name) + return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec) if isinstance(s, NamedSharding): mesh = s.mesh if mesh is None or mesh.empty: @@ -2470,6 +2477,27 @@ def _identity_fn(x): return x def _sharding_constraint_impl(x, sharding, layout, resource_env, unconstrained_dims): + if (isinstance(sharding, NamedSharding) and + isinstance(sharding.mesh, AbstractMesh)): + if not hasattr(x, 'sharding'): + aval = shaped_abstractify(x) + raise ValueError( + 'Target sharding contains a `jax.sharding.AbstractMesh` which' + ' requires the input passed should be a `jax.Array`. Got' + f' {type(x)} with shape {aval.str_short()}') + if not isinstance(x.sharding, NamedSharding): + raise TypeError( + 'The sharding on the input must be a `NamedSharding` since the target' + ' sharding has an `AbstractMesh` in it. Got sharding type' + f' {type(x.sharding)}') + if x.sharding.mesh.shape_tuple != sharding.mesh.shape_tuple: + raise ValueError( + f'Mesh shape of the input {x.sharding.mesh.shape_tuple} does not' + ' match the mesh shape of the target sharding' + f' {sharding.mesh.shape_tuple}') + sharding = NamedSharding._from_parsed_pspec( + x.sharding.mesh, sharding._parsed_pspec) + if layout is None: if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim): return x diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 1a23f4ba74ad..e99184299f4e 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -191,7 +191,7 @@ class NamedSharding(sharding.Sharding): >>> named_sharding = jax.sharding.NamedSharding(mesh, spec) """ - mesh: mesh_lib.Mesh + mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh spec: PartitionSpec _memory_kind: str | None _parsed_pspec: ParsedPartitionSpec @@ -199,7 +199,7 @@ class NamedSharding(sharding.Sharding): @use_cpp_method() def __init__( - self, mesh: mesh_lib.Mesh, spec: PartitionSpec, *, + self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *, memory_kind: str | None = None, _parsed_pspec=None, _manual_axes=frozenset()): self.mesh = mesh @@ -259,20 +259,32 @@ def _from_parsed_pspec( @property def device_set(self) -> set[Device]: + if isinstance(self.mesh, mesh_lib.AbstractMesh): + raise ValueError( + 'device_set is not implemented for `jax.sharding.AbstractMesh`.') return self.mesh._flat_devices_set @property def _device_assignment(self) -> XLADeviceAssignment: + if isinstance(self.mesh, mesh_lib.AbstractMesh): + raise ValueError('_device_assignment is not implemented for' + ' `jax.sharding.AbstractMesh`.') return self.mesh._flat_devices_tuple @property def is_fully_addressable(self) -> bool: + if isinstance(self.mesh, mesh_lib.AbstractMesh): + raise ValueError('is_fully_addressable is not implemented for ' + '`jax.sharding.AbstractMesh`.') # Speed up `is_fully_addressable` since there is a high chance that the # mesh across multiple NamedSharding objects will be the same. return not self.mesh.is_multi_process @property def addressable_devices(self) -> set[Device]: + if isinstance(self.mesh, mesh_lib.AbstractMesh): + raise ValueError('addressable_devices is not implemented for ' + '`jax.sharding.AbstractMesh`.') # Override addressable devices because there is a high chance that the mesh # across multiple NamedSharding objects will be the same. return self.mesh._local_devices_set @@ -1623,6 +1635,7 @@ def logical_sharding(aval, phys_sharding) -> sharding.Sharding: sharding_spec=logical_sharding_spec) elif isinstance(phys_sharding, NamedSharding): logical_gs = get_logical_gspmd_sharding(aval, phys_sharding) + assert isinstance(phys_sharding.mesh, mesh_lib.Mesh) return _gspmd_to_named_sharding_via_mesh( logical_gs, phys_sharding.mesh) else: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index b19110dfc516..b0f06624b718 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -383,6 +383,25 @@ def mlir_lower_and_count(*args, **kwargs): mlir.lower_jaxpr_to_module = mlir_lower +@contextmanager +def count_jit_compilation_cache_miss(): + # No need to clear any caches since we generally jit and pmap fresh callables + # in tests. + + jit_compilation = pxla._cached_compilation + count = [0] + + def compile_and_count(*args, **kwargs): + count[0] += 1 + return jit_compilation(*args, **kwargs) + + pxla._cached_compilation = compile_and_count + try: + yield count + finally: + pxla._cached_compilation = jit_compilation + + @contextmanager def count_subjaxpr_to_hlo_conversion(fun_name: str): # No need to clear any caches since we generally jit and pmap fresh callables diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index e16fa2814e9d..fa75c48292bf 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -26,8 +26,9 @@ import jax import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec, Mesh +from jax.sharding import NamedSharding, PartitionSpec from jax._src import ad_checkpoint +from jax._src import array from jax._src import ad_util from jax._src import callback from jax._src import config @@ -46,6 +47,7 @@ from jax._src import traceback_util from jax._src import util from jax._src.core import Tracer +from jax._src.mesh import AbstractMesh, Mesh from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, @@ -79,8 +81,9 @@ @traceback_util.api_boundary -def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, - check_rep: bool = True, auto: frozenset[AxisName] = frozenset()): +def shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, + out_specs: Specs, check_rep: bool = True, + auto: frozenset[AxisName] = frozenset()): """Map a function over shards of data. Note: @@ -134,14 +137,15 @@ def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, """ return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto) -def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs, +def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs | Callable[[], Specs], check_rep: bool, auto: frozenset[AxisName]): if not callable(f): raise TypeError("shard_map requires a callable for its first argument, " f"but got {f} of type {type(f)}.") - if not isinstance(mesh, Mesh): - raise TypeError("shard_map requires a `jax.sharding.Mesh` instance for its " + if not isinstance(mesh, (Mesh, AbstractMesh)): + raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " + "`jax.sharding.AbstractMesh` instance for its " f"second argument, but got {mesh} of type {type(mesh)}.") if not auto.issubset(mesh.axis_names): raise ValueError(f"shard_map requires auto={auto} to be a subset of " @@ -711,10 +715,33 @@ def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: # Eager evaluation +def get_mesh_from_args(args_flat, mesh): + for a in args_flat: + if hasattr(a, 'sharding'): + if not isinstance(a.sharding, NamedSharding): + raise TypeError( + "shard_map got `AbstractMesh` as an input to the `mesh` argument" + " which requires the input's sharding to be a `NamedSharding`. Got" + f" sharding type {type(a.sharding)}") + if a.sharding.mesh.shape_tuple != mesh.shape_tuple: + raise ValueError( + f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" + " match the mesh shape passed to shard_map " + f" {mesh.shape_tuple}") + mesh = a.sharding.mesh + if isinstance(mesh, AbstractMesh): + raise ValueError( + "Please pass `jax.Array`s with a `NamedSharding` as input to" + " `shard_map` when passing `AbstractMesh` to the mesh argument.") + assert isinstance(mesh, Mesh) + return mesh + def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): if auto: raise NotImplementedError del prim, auto + if isinstance(mesh, AbstractMesh): + mesh = get_mesh_from_args(args, mesh) args = map(partial(_unmatch_spec, mesh), in_names, args) in_rep = map(partial(_in_names_to_rep, mesh), in_names) with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main: diff --git a/jax/sharding.py b/jax/sharding.py index fe221f90af67..ea92e9d17e42 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -28,6 +28,7 @@ PartitionSpec as PartitionSpec, ) from jax._src.interpreters.pxla import Mesh as Mesh +from jax._src.mesh import AbstractMesh _deprecations = { # Added Jun 4, 2024. diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9368f7da9cfe..db105b527246 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -56,6 +56,7 @@ from jax._src.lib.mlir import dialects from jax._src import xla_bridge from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib import xla_extension from jax._src.util import curry, unzip2 @@ -4357,6 +4358,115 @@ def f(x): "Compiled object called with input sharding.*does not match"): compiled(cpu_arr) + @unittest.skipIf(xla_extension_version < 281, + 'Requires xla_extension_version >= 281') + def test_different_devices_wsc_abstract_mesh_cache_hit(self): + if jax.device_count() < 4: + self.skipTest('Requires >=4 devices') + + mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x') + mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'x') + + @jax.jit + def f(x): + x = with_sharding_constraint( + x, NamedSharding(mesh_lib.AbstractMesh(mesh1.shape_tuple), P('x'))) + return jnp.sin(x) + + with ( + jtu.count_jit_tracing_cache_miss() as tracing_count, + jtu.count_jit_and_pmap_lowerings() as lowering_count, + jtu.count_jit_compilation_cache_miss() as compilation_count, + ): + a = jax.device_put(np.arange(8.), NamedSharding(mesh1, P())) + out_a = f(a) # tracing and lowering cached + + # same num_devices but different devices. + b = jax.device_put(out_a, NamedSharding(mesh2, P())) + f(b) # tracing and lowering cache *hit* + self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin` + self.assertEqual(lowering_count[0], 1) + self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ. + + @unittest.skipIf(xla_extension_version < 281, + 'Requires xla_extension_version >= 281') + def test_wsc_abstract_mesh(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) + + def f(x): + x = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x'))) + return x * 2 + + out = jax.jit(f)(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + out_eager = f(arr) + self.assertArraysEqual(out_eager, np_inp * 2) + self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x'))) + + @unittest.skipIf(xla_extension_version < 281, + 'Requires xla_extension_version >= 281') + def test_wsc_sds_abstract_mesh(self): + mesh = jtu.create_global_mesh((2,), 'x') + s = NamedSharding(mesh, P()) + abstract_mesh = mesh_lib.AbstractMesh(mesh.shape_tuple) + + @jax.jit + def f(x): + x = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x'))) + return x * 2 + + sds = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s) + f.eval_shape(sds) # doesn't crash + + @unittest.skipIf(xla_extension_version < 281, + 'Requires xla_extension_version >= 281') + def test_wsc_vmap_abstract_mesh(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, s) + + def f(x): + x = with_sharding_constraint(x, NamedSharding(mesh.abstract_mesh, P('x'))) + return x * 2 + + out = jax.jit(jax.vmap(f))(arr) # doesn't crash + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x'))) + + out2 = jax.jit(jax.vmap(f, spmd_axis_name='y'))(arr) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', 'x'))) + + @unittest.skipIf(xla_extension_version < 281, + 'Requires xla_extension_version >= 281') + def test_wsc_abstract_mesh_errors(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + np_inp = np.arange(8) + abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) + s_abs = NamedSharding(abstract_mesh, P('x')) + + with self.assertRaisesRegex( + ValueError, ".*requires the input passed should be a `jax.Array`.*"): + with_sharding_constraint(np_inp, s_abs) + + with self.assertRaisesRegex( + TypeError, "The sharding on the input must be a `NamedSharding`"): + with_sharding_constraint(jnp.arange(8), s_abs) + + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + abs_mesh2 = mesh_lib.AbstractMesh( + jtu.create_global_mesh((2,), 'y').shape_tuple) + with self.assertRaisesRegex( + ValueError, + 'Mesh shape of the input.*does not' + ' match the mesh shape of the target sharding.*'): + with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y'))) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index bb1763fbcd3e..47d2224ccd09 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -38,11 +38,13 @@ from jax._src import test_util as jtu from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals +from jax._src.mesh import AbstractMesh from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src import linear_util as lu from jax._src import tree_util import jax.numpy as jnp +from jax._src.lib import xla_extension_version from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.shard_map import shard_map @@ -743,6 +745,103 @@ def f(x): self.assertIn('out_names', e.params) self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) + @unittest.skipIf(xla_extension_version < 281, + 'Requires xla_extension_version >= 281') + def test_shard_map_abstract_mesh(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + def f(x): + return shard_map(lambda x: x, mesh=mesh.abstract_mesh, in_specs=P('x'), + out_specs=P('x'))(x) + + out1 = jax.jit(f)(arr) + self.assertArraysEqual(out1, np_inp) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('x'))) + + out_eager = f(arr) + self.assertArraysEqual(out_eager, np_inp) + self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x'))) + + out1, out2 = shard_map(lambda x, y: (x, y), mesh=mesh.abstract_mesh, + in_specs=P('x'), out_specs=P('x'))(np_inp, arr) + self.assertArraysEqual(out1, np_inp) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out2, np_inp) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) + + @unittest.skipIf(xla_extension_version < 281, + 'Requires xla_extension_version >= 281') + def test_different_devices_shmap_abstract_mesh_cache_hit(self): + if jax.device_count() < 4: + self.skipTest('Requires >=4 devices') + + mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'i') + mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'i') + abstract_mesh = AbstractMesh(mesh1.shape_tuple) + + @jax.jit + def f(x): + x = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('i'), + out_specs=P('i'))(x) + return jnp.sin(x) + + with ( + jtu.count_jit_tracing_cache_miss() as tracing_count, + jtu.count_jit_and_pmap_lowerings() as lowering_count, + jtu.count_jit_compilation_cache_miss() as compilation_count, + ): + a = jax.device_put(np.arange(8.), NamedSharding(mesh1, P())) + out_a = f(a) # tracing and lowering cached + + # same num_devices but different devices. + b = jax.device_put(out_a, NamedSharding(mesh2, P())) + f(b) # tracing and lowering cache *hit* + + self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin` + self.assertEqual(lowering_count[0], 1) + self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ. + + @unittest.skipIf(xla_extension_version < 281, + 'Requires xla_extension_version >= 281') + def test_shmap_abstract_mesh_errors(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + np_inp = np.arange(8) + abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) + + with self.assertRaisesRegex( + TypeError, + 'shard_map got `AbstractMesh` as an input to the `mesh` argument' + " which requires the input's sharding to be a `NamedSharding`"): + shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), + out_specs=P('x'))(jnp.arange(8)) + + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + mesh2 = jtu.create_global_mesh((2,), 'y') + abs_mesh2 = AbstractMesh(mesh2.shape_tuple) + with self.assertRaisesRegex( + ValueError, + 'Mesh shape of the input.*does not match the mesh shape passed to' + ' shard_map'): + shard_map(lambda x: x, mesh=abs_mesh2, in_specs=P('y'), + out_specs=P('y'))(arr) + + with self.assertRaisesRegex( + ValueError, + 'Please pass `jax.Array`s with a `NamedSharding` as input to' + ' `shard_map` when passing `AbstractMesh` to the mesh argument.'): + shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), + out_specs=P('x'))(np_inp) + + arr_mesh2 = jax.device_put(np_inp, NamedSharding(mesh2, P('y'))) + with self.assertRaisesRegex( + ValueError, + 'Mesh shape of the input.*does not match the mesh shape passed to' + ' shard_map'): + shard_map(lambda x, y: (x, y), mesh=abstract_mesh, in_specs=P('x'), + out_specs=P('x'))(arr, arr_mesh2) + @parameterized.parameters([True, False]) @jtu.run_on_devices('cpu', 'gpu', 'tpu') def test_debug_print_jit(self, jit): From 8f23392a8c7595da712d7017d057116af021516d Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 13 Aug 2024 15:22:02 -0700 Subject: [PATCH 110/702] [Mosaic:TPU] Refactor relayout helper functions to take ctx instead of only target shape. PiperOrigin-RevId: 662672417 --- .../tpu/transforms/apply_vector_layout.cc | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index f85692a61624..437d4d344dcb 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5287,9 +5287,10 @@ FailureOr> doColumnShiftRelayout( } FailureOr>> changeOffsets( - OpBuilder &builder, const std::array target_shape, - const Location loc, const VectorType vty, const VectorLayout src, - xla::Array vregs, const LayoutOffsets dst_offsets) { + RewriteContext &ctx, OpBuilder &builder, const Location loc, + const VectorType vty, const VectorLayout src, xla::Array vregs, + const LayoutOffsets dst_offsets) { + const auto &target_shape = ctx.target_shape; const VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(), src.implicit_dim()); const int packing = src.packing(); @@ -5389,10 +5390,10 @@ FailureOr>> changeOffsets( // TODO(b/265133506): Generalize retiling. FailureOr>> changeTiling( - OpBuilder &builder, const std::array target_shape, - const Location loc, VectorType vty, const VectorLayout src, - xla::Array vregs, const std::array dst_tiling, - bool try_replicate_rows) { + RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, + const VectorLayout src, xla::Array vregs, + const std::array dst_tiling, bool try_replicate_rows) { + const auto &target_shape = ctx.target_shape; if (src.tiling() == dst_tiling) { return std::pair(src, std::move(vregs)); } @@ -5594,10 +5595,11 @@ FailureOr>> changeTiling( } FailureOr>> changeImplicitDim( - OpBuilder &builder, const std::array target_shape, - const Location loc, VectorType vty, const VectorLayout src, - xla::Array vregs, const VectorLayout::ImplicitDim dst_implicit_dim, + RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, + const VectorLayout src, xla::Array vregs, + const VectorLayout::ImplicitDim dst_implicit_dim, const LayoutOffsets dst_offset_hints) { + const auto &target_shape = ctx.target_shape; if (src.implicit_dim() == dst_implicit_dim) { return std::make_pair(src, std::move(vregs)); } @@ -5625,8 +5627,7 @@ FailureOr>> changeImplicitDim( src.tiling(), dst_implicit_dim); xla::Array new_vregs( dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - new_vregs.Each([&](const absl::Span idx, - Value *tile) { + new_vregs.Each([&](const absl::Span idx, Value *tile) { const int64_t dst_2nd_minor_idx = idx.size() - 2; SmallVector src_idx(idx.begin(), idx.end()); src.insertImplicit(src_idx, 0); @@ -5751,21 +5752,21 @@ FailureOr> relayout(RewriteContext &ctx, FAILUREOR_ASSIGN_OR_RETURN( std::tie(src, src_tiles), - changeTiling(builder, ctx.target_shape, v.getLoc(), vty, src, - std::move(src_tiles), dst.tiling(), + changeTiling(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), + dst.tiling(), dst.offsets()[0] == std::nullopt && src.offsets()[0] != std::nullopt)); FAILUREOR_ASSIGN_OR_RETURN( std::tie(src, src_tiles), - changeImplicitDim(builder, ctx.target_shape, v.getLoc(), vty, src, + changeImplicitDim(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), dst.implicit_dim(), dst.offsets())); FAILUREOR_ASSIGN_OR_RETURN( std::tie(src, src_tiles), - changeOffsets(builder, ctx.target_shape, v.getLoc(), vty, src, - std::move(src_tiles), dst.offsets())); + changeOffsets(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), + dst.offsets())); CHECK_EQ(src, dst); // At this point we've should be done. return assemble(builder, vty, dst, std::move(src_tiles), target_shape, From 4e580d167efda89d460ea0410a29da82bcae5f1b Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 13 Aug 2024 15:22:48 -0700 Subject: [PATCH 111/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/55476059f622468985141311ef20328993bd7ba5. PiperOrigin-RevId: 662672660 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f76c1698e8a2..e6b18e38bd40 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "b18bc612506b4fb759e5930b9d4b24d4c33dbdbd" -XLA_SHA256 = "48b2cc62b3e99ba4e60088aad9673489001d98c8103da9bdef90fdfdb8a76dd7" +XLA_COMMIT = "55476059f622468985141311ef20328993bd7ba5" +XLA_SHA256 = "1cbc9b2956154d724018302bcf23343f3cbe7dacd114b101615339614d953fb9" def repo(): tf_http_archive( From 5903c772f4c23301ed4a87c590033f2e10966642 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 13 Aug 2024 16:06:47 -0700 Subject: [PATCH 112/702] doc: clarify data_fields & meta_fields in register_dataclass --- jax/_src/tree_util.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 2b69c80edad6..07beed7276a8 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -941,14 +941,16 @@ def register_dataclass( attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed among ``meta_fields`` or ``data_fields``. - meta_fields: auxiliary data field names. These fields *must* contain static, - hashable, immutable objects, as these objects are used to generate JIT cache - keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or - :class:`numpy.ndarray` objects. data_fields: data field names. These fields *must* be JAX-compatible objects such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or - pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be - ``None``, as this is recognized by JAX as an empty pytree. + pytrees whose leaves are arrays or scalars. Note that ``None`` is valid, as + this is recognized by JAX as an empty pytree. + meta_fields: auxiliary data field names. These fields will be considered static + within JAX transformations such as :func:`jax.jit`. The listed fields *must* + contain static, hashable, immutable objects, as these objects are used to + generate JIT cache keys: for example strings, Python scalars, or array shapes + and dtypes. In particular, ``meta_fields`` cannot contain :class:`jax.Array` + or :class:`numpy.ndarray` objects, as they are not hashable. Returns: The input class ``nodetype`` is returned unchanged after being added to JAX's From 323e257f678627591363f78669d3b4f002e445b6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 13 Aug 2024 17:01:33 -0700 Subject: [PATCH 113/702] Fix test failures. PiperOrigin-RevId: 662703221 --- tests/lax_numpy_reducers_test.py | 2 +- tests/lax_numpy_test.py | 2 +- tests/lax_scipy_test.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 4767e48c3f5e..0edc09fa7c14 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -715,7 +715,7 @@ def np_fun(*args): # TODO(phawkins): we currently set dtype=False because we aren't as # aggressive about promoting to float64. It's not clear we want to mimic # Numpy here. - tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6} + tol_spec = {np.float16: 4e-2, np.float32: 2e-4, np.float64: 5e-6} tol = max(jtu.tolerance(a_dtype, tol_spec), jtu.tolerance(q_dtype, tol_spec)) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 860b3358d2ef..f09553c83e14 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4768,7 +4768,7 @@ def np_fun(condlist, choicelist, default): else x.astype(np.float32) for x in choicelist] dtype = jnp.result_type(default, *choicelist) return np.select(condlist, - [np.asarray(x, dtype=dtype) for x in choicelist], + [np.asarray(x).astype(dtype) for x in choicelist], np.asarray(default, dtype=dtype)) with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(np_fun, jnp.select, args_maker, diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 66d84c427fea..50d2ee7259dd 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -226,7 +226,8 @@ def lax_fun(a): rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, - tol={np.float32: 1e-3, np.float64: 1e-14}) + tol={np.float32: 1e-3, np.float64: 1e-14}, + check_dtypes=False) self._CompileAndCheck( lax_fun, args_maker, rtol={ np.float32: 5e-5 if jtu.test_device_matches(["tpu"]) else 1e-05, From 9a8f0a67f55a461432b23fc38bfaf2d84ab4f9bd Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 13 Aug 2024 17:37:27 -0700 Subject: [PATCH 114/702] Add a devices property to AbstractMesh but raise an error in it. This is to make pytype happy PiperOrigin-RevId: 662712450 --- jax/_src/mesh.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 7bbf3d2b42eb..fec1f5ef1779 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -372,8 +372,38 @@ def _is_jax_device_mesh(self): def _internal_device_list(self): return None + @property + def empty(self): + return self.size == 0 + + @property + def devices(self): + _raise_value_error("devices") + + @property + def device_ids(self): + _raise_value_error("device_ids") + + @property + def is_multi_process(self): + _raise_value_error("is_multi_process") + + @property + def local_devices(self): + _raise_value_error("local_devices") + + @property + def local_mesh(self): + _raise_value_error("local_mesh") + def __enter__(self): raise RuntimeError("AbstractMesh is not a context manager") def __exit__(self, exc_type, exc_value, traceback): raise RuntimeError("AbstractMesh is not a context manager") + + +# Create this indirection because pytype fails to recognize a property if a +# property raises an exception unconditionally. Remove this once that is fixed. +def _raise_value_error(name): + raise ValueError(f"AbstractMesh does not implement {name}") From d17edb4c4dc409775f8bad2ec959efd58629449b Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 13 Aug 2024 18:28:44 -0700 Subject: [PATCH 115/702] docs: fix `shard_map` guide headings These were off by one level, causing section titles to be listed in the guide index. --- docs/notebooks/shard_map.ipynb | 46 ++++++++++++++++++---------------- docs/notebooks/shard_map.md | 46 ++++++++++++++++++---------------- 2 files changed, 48 insertions(+), 44 deletions(-) diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 002582806bdc..aa3c0e276fb1 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -9,6 +9,8 @@ "\n", "\n", "\n", + "## Overview\n", + "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", @@ -36,7 +38,7 @@ "id": "97c57a94", "metadata": {}, "source": [ - "## So, let's see a `shard_map`!\n", + "### So, let's see a `shard_map`!\n", "\n", "Without further ado, here's a toy example:" ] @@ -189,9 +191,9 @@ "id": "532fe5f6", "metadata": {}, "source": [ - "## Slow down, start with the basics!\n", + "### Slow down, start with the basics!\n", "\n", - "### Rank-reducing vs rank-preserving maps\n", + "#### Rank-reducing vs rank-preserving maps\n", "\n", "We can think of `vmap` and `pmap` as unstacking each array input along an axis\n", "(e.g. unpacking a 2D matrix into its 1D rows), applying its body function to\n", @@ -274,7 +276,7 @@ "over 4 devices) then semantically we get 4 logical applications of the\n", "function, corresponding to the 4 devices physically computing them.\n", "\n", - "### Controlling how each input is split (unconcatenated) and tiled with `in_specs`\n", + "#### Controlling how each input is split (unconcatenated) and tiled with `in_specs`\n", "\n", "Each of the `in_specs` identifies some of the corresponding input array's axes\n", "with mesh axes by name using `PartitionSpec`s, representing how to split (or\n", @@ -354,7 +356,7 @@ "Physical data movement is possible on inputs, as each device needs to have a\n", "copy of the appropriate data.\n", "\n", - "### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`\n", + "#### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`\n", "\n", "Analogously to the input side, each of the `out_specs` identifies some of the\n", "corresponding output array's axes with mesh axes by name, representing how the\n", @@ -482,7 +484,7 @@ "`Array`s, or physically how to interpret the buffers across devices as the\n", "physical layout of a single logical `Array`.\n", "\n", - "# API Specification\n", + "## API Specification\n", "\n", "```python\n", "from jax.sharding import Mesh\n", @@ -508,7 +510,7 @@ "the corresponding `PartitionSpec` `spec` as roughly\n", "`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.\n", "\n", - "# Collectives tutorial\n", + "## Collectives tutorial\n", "\n", "A `shard_map` need not be a pure map: function applications can communicate\n", "with each other via _collectives_, using axis names defined in the `mesh`\n", @@ -572,7 +574,7 @@ "means communication across devices. Exactly what communication happens, and\n", "what values are computed, depend on the collective.\n", "\n", - "## `psum`\n", + "### `psum`\n", "\n", "The simplest collective may be `jax.lax.psum`, which computes an\n", "all-reduce-sum along a device mesh axis (or multiple axes).\n", @@ -714,7 +716,7 @@ "In the sequel, we'll see how `psum` can be implemented in terms of other\n", "primitives, which gives some intuition about its communication cost.\n", "\n", - "## `all_gather`\n", + "### `all_gather`\n", "\n", "Another fundamental operation is gathering array shards along an axis, so that\n", "each function application has a full copy of the data along that axis:\n", @@ -796,7 +798,7 @@ "In deep learning, we might use `all_gather`s on parameters in fully sharded\n", "data parallelism (FSDP).\n", "\n", - "## `psum_scatter`\n", + "### `psum_scatter`\n", "\n", "The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like\n", "`psum` except each function instance gets only one shard of the result:\n", @@ -871,7 +873,7 @@ "multiplies or fully-sharded data parallel gradient accumulation, as shown in\n", "the examples to follow.\n", "\n", - "## `ppermute`\n", + "### `ppermute`\n", "\n", "The `jax.lax.ppermute` collective provides the most direct way for\n", "function instances to send data to one another. Given a mesh axis and a\n", @@ -998,7 +1000,7 @@ "spatial axes and thus devices must communicate \"halos\" to each other. Or it\n", "may be used under-the-hood in tensor-parallel matrix multiplies.\n", "\n", - "## `all_to_all`\n", + "### `all_to_all`\n", "\n", "A final collective is `all_to_all`, which is essentially a block matrix\n", "transpose operating along one positional axis and one cross-device axis:\n", @@ -1059,12 +1061,12 @@ "where we first sort our local batch of examples according to which expert they\n", "should go to, then apply an `all_to_all` to redistribute examples to experts.\n", "\n", - "# Toy examples\n", + "## Toy examples\n", "\n", "How might we use `shard_map` and collective communication in practice? These\n", "examples, while simple, give some idea.\n", "\n", - "## Matrix multiplies\n", + "### Matrix multiplies\n", "\n", "Parallelizing matrix multiplication is central in scaling up deep learning\n", "models, both for training and for inference. When `jax.jit` automatically\n", @@ -1107,7 +1109,7 @@ "id": "2e2b33b9", "metadata": {}, "source": [ - "### Example 1: `all-gather` on one side\n", + "#### Example 1: `all-gather` on one side\n", "\n", "Consider performing a matrix multiplication where we shard the left-hand side\n", "argument (can think: parameters) on its leading (non-contracting) dimension:" @@ -1301,7 +1303,7 @@ "`jax.lax.fori_loop`. We might also have additional axes of parallelism\n", "involved.\n", "\n", - "### Example 2: `psum_scatter` the result\n", + "#### Example 2: `psum_scatter` the result\n", "\n", "Another sharding we might start with has both `lhs` and `rhs` sharded along\n", "their contracting dimensions, with the output sharded like `rhs` again:" @@ -1446,7 +1448,7 @@ "id": "60c2d2bc", "metadata": {}, "source": [ - "## Neural networks\n", + "### Neural networks\n", "\n", "We can use `shard_map` to parallelize computation in neural networks, either by\n", "itself or in combination with the automatic partitioning in `jax.jit`. This\n", @@ -1524,7 +1526,7 @@ "functions to use different parallelization strategies, with `shard_map` we\n", "often do.\n", "\n", - "### 8-way batch data parallelism\n", + "#### 8-way batch data parallelism\n", "\n", "The simplest multi-device parallelism strategy is to shard the batch of inputs\n", "and targets over multiple devices, replicate the parameters over those devices,\n", @@ -1608,7 +1610,7 @@ "end of the forward pass to compute the loss value, and in the backward pass to\n", "compute the total parameter gradients.\n", "\n", - "### 8-way fully sharded data parallelism (FSDP)\n", + "#### 8-way fully sharded data parallelism (FSDP)\n", "\n", "Another strategy is to additionally shard the parameters over the devices,\n", "all-gathering each one when the full value is needed for the `jnp.dot` or bias\n", @@ -1697,7 +1699,7 @@ "id": "f88ddefe", "metadata": {}, "source": [ - "### 8-way tensor parallelism (TP)\n", + "#### 8-way tensor parallelism (TP)\n", "\n", "Usually we don't use tensor model parallelism by itself, but seeing it in\n", "isolation is a good warmup on parallel matrix multiplication. It's also a good\n", @@ -1750,7 +1752,7 @@ "id": "cf59d537", "metadata": {}, "source": [ - "### FSDP + TP, with `shard_map` at the top level\n", + "#### FSDP + TP, with `shard_map` at the top level\n", "\n", "We can compose these strategies together, using multiple axes of parallelism." ] @@ -1821,7 +1823,7 @@ "id": "94a352ca", "metadata": {}, "source": [ - "### SPMD pipeline parallelism (PP)\n", + "#### SPMD pipeline parallelism (PP)\n", "\n", "With pipeline parallelism we aim to parallelize the evaluation of layers at\n", "different depths in our network. For example, one device might compute the\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index bda373983ccd..8b2c2d6fbdcd 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -18,6 +18,8 @@ kernelspec: +## Overview + `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. `shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. @@ -33,7 +35,7 @@ import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices ``` -## So, let's see a `shard_map`! +### So, let's see a `shard_map`! Without further ado, here's a toy example: @@ -120,9 +122,9 @@ print('b blocks:'); jax.debug.visualize_array_sharding(b) print('c blocks:'); jax.debug.visualize_array_sharding(c) ``` -## Slow down, start with the basics! +### Slow down, start with the basics! -### Rank-reducing vs rank-preserving maps +#### Rank-reducing vs rank-preserving maps We can think of `vmap` and `pmap` as unstacking each array input along an axis (e.g. unpacking a 2D matrix into its 1D rows), applying its body function to @@ -181,7 +183,7 @@ by any input axis size: for example, if we have a mesh of total size 4 (i.e. over 4 devices) then semantically we get 4 logical applications of the function, corresponding to the 4 devices physically computing them. -### Controlling how each input is split (unconcatenated) and tiled with `in_specs` +#### Controlling how each input is split (unconcatenated) and tiled with `in_specs` Each of the `in_specs` identifies some of the corresponding input array's axes with mesh axes by name using `PartitionSpec`s, representing how to split (or @@ -237,7 +239,7 @@ along the first axis, and used the pspec `P(('j', 'i'), None)`. Physical data movement is possible on inputs, as each device needs to have a copy of the appropriate data. -### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs` +#### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs` Analogously to the input side, each of the `out_specs` identifies some of the corresponding output array's axes with mesh axes by name, representing how the @@ -329,7 +331,7 @@ Instead, `out_specs` just encodes how to assemble the block outputs into `Array`s, or physically how to interpret the buffers across devices as the physical layout of a single logical `Array`. -# API Specification +## API Specification ```python from jax.sharding import Mesh @@ -355,7 +357,7 @@ from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and the corresponding `PartitionSpec` `spec` as roughly `tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`. -# Collectives tutorial +## Collectives tutorial A `shard_map` need not be a pure map: function applications can communicate with each other via _collectives_, using axis names defined in the `mesh` @@ -419,7 +421,7 @@ collective introduces some amount of cross-block dependence. Physically, that means communication across devices. Exactly what communication happens, and what values are computed, depend on the collective. -## `psum` +### `psum` The simplest collective may be `jax.lax.psum`, which computes an all-reduce-sum along a device mesh axis (or multiple axes). @@ -513,7 +515,7 @@ have a `grad` inside the `shard_map`ped function body, total gradients. In the sequel, we'll see how `psum` can be implemented in terms of other primitives, which gives some intuition about its communication cost. -## `all_gather` +### `all_gather` Another fundamental operation is gathering array shards along an axis, so that each function application has a full copy of the data along that axis: @@ -571,7 +573,7 @@ def all_gather_ref(_, x_blocks, *, tiled=False): In deep learning, we might use `all_gather`s on parameters in fully sharded data parallelism (FSDP). -## `psum_scatter` +### `psum_scatter` The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like `psum` except each function instance gets only one shard of the result: @@ -634,7 +636,7 @@ machine learning, `psum_scatter` can be used in tensor-parallel matrix multiplies or fully-sharded data parallel gradient accumulation, as shown in the examples to follow. -## `ppermute` +### `ppermute` The `jax.lax.ppermute` collective provides the most direct way for function instances to send data to one another. Given a mesh axis and a @@ -731,7 +733,7 @@ parallelizing the evaluation of convolutional layers, where we shard over spatial axes and thus devices must communicate "halos" to each other. Or it may be used under-the-hood in tensor-parallel matrix multiplies. -## `all_to_all` +### `all_to_all` A final collective is `all_to_all`, which is essentially a block matrix transpose operating along one positional axis and one cross-device axis: @@ -780,12 +782,12 @@ In deep learning, we might use `all_to_all` in mixture-of-expert routing, where we first sort our local batch of examples according to which expert they should go to, then apply an `all_to_all` to redistribute examples to experts. -# Toy examples +## Toy examples How might we use `shard_map` and collective communication in practice? These examples, while simple, give some idea. -## Matrix multiplies +### Matrix multiplies Parallelizing matrix multiplication is central in scaling up deep learning models, both for training and for inference. When `jax.jit` automatically @@ -810,7 +812,7 @@ def device_put(x, pspec): return jax.device_put(x, NamedSharding(mesh, pspec)) ``` -### Example 1: `all-gather` on one side +#### Example 1: `all-gather` on one side Consider performing a matrix multiplication where we shard the left-hand side argument (can think: parameters) on its leading (non-contracting) dimension: @@ -926,7 +928,7 @@ In practice, to reduce compile times we would probably roll this into a `jax.lax.fori_loop`. We might also have additional axes of parallelism involved. -### Example 2: `psum_scatter` the result +#### Example 2: `psum_scatter` the result Another sharding we might start with has both `lhs` and `rhs` sharded along their contracting dimensions, with the output sharded like `rhs` again: @@ -1011,7 +1013,7 @@ out = matmul_psumscatter_overlapped_bidi(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) ``` -## Neural networks +### Neural networks We can use `shard_map` to parallelize computation in neural networks, either by itself or in combination with the automatic partitioning in `jax.jit`. This @@ -1065,7 +1067,7 @@ While in those automatic partitioning examples we don't need to edit the model functions to use different parallelization strategies, with `shard_map` we often do. -### 8-way batch data parallelism +#### 8-way batch data parallelism The simplest multi-device parallelism strategy is to shard the batch of inputs and targets over multiple devices, replicate the parameters over those devices, @@ -1119,7 +1121,7 @@ that the collective all-reduce-sum operations happen where we'd expect: at the end of the forward pass to compute the loss value, and in the backward pass to compute the total parameter gradients. -### 8-way fully sharded data parallelism (FSDP) +#### 8-way fully sharded data parallelism (FSDP) Another strategy is to additionally shard the parameters over the devices, all-gathering each one when the full value is needed for the `jnp.dot` or bias @@ -1184,7 +1186,7 @@ print(allclose(jax.jit(jax.grad(loss))(params, batch), jax.jit(jax.grad(loss_fsdp))(params, batch))) ``` -### 8-way tensor parallelism (TP) +#### 8-way tensor parallelism (TP) Usually we don't use tensor model parallelism by itself, but seeing it in isolation is a good warmup on parallel matrix multiplication. It's also a good @@ -1225,7 +1227,7 @@ def loss_tp(params, batch): return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum! ``` -### FSDP + TP, with `shard_map` at the top level +#### FSDP + TP, with `shard_map` at the top level We can compose these strategies together, using multiple axes of parallelism. @@ -1272,7 +1274,7 @@ print(allclose(jax.jit(jax.grad(loss))(params, batch), jax.jit(jax.grad(loss_fsdp_tp))(params, batch))) ``` -### SPMD pipeline parallelism (PP) +#### SPMD pipeline parallelism (PP) With pipeline parallelism we aim to parallelize the evaluation of layers at different depths in our network. For example, one device might compute the From 25da7add37cd420e739f38e98dbf47f13821491d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 13 Aug 2024 19:04:14 -0700 Subject: [PATCH 116/702] Add method argument for jnp.isin --- jax/_src/numpy/setops.py | 63 ++++++++++++++++++++++++++-------------- tests/lax_numpy_test.py | 5 ++-- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 8e0162660951..7e8acb090279 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -21,6 +21,7 @@ import numpy as np +import jax from jax import jit from jax import lax @@ -41,24 +42,39 @@ _lax_const = lax_internal._const -@partial(jit, static_argnames=('invert',)) -def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array: +@partial(jit, static_argnames=('assume_unique', 'invert', 'method')) +def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool, + method='auto', assume_unique=False) -> Array: check_arraylike("in1d", ar1, ar2) - ar1_flat = ravel(ar1) - ar2_flat = ravel(ar2) - # Note: an algorithm based on searchsorted has better scaling, but in practice - # is very slow on accelerators because it relies on lax control flow. If XLA - # ever supports binary search natively, we should switch to this: - # ar2_flat = jnp.sort(ar2_flat) - # ind = jnp.searchsorted(ar2_flat, ar1_flat) - # if invert: - # return ar1_flat != ar2_flat[ind] - # else: - # return ar1_flat == ar2_flat[ind] - if invert: - return (ar1_flat[:, None] != ar2_flat[None, :]).all(-1) + arr1, arr2 = promote_dtypes(ar1, ar2) + arr1, arr2 = arr1.ravel(), arr2.ravel() + if arr1.size == 0 or arr2.size == 0: + return (ones if invert else zeros)(arr1.shape, dtype=bool) + if method in ['auto', 'compare_all']: + if invert: + return (arr1[:, None] != arr2[None, :]).all(-1) + else: + return (arr1[:, None] == arr2[None, :]).any(-1) + elif method == 'binary_search': + arr2 = lax.sort(arr2) + ind = jax.numpy.searchsorted(arr2, arr1) + if invert: + return arr1 != arr2[ind] + else: + return arr1 == arr2[ind] + elif method == 'sort': + if assume_unique: + ind_out: slice | Array = slice(None) + else: + arr1, ind_out = unique(arr1, size=len(arr1), return_inverse=True, fill_value=arr2.max()) + aux, ind = lax.sort_key_val(concatenate([arr1, arr2]), arange(arr1.size + arr2.size)) + if invert: + return ones(arr1.shape, bool).at[ind[:-1]].set(aux[1:] != aux[:-1], mode='drop')[ind_out] + else: + return zeros(arr1.shape, bool).at[ind[:-1]].set(aux[1:] == aux[:-1], mode='drop')[ind_out] else: - return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1) + raise ValueError(f"{method=} is not implemented; options are " + "'compare_all', 'binary_search', 'sort', and 'auto'") def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]: @@ -148,7 +164,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return full_like(arr1, fill_value, shape=size or 0) if not assume_unique: arr1 = cast(Array, unique(arr1, size=size and arr1.size)) - mask = _in1d(arr1, ar2, invert=True) + mask = _in1d(arr1, ar2, invert=True, assume_unique=assume_unique) if size is None: return arr1[mask] else: @@ -509,7 +525,8 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, def isin(element: ArrayLike, test_elements: ArrayLike, - assume_unique: bool = False, invert: bool = False) -> Array: + assume_unique: bool = False, invert: bool = False, *, + method='auto') -> Array: """Determine whether elements in ``element`` appear in ``test_elements``. JAX implementation of :func:`numpy.isin`. @@ -519,7 +536,11 @@ def isin(element: ArrayLike, test_elements: ArrayLike, test_elements: N-dimensional array of test values to check for the presence of each element. invert: If True, return ``~isin(element, test_elements)``. Default is False. - assume_unique: unused by JAX + assume_unique: if true, input arrays are assumed to be unique, which can + lead to more efficient computation. If the input arrays are not unique + and assume_unique is set to True, the results are undefined. + method: string specifying the method used to compute the result. Supported + options are 'compare_all', 'binary_search', 'sort', and 'auto' (default). Returns: A boolean array of shape ``element.shape`` that specifies whether each element @@ -531,9 +552,9 @@ def isin(element: ArrayLike, test_elements: ArrayLike, >>> jnp.isin(elements, test_elements) Array([ True, False, True, False], dtype=bool) """ - del assume_unique # unused check_arraylike("isin", element, test_elements) - result = _in1d(element, test_elements, invert=invert) + result = _in1d(element, test_elements, invert=invert, + method=method, assume_unique=assume_unique) return result.reshape(np.shape(element)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f09553c83e14..75bfbbccc1e1 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -703,11 +703,12 @@ def testTensordotErrors(self): test_shape=all_shapes, dtype=default_dtypes, invert=[False, True], + method=['auto', 'compare_all', 'binary_search', 'sort'] ) - def testIsin(self, element_shape, test_shape, dtype, invert): + def testIsin(self, element_shape, test_shape, dtype, invert, method): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] - jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert) + jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert, method=method) np_fun = lambda e, t: np.isin(e, t, invert=invert) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) From 12eebfe8d44fbd84e1cc0adc34c2a1ecf208df85 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 13 Aug 2024 13:16:50 -0700 Subject: [PATCH 117/702] docs: reorganize sections * Create "extension guides" section * Sort developer notes into subsections * Move examples from advanced section into user guides * Reorder some listings, adjust some titles --- docs/advanced_guide.rst | 22 +++------------------- docs/contributor_guide.rst | 22 ++++++++++++++++------ docs/distributed_data_loading.md | 2 +- docs/extensions.rst | 23 +++++++++++++++++++++++ docs/index.rst | 4 ++-- docs/jax.rst | 4 ++-- docs/jax_internal_api.rst | 4 ++-- docs/user_guides.rst | 22 ++++++++++++---------- 8 files changed, 61 insertions(+), 42 deletions(-) create mode 100644 docs/extensions.rst diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index cb987bd6e5a0..5cf32f696252 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -4,24 +4,16 @@ Advanced guides =============== This section contains examples and tutorials on more advanced topics, -such as multi-core computation, custom operations, and more in-depth -applications. - -.. toctree:: - :caption: Examples - :maxdepth: 1 - - notebooks/neural_network_with_tfds_data - notebooks/Neural_Network_and_Data_Loading - notebooks/vmapped_log_probs +such as multi-core computation, automatic differentiation, and custom +operations. .. toctree:: :caption: Parallel computation :maxdepth: 1 - multi_process notebooks/Distributed_arrays_and_automatic_parallelization notebooks/shard_map + multi_process distributed_data_loading .. toctree:: @@ -32,14 +24,6 @@ applications. notebooks/Custom_derivative_rules_for_Python_code notebooks/autodiff_remat -.. toctree:: - :caption: JAX internals - :maxdepth: 1 - - notebooks/How_JAX_primitives_work - notebooks/Writing_custom_interpreters_in_Jax - Custom_Operation_for_GPUs - .. toctree:: :caption: Deep dives :maxdepth: 1 diff --git a/docs/contributor_guide.rst b/docs/contributor_guide.rst index b5ebd5057df7..55094fc88958 100644 --- a/docs/contributor_guide.rst +++ b/docs/contributor_guide.rst @@ -1,18 +1,28 @@ .. _contributor-guide: -Developer documentation -======================= +Developer notes +=============== JAX welcomes contributions from the community. -See below for various install guides to get setup as a developer -as well as developer-focused resources such as Jax Enhancement Proposals. +These are guides to get set up as a developer, as well as +developer-focused resources, such as JAX Enhancement Proposals. + +See also the :doc:`extension guides<../extensions>`, which document +some of JAX's (extensible) internals. + .. toctree:: :maxdepth: 1 + :caption: Contribution guides contributing developer - jax_internal_api + investigating_a_regression + +.. toctree:: + :maxdepth: 1 + :caption: Design and internals + autodidax jep/index - investigating_a_regression + jax_internal_api diff --git a/docs/distributed_data_loading.md b/docs/distributed_data_loading.md index 70cbd26baa5c..be4d170eae81 100644 --- a/docs/distributed_data_loading.md +++ b/docs/distributed_data_loading.md @@ -12,7 +12,7 @@ kernelspec: name: python3 --- -# Distributed data loading in a multi-host/multi-process environment +# Distributed data loading in multi-host / multi-process environments diff --git a/docs/extensions.rst b/docs/extensions.rst new file mode 100644 index 000000000000..92963b71f20f --- /dev/null +++ b/docs/extensions.rst @@ -0,0 +1,23 @@ +.. _extensions: + +Extension guides +================ + +Guides for extending JAX's capabilities, and for building libraries +that use or interface with JAX. + +.. toctree:: + :caption: Extensible JAX internals + :maxdepth: 1 + + notebooks/How_JAX_primitives_work + jaxpr + notebooks/Writing_custom_interpreters_in_Jax + Custom_Operation_for_GPUs + jax.extend + +.. toctree:: + :caption: Libraries and extensions + :maxdepth: 1 + + building_on_jax diff --git a/docs/index.rst b/docs/index.rst index 476de62bd713..11d2807bf77e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -48,7 +48,7 @@ For an end-to-end transformer library built on JAX, see MaxText_. :link-type: ref :class-card: user-guides - .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer docs + .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer notes :columns: 12 6 6 4 :link: contributor-guide :link-type: ref @@ -80,7 +80,7 @@ For an end-to-end transformer library built on JAX, see MaxText_. user_guides advanced_guide contributor_guide - building_on_jax + extensions notes jax diff --git a/docs/jax.rst b/docs/jax.rst index 7be3e63015d9..79a46ff4d774 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -1,7 +1,7 @@ .. currentmodule:: jax -Public API: jax package -======================= +Public API: ``jax`` package +=========================== Subpackages ----------- diff --git a/docs/jax_internal_api.rst b/docs/jax_internal_api.rst index fe65054d22c1..1ece596d88ef 100644 --- a/docs/jax_internal_api.rst +++ b/docs/jax_internal_api.rst @@ -1,5 +1,5 @@ -Internal APIs -============= +Internal API reference +====================== core ---- diff --git a/docs/user_guides.rst b/docs/user_guides.rst index 45260067604e..e917cf2fee38 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -20,25 +20,27 @@ or deployed codebases. .. toctree:: :maxdepth: 1 - :caption: Development + :caption: Interfaces - jaxpr - notebooks/external_callbacks - type_promotion pytrees - -.. toctree:: - :maxdepth: 1 - :caption: Run time - + errors aot export/index - errors + type_promotion transfer_guard .. toctree:: :maxdepth: 1 :caption: Custom operations + notebooks/external_callbacks pallas/index ffi + +.. toctree:: + :caption: Example applications + :maxdepth: 1 + + notebooks/neural_network_with_tfds_data + notebooks/Neural_Network_and_Data_Loading + notebooks/vmapped_log_probs From dbd6aeebb7f8605d5a554ff2ac77553015a26b20 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 13 Aug 2024 22:02:17 -0700 Subject: [PATCH 118/702] Disable some asan tests, times out PiperOrigin-RevId: 662774152 --- tests/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index 07069b08e7b2..14d1d409c2ce 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1196,6 +1196,7 @@ jax_test( shard_count = { "tpu": 5, }, + tags = ["noasan"], # Times out deps = [ "//jax:experimental", "//jax:experimental_host_callback", @@ -1211,6 +1212,7 @@ jax_test( shard_count = { "gpu": 5, }, + tags = ["noasan"], # Times out deps = [ "//jax:experimental", "//jax:experimental_host_callback", From f384497f6882f5cdbea2d5b7f516b9b5e233cf9e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 14 Aug 2024 01:46:31 -0700 Subject: [PATCH 119/702] [Mosaic GPU] Add support for cluster collective loads and barriers over multiple dimensions This will be useful for an upcoming change to the matmul kernel that splits the N blocks over two cluster dimensions. PiperOrigin-RevId: 662825455 --- jax/experimental/mosaic/gpu/__init__.py | 24 ++++++-- jax/experimental/mosaic/gpu/utils.py | 51 +++++++++------- tests/mosaic/gpu_test.py | 79 ++++++++++++++++++++----- 3 files changed, 112 insertions(+), 42 deletions(-) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 5ffc149bc76f..88e80a79cc76 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -325,7 +325,7 @@ def async_copy( swizzle: int | None = None, arrive: bool | None = None, uniform: bool = True, - collective: gpu.Dimension | None = None, + collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, ): index = ir.IndexType.get() i16 = ir.IntegerType.get_signless(16) @@ -388,7 +388,11 @@ def async_copy( dyn_base_indices = list(dyn_base_indices) slice_shape = list(slice_shape) - collective_size = 1 if collective is None else self.cluster_size[collective] + collective_size = 1 + if collective is not None: + if isinstance(collective, gpu.Dimension): + collective = (collective,) + collective_size = math.prod(self.cluster_size[d] for d in collective) if collective_size > 1: def partition_dim(dim: int, idx: ir.Value, num_chunks: int): nonlocal smem_ref @@ -399,18 +403,28 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): smem_ref, (slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),) ) - idx = gpu.cluster_block_id(collective) + stride = 1 + idx = c(0, index) + for d in sorted(collective): + if self.cluster_size[d] == 1: # Optimize a multiply by 0. + continue + idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))) + stride *= self.cluster_size[d] rem_collective_size = collective_size for dim, slice_size in enumerate(slice_shape[:-1]): if slice_size % rem_collective_size == 0: partition_dim(dim, idx, rem_collective_size) + rem_collective_size = 1 break - elif collective_size % slice_size == 0: + elif rem_collective_size % slice_size == 0: dim_idx = arith.remui(idx, c(slice_size, index)) partition_dim(dim, dim_idx, slice_size) idx = arith.divui(idx, c(slice_size, index)) rem_collective_size //= slice_size - else: + else: + break # We failed to partition the leading dimensions. + del idx # We overwrote the block index in the loop. + if rem_collective_size > 1: raise ValueError( "None of the leading dimensions in the transformed slice shape" f" {slice_shape} is divisible by the collective size" diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 7733146d0153..c892c672ffe2 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -622,21 +622,27 @@ class CollectiveBarrierRef: def initialize( address: ir.Value, num_barriers: int, - dims: Sequence[gpu.Dimension], + dims: Sequence[gpu.Dimension | Sequence[gpu.Dimension]], cluster_shape: tuple[int, int, int], ) -> "CollectiveBarrierRef": i32 = ir.IntegerType.get_signless(32) # With the exception of the current device, each pair of slices along # collective dims is disjoint. Since the current device is overcounted, # we must decrease the arrival count a little. - arrival_count = sum(cluster_shape[d] for d in dims) - len(dims) + 1 - if math.prod(cluster_shape[d] for d in dims) == 1: + dims_shape = [ + cluster_shape[d] + if isinstance(d, gpu.Dimension) + else math.prod(cluster_shape[dd] for dd in d) + for d in dims + ] + arrival_count = sum(dims_shape) - len(dims) + 1 + if arrival_count == 1: + assert all(s == 1 for s in dims_shape) cluster_mask = None - assert arrival_count == 1 else: cluster_mask = c(0, i32) - for d in dims: - if cluster_shape[d] == 1: + for d, size in zip(dims, dims_shape): + if size == 1: # Only the current device is in this mask, but it will also be # present in one of the non-trivial cluster dims. continue @@ -887,8 +893,11 @@ def memref_ptr(memref_arg, memory_space=None): def cluster_collective_mask( - cluster_shape: tuple[int, int, int], collective: gpu.Dimension + cluster_shape: tuple[int, int, int], + collective: Sequence[gpu.Dimension] | gpu.Dimension, ): + if isinstance(collective, gpu.Dimension): + collective = (collective,) # We first compute the linearized index of the slice along the collective # dim that contains the current block. Then, the mask is a sequence of 1s # strided by the position of the collective dim, shifted left by the linear @@ -896,20 +905,20 @@ def cluster_collective_mask( # TODO(apaszke): Make sure this gets hoisted outside of any loops. # If not, we might need to do it manually. i32 = ir.IntegerType.get_signless(32) - stride = 1 mask_shift = c(0, i32) - collective_stride = None - for cluster_dim in gpu.Dimension: - if cluster_dim != collective: - if cluster_shape[cluster_dim] != 1: # Constant-fold multiply by 0. - dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim)) - mask_shift = arith.addi( - mask_shift, arith.muli(dim_idx, c(stride, i32)), - ) - else: - collective_stride = stride - stride *= cluster_shape[cluster_dim] + # NOTE: GPU dimensions are minor-to-major. + cluster_strides = get_contiguous_strides(cluster_shape[::-1])[::-1] + for stride, cluster_dim in zip(cluster_strides, gpu.Dimension): + if cluster_dim in collective: + continue + if cluster_shape[cluster_dim] != 1: # Constant-fold multiply by 0. + dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim)) + mask_shift = arith.addi( + mask_shift, arith.muli(dim_idx, c(stride, i32)), + ) mask_unshifted = 0 - for i in range(cluster_shape[collective]): - mask_unshifted |= 1 << (i * collective_stride) + collective_strides = [cluster_strides[d] for d in collective] + collective_shape = tuple(cluster_shape[d] for d in collective) + for idx in np.ndindex(collective_shape): + mask_unshifted |= 1 << sum(i * s for i, s in zip(idx, collective_strides)) return arith.shli(c(mask_unshifted, i32), mask_shift) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index ca42a383adef..baad759df38d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -44,6 +44,7 @@ class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok else: from jax.experimental.mosaic import gpu as mosaic_gpu from jax.experimental.mosaic.gpu import dsl as mgpu + from jax.experimental.mosaic.gpu import utils as utils from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu.utils import * # noqa: F403 from jax._src.lib.mlir.dialects import gpu @@ -749,10 +750,11 @@ def kernel(ctx, dst, scratch): @parameterized.named_parameters( ( - f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", + f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}{'_group' if group_dims else ''}", collective_dims, noncollective_dims, collective_size, + group_dims, ) for collective_dims in itertools.chain.from_iterable( itertools.combinations(Dimension, n) for n in range(1, 4) @@ -761,9 +763,10 @@ def kernel(ctx, dst, scratch): itertools.combinations(Dimension, n) for n in range(3) ) for collective_size in (1, 2, 4) + for group_dims in (False,) + ((True,) if len(collective_dims) > 1 else ()) if all(d not in noncollective_dims for d in collective_dims) ) - def test_collective_arrive(self, collective_dims, noncollective_dims, collective_size): + def test_collective_arrive(self, collective_dims, noncollective_dims, collective_size, group_dims): i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() cluster = [1, 1, 1] @@ -773,9 +776,21 @@ def test_collective_arrive(self, collective_dims, noncollective_dims, collective cluster[d] = 2 if math.prod(cluster) > 16: self.skipTest("Cluster too big") - def kernel(ctx, dst, collective_barrier): + is_trivial = math.prod(cluster[d] for d in collective_dims) == 1 + def kernel(ctx, dst, mask, collective_barrier): + memref.store(arith.constant(i32, 1 << 17), mask, [c(0, index)]) + gpu.barrier() collective_barrier.arrive() collective_barrier.wait() + if not is_trivial: + llvm.atomicrmw( + llvm.AtomicBinOp.min, + utils.memref_ptr(mask), + collective_barrier.cluster_mask, + llvm.AtomicOrdering.monotonic, + ) + else: + assert collective_barrier.cluster_mask is None tid = thread_idx() linear_idx = arith.index_cast(index, tid) stride = c(128, index) @@ -784,13 +799,30 @@ def kernel(ctx, dst, collective_barrier): stride = arith.muli(stride, gpu.grid_dim(d)) memref.store(arith.index_cast(i32, linear_idx), dst, [linear_idx]) out_shape = jax.ShapeDtypeStruct((math.prod(cluster) * 128,), jnp.int32) - scratch = mgpu.ClusterBarrier(collective_dims) - y = mosaic_gpu.as_gpu_kernel( - kernel, cluster, (128, 1, 1), (), out_shape, scratch, cluster=cluster, + mask_shape = jax.ShapeDtypeStruct((1,), jnp.int32) + barrier_dims = collective_dims + if group_dims: + barrier_dims = (collective_dims[:2], *collective_dims[2:]) + scratch = mgpu.ClusterBarrier(barrier_dims) + y, mask = mosaic_gpu.as_gpu_kernel( + kernel, cluster, (128, 1, 1), (), (out_shape, mask_shape), scratch, cluster=cluster, )() np.testing.assert_array_equal( y, np.arange(math.prod(cluster) * 128, dtype=np.int32) ) + if not is_trivial: + # Verify that the mask is correct. Blocks are column-major, hence the transpose. + block_bits = 1 << np.arange(math.prod(cluster), dtype=np.int32).reshape(cluster[::-1]).T + expected_mask = 0 + for bd in barrier_dims: + if isinstance(bd, gpu.Dimension): + bd = (bd,) + least_significant_slice = tuple( + slice(None) if d in bd else 0 for d in gpu.Dimension + ) + mask_bits = block_bits[least_significant_slice] + expected_mask |= np.bitwise_or.reduce(mask_bits, axis=None) + self.assertEqual(mask, expected_mask) class TMATest(TestCase): @@ -816,30 +848,36 @@ def kernel(ctx, src, dst, smem): @parameterized.named_parameters( ( - f"_{collective_dim}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", - collective_dim, + f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", + collective_dims, noncollective_dims, collective_size, ) - for collective_dim in Dimension + for collective_dims in itertools.chain.from_iterable( + itertools.combinations(Dimension, n) for n in range(1, 4) + ) for noncollective_dims in itertools.chain.from_iterable( itertools.combinations(Dimension, n) for n in range(3) ) for collective_size in (1, 2, 4) - if collective_dim not in noncollective_dims + if all(d not in noncollective_dims for d in collective_dims) ) - def test_tma_load_multicast(self, collective_dim, noncollective_dims, collective_size): + def test_tma_load_multicast(self, collective_dims, noncollective_dims, collective_dim_size): index = ir.IndexType.get() swizzle = 128 dtype = jnp.float16 cluster = [1, 1, 1] - cluster[collective_dim] = collective_size + for d in collective_dims: + cluster[d] = collective_dim_size for d in noncollective_dims: cluster[d] = 2 - noncollective_size = math.prod(cluster) // cluster[collective_dim] + if math.prod(cluster) > 16: + self.skipTest("Cluster too big") + collective_size = math.prod(cluster[d] for d in collective_dims) + noncollective_size = math.prod(cluster) // collective_size # We use the 2 dimension to exercise splitting the collective over # multiple dimensions when the cluster is large. - shape = (noncollective_size, 2, 16 * cluster[collective_dim], 64) + shape = (noncollective_size, 2, 16 * collective_size, 64) minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize shape = (*shape[:-1], minor_size) # Note that this kernel does not use the non-collective dimensions in any @@ -861,11 +899,20 @@ def kernel(ctx, src, dst, scratch): gmem_slice=(noncollective_idx,), swizzle=swizzle, barrier=barrier, - collective=collective_dim, + collective=collective_dims, ) barrier.wait() + # This is _not_ the real cluster block idx, because it does not consider + # the column-major ordering of the grid dimensions. + idx = c(0, index) + stride = 1 + for d in collective_dims: + idx = arith.addi( + idx, arith.muli(gpu.cluster_block_id(d), c(stride, index)) + ) + stride *= cluster[d] slc = ds( - arith.muli(gpu.cluster_block_id(collective_dim), c(16, index)), 16 + arith.muli(idx, c(16, index)), 16 ) copy( memref_slice(tmp, (slice(None), slc)), From 2ab7558425ec0f69370f8b959b3140c707bc4821 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 14 Aug 2024 02:17:18 -0700 Subject: [PATCH 120/702] [Mosaic GPU] Add support for grid tiling to improve L2 cache utilization While CUDA technically does not guarantee anything about the order in which blocks will be executed, in practice they are generally scheduled in column-major order within the grid. We can use this property to launch the blocks in a tiled way, which can lead to an improved rate of L2 hits and a significant performance boost. PiperOrigin-RevId: 662834982 --- .../mosaic/gpu/examples/matmul.py | 57 +++++++++++++------ tests/mosaic/matmul_test.py | 8 ++- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 53cb270b3cdc..52d403cd0131 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -117,6 +117,7 @@ def build_kernel( swizzle: int = 128, cluster_m: int = 1, cluster_n: int = 1, + grid_tile_n: int = 1, rhs_transpose: bool = False, wgmma_impl=WGMMADefaultImpl, profiler_spec: profiler.ProfilerSpec | None = None, @@ -126,8 +127,8 @@ def build_kernel( raise ValueError(f"{tile_m=} must be divisible by 64") if m % tile_m != 0: raise ValueError(f"{m=} must be divisible by {tile_m=}") - if n % 64 != 0: - raise ValueError(f"n must be divisible by 64, but got {n=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") if stages < 2: raise ValueError(f"Need at least 2 stages, but got {stages=}") if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2: @@ -174,7 +175,11 @@ def safe_div(x, y): assert x % y == 0, (x, y) return x // y - grid = (safe_div(m, block_tiling.m), safe_div(n, block_tiling.n), 1) + grid = ( + grid_tile_n, + safe_div(m, block_tiling.m), + safe_div(n, block_tiling.n * grid_tile_n), + ) block = (128, 1, 1) c = arith.ConstantOp.create_index @@ -191,10 +196,12 @@ def _main(ctx, a_device, b_device, c_device, smem): ((lhs_smem, rhs_smem, impl_smem), epilogue_smem), *barriers = smem tma_barriers, cluster_barrier = barriers - memref.assume_alignment(c_device, 16) - - m_start = arith.muli(c(block_tiling.m), gpu.block_id(gpu.Dimension.x)) - n_start = arith.muli(c(block_tiling.n), gpu.block_id(gpu.Dimension.y)) + m_start = arith.muli(c(block_tiling.m), gpu.block_id(gpu.Dimension.y)) + n_block_idx = arith.addi( + gpu.block_id(gpu.Dimension.x), + arith.muli(gpu.block_id(gpu.Dimension.z), c(grid_tile_n)), + ) + n_start = arith.muli(c(block_tiling.n), n_block_idx) def fetch(slot, ki): barrier = tma_barriers[slot] @@ -212,7 +219,7 @@ def fetch(slot, ki): dst_ref=memref_slice(lhs_smem, slot), gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)), gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk), - collective=gpu.Dimension.y, + collective=(gpu.Dimension.x, gpu.Dimension.z), **common_copy_args, ) rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n)) @@ -226,7 +233,7 @@ def fetch(slot, ki): dst_ref=memref_slice(rhs_smem, slot), gmem_slice=rhs_slice, gmem_transform=rhs_transform, - collective=gpu.Dimension.x, + collective=gpu.Dimension.y, **common_copy_args, ) @@ -290,6 +297,13 @@ def stage_loop_body(ki, accs): ) ctx.await_async_copy(0) + cluster_tile_n = min(cluster_n, grid_tile_n) + if cluster_n % cluster_tile_n: + raise ValueError( + f"{cluster_n=} must be divisible by {cluster_tile_n} (due to" + f" {grid_tile_n=})" + ) + cluster = (cluster_tile_n, cluster_m, cluster_n // cluster_tile_n) return mosaic_gpu.as_gpu_kernel( _main, grid, @@ -303,12 +317,12 @@ def stage_loop_body(ki, accs): smem_shape, TMABarrier(num_barriers=stages), ClusterBarrier( - collective_dims=(gpu.Dimension.x, gpu.Dimension.y), + collective_dims=((gpu.Dimension.x, gpu.Dimension.z), gpu.Dimension.y), num_barriers=stages, ) if cluster_m * cluster_n > 1 else None, ), profiler_spec, - cluster=(cluster_n, cluster_m, 1), + cluster=cluster, ) @@ -321,6 +335,7 @@ def verify( tile_n=128, cluster_m=1, cluster_n=1, + grid_tile_n=1, swizzle=128, profile=False, in_dtype=jnp.float16, @@ -344,6 +359,7 @@ def verify( cluster_n=cluster_n, rhs_transpose=rhs_transpose, swizzle=swizzle, + grid_tile_n=grid_tile_n, wgmma_impl=WGMMADefaultImpl, profiler_spec=prof_spec, ) @@ -384,12 +400,13 @@ def ref_f(x, y): x = random.uniform(kx, (m, k), dtype=dtype) y = random.uniform(ky, (k, n), dtype=dtype) - tile_m = tile_n = (64, 128, 256) + tile_m = tile_n = (64, 128) cluster_m = cluster_n = (1, 2) - swizzle = (128,) + swizzle = (128,) # 64 can be a good choice for some shapes too! stages = (2, 4, 5, 6) - configs = itertools.product(tile_m, tile_n, cluster_m, cluster_n, stages, swizzle) - names = ("tile_m", "tile_n", "cluster_m", "cluster_n", "stages", "swizzle") + grid_tile_n = (1, 4, 16) + configs = itertools.product(tile_m, tile_n, cluster_m, cluster_n, stages, swizzle, grid_tile_n) + names = ("tile_m", "tile_n", "cluster_m", "cluster_n", "stages", "swizzle", "grid_tile_n") best_runtime = float("inf") best_kwargs = {} for config in configs: @@ -398,9 +415,15 @@ def ref_f(x, y): continue if m < kwargs["tile_m"] or n < kwargs["tile_n"]: continue - if (m // kwargs["tile_m"]) % kwargs["cluster_n"]: + if (m // kwargs["tile_m"]) % kwargs["cluster_m"]: + continue + if (n // kwargs["tile_n"]) % kwargs["cluster_n"]: + continue + if n % kwargs["grid_tile_n"]: continue - if (n // kwargs["tile_n"]) % kwargs["cluster_m"]: + # This is a heuristic, not a strict correctness check. You can relax it + # for a more complete search space. + if kwargs["tile_m"] == kwargs["tile_n"] == 64: continue try: f = build_kernel( diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index b7fa615db76a..27ce4e3f02d7 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -88,10 +88,13 @@ def test_matmul(self, data): tile_n = data.draw( hps.sampled_from([t for t in [64, 128, 256] if t <= n]), label="tile_n" ) + grid_m, grid_n = m // tile_m, n // tile_n + grid_tile_n = data.draw(hps.sampled_from([1, 2, 4, 8, 16]), label="grid_tile_n") + hp.assume(grid_n % grid_tile_n == 0) cluster_m = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_m") - hp.assume((m // tile_m) % cluster_m == 0) + hp.assume(grid_m % cluster_m == 0) cluster_n = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_n") - hp.assume((n // tile_n) % cluster_n == 0) + hp.assume(grid_n % cluster_n == 0) # TODO(apaszke): Non-portable clusters (16 blocks) sometimes deadlock. hp.assume(cluster_m * cluster_n <= 8) if bytewidth == 4: @@ -111,6 +114,7 @@ def test_matmul(self, data): out_dtype=out_dtype, cluster_m=cluster_m, cluster_n=cluster_n, + grid_tile_n=grid_tile_n, swizzle=swizzle, rhs_transpose=rhs_transpose, ) From 6290cd77fcf49239f99a96af536091b63c08f0eb Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 14 Aug 2024 02:22:50 -0700 Subject: [PATCH 121/702] Added pl.program_id and pl.num_programs to Mosaic GPU lowering PiperOrigin-RevId: 662836490 --- jax/_src/pallas/mosaic_gpu/lowering.py | 85 +++++++++++++++++--------- tests/pallas/mosaic_gpu_test.py | 32 ++++++++++ 2 files changed, 87 insertions(+), 30 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 109e81306b4b..cd5e3bbddb8d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -19,7 +19,6 @@ from collections.abc import Sequence import dataclasses import functools -import itertools as it import math from typing import Any, cast @@ -161,22 +160,6 @@ def lower_jaxpr_to_module( ) -> LoweringResult: del cost_estimate # Unused. - in_structs_gmem = [*grid_mapping.in_shapes] - in_structs_smem = [ - jax.ShapeDtypeStruct(bm.block_shape, s.dtype) - for bm, s in zip( - grid_mapping.block_mappings[: grid_mapping.num_inputs], - grid_mapping.in_shapes, - ) - ] - out_structs_gmem = [*grid_mapping.out_shapes] - out_structs_smem = [ - jax.ShapeDtypeStruct(bm.block_shape, s.dtype) - for bm, s in zip( - grid_mapping.block_mappings[grid_mapping.num_inputs :], - grid_mapping.out_shapes, - ) - ] assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims if len(grid_mapping.grid) > 3: @@ -209,31 +192,46 @@ def lower_jaxpr_to_module( grid += (1,) * (3 - len(grid)) block = (128,) + (1,) * (len(grid) - 1) + in_structs_gmem = [*grid_mapping.in_shapes] + in_structs_smem = [ + jax.ShapeDtypeStruct(bm.block_shape, s.dtype) + for bm, s in zip( + grid_mapping.block_mappings[: grid_mapping.num_inputs], + grid_mapping.in_shapes, + ) + ] + out_structs_gmem = [*grid_mapping.out_shapes] + out_structs_smem = [ + jax.ShapeDtypeStruct(bm.block_shape, s.dtype) + for bm, s in zip( + grid_mapping.block_mappings[grid_mapping.num_inputs :], + grid_mapping.out_shapes, + ) + ] + def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): *buffers_gmem, (*buffers_smem, runtime_smem, barriers) = buffers assert len(buffers_gmem) == len(buffers_smem) - in_buffers_gmem = buffers_gmem[: len(in_structs_gmem)] - in_buffers_smem = buffers_smem[: len(in_structs_smem)] - out_buffers_gmem = buffers_gmem[len(in_structs_gmem) :] - out_buffers_smem = buffers_smem[len(in_structs_smem) :] + in_buffers_gmem, out_buffers_gmem = util.split_list( + buffers_gmem, [grid_mapping.num_inputs] + ) + in_buffers_smem, out_buffers_smem = util.split_list( + buffers_smem, [grid_mapping.num_inputs] + ) [barrier] = cast(mgpu.BarrierRef, barriers) module_ctx = ModuleContext( name_and_src_info.name, grid_mapping, runtime_smem, smem_used_bytes=0 ) - program_ids = [ - arith_dialect.index_cast( - ir.IntegerType.get_signless(32), gpu_dialect.block_id(dim) - ) - for dim in it.islice(gpu_dialect.Dimension, len(grid_mapping.grid)) - ] + program_ids = map(_program_id, range(len(grid_mapping.grid))) start_indices = map( functools.partial(_eval_index_map, module_ctx, program_ids), grid_mapping.block_mappings, ) - in_start_indices = start_indices[: len(in_structs_gmem)] - out_start_indices = start_indices[len(in_structs_gmem) :] + in_start_indices, out_start_indices = util.split_list( + start_indices, [grid_mapping.num_inputs] + ) with mgpu.single_thread(): for start_indices, b_gmem, b_smem in zip( @@ -252,7 +250,9 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): uniform=False, ) - barrier.wait() + if grid_mapping.num_inputs: + # Only wait if async copies were issued. + barrier.wait() _ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, buffers_smem) mgpu.commit_shared() @@ -359,6 +359,28 @@ def write_env(var: jax_core.Var, val): return map(read_env, jaxpr.outvars) +@register_lowering_rule(primitives.program_id_p) +def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): + del ctx # Unused. + return _program_id(axis) + + +def _program_id(axis: int) -> ir.Value: + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.block_id(gpu_dialect.Dimension(axis)), + ) + + +@register_lowering_rule(primitives.num_programs_p) +def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): + del ctx # Unused. + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.block_dim(gpu_dialect.Dimension(axis)), + ) + + @register_lowering_rule(sp.get_p) def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree): del ctx, tree # Unused. @@ -510,6 +532,9 @@ def _ensure_fa(x: object, aval: jax_core.ShapedArray) -> mgpu.FragmentedArray: return mgpu.FragmentedArray.splat( _ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)), () ) + elif isinstance(x, ir.Value): + if isinstance(x.type, (ir.IntegerType, ir.FloatType)): + return mgpu.FragmentedArray.splat(x, ()) raise NotImplementedError diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index cb2301e1a2b1..3f690f9ca911 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -162,6 +162,38 @@ def body(tmp_ref): o = f(inp) np.testing.assert_array_equal(o, inp + 1.0) + def test_program_id(self): + @functools.partial( + pl.pallas_call, + in_specs=(), + out_specs=pl.BlockSpec((128,), lambda *i: i), + out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), + grid=2, + ) + def kernel(o_ref): + o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0)) + + np.testing.assert_array_equal( + kernel(), + jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32), + ) + + def test_num_programs(self): + @functools.partial( + pl.pallas_call, + in_specs=(), + out_specs=pl.BlockSpec((128,), lambda *i: i), + out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), + grid=2, + ) + def kernel(o_ref): + o_ref[...] = jnp.full(o_ref.shape, pl.num_programs(0)) + + np.testing.assert_array_equal( + kernel(), + jnp.full([256], 2, dtype=jnp.int32), + ) + if __name__ == "__main__": absltest.main() From 807dcb5a066191bce3480ce79c55d9544e78e42b Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 14 Aug 2024 07:09:13 -0700 Subject: [PATCH 122/702] Integrate LLVM at llvm/llvm-project@c8b5d30f7077 Updates LLVM usage to match [c8b5d30f7077](https://github.com/llvm/llvm-project/commit/c8b5d30f7077) PiperOrigin-RevId: 662906261 --- jax/_src/pallas/mosaic/lowering.py | 4 +--- .../dialect/tpu/transforms/apply_vector_layout.cc | 7 +------ .../dialect/tpu/transforms/canonicalize_mosaic.cc | 6 +++--- .../dialect/tpu/transforms/infer_vector_layout.cc | 11 +++-------- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 0c01e01ba686..1706b8f201a4 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1295,9 +1295,7 @@ def _proxy_fun(val, *, axes): kind, x, acc, - ir.ArrayAttr.get( - [ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes] - ), + axes, ) return op.result return _lowering_rule diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 437d4d344dcb..25775a1994ab 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3578,12 +3578,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, auto acc = cast>(multi_reduction_op.getAcc()); TPU_ASSERT_OP(layouts_out.front().has_value()); - const ArrayAttr dim_attrs = multi_reduction_op.getReductionDims(); - SmallVector dims; - dims.reserve(dim_attrs.size()); - for (const Attribute dim_attr : dim_attrs) { - dims.push_back(cast(dim_attr).getValue().getSExtValue()); - } + SmallVector dims(multi_reduction_op.getReductionDims()); std::sort(dims.begin(), dims.end()); // Make sure that the accumulator is a splat of the neutral value diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 93c362a43671..e70e01dfbce7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -203,8 +203,8 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation, return success(); } else if (element_type.isBF16()) { bool reduces_sublanes = false; - for (Attribute dim : op.getReductionDims()) { - if (cast(dim).getInt() == source_ty.getRank() - 2) { + for (int64_t dim : op.getReductionDims()) { + if (dim == source_ty.getRank() - 2) { reduces_sublanes = true; } } @@ -230,7 +230,7 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation, } auto new_op = builder.create( op.getLoc(), new_acc.getType(), op.getKindAttr(), new_source, new_acc, - op.getReductionDims()); + DenseI64ArrayAttr::get(builder.getContext(), op.getReductionDims())); auto new_result = builder.create(op.getLoc(), result_ty, new_op.getResult()); op.replaceAllUsesWith(new_result.getResult()); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 8f702432d397..4d8f3db71027 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1277,11 +1277,7 @@ class VectorLayoutInferer { auto src_ty = op.getSourceVectorType(); auto dst_ty = dyn_cast(op.getDestType()); TPU_CHECK_OP(dst_ty, "only reductions with vector results supported"); - SmallVector dims; - dims.reserve(op.getReductionDims().size()); - for (Attribute dim_attr : op.getReductionDims()) { - dims.push_back(cast(dim_attr).getInt()); - } + llvm::ArrayRef dims = op.getReductionDims(); int64_t src_rank = src_ty.getRank(); auto acc_layout = getLayout(op.getAcc()); TPU_CHECK_OP(is_fully_replicated(acc_layout), @@ -1770,9 +1766,8 @@ class VectorLayoutInferer { if (auto reduce = dyn_cast(operand.getOwner())) { bool reduces_tiled_dims = false; - for (Attribute dim : reduce.getReductionDims()) { - if (cast(dim).getInt() >= - reduce.getSourceVectorType().getRank() - 2) { + for (int64_t dim : reduce.getReductionDims()) { + if (dim >= reduce.getSourceVectorType().getRank() - 2) { reduces_tiled_dims = true; break; } From b0a144ae4b12af048a37c6f73f3c5186eaf7c843 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 14 Aug 2024 07:52:05 -0700 Subject: [PATCH 123/702] Don't export ir_attribute from interpreters.mlir. PiperOrigin-RevId: 662918256 --- jax/interpreters/mlir.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 78b070614621..edfc56ddd4fd 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -49,7 +49,6 @@ i32_attr as i32_attr, i64_attr as i64_attr, ir as ir, - ir_attribute as ir_attribute, ir_constant as ir_constant, ir_type_handlers as ir_type_handlers, jaxpr_subcomp as jaxpr_subcomp, From df2e9c3836e16a7082f8fafdfa4847e2fa89f493 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 14 Aug 2024 08:42:00 -0700 Subject: [PATCH 124/702] [Mosaic] Fix lowering for `_dot_general_lowering_rule` to match the new `vector.MultiDimReductionOp` signature. PiperOrigin-RevId: 662933072 --- jax/_src/pallas/mosaic/lowering.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 1706b8f201a4..4db430ecde9e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1460,9 +1460,7 @@ def _dot_general_lowering_rule( ir.Attribute.parse("#vector.kind"), arith.MulFOp(x, y), acc, - ir.ArrayAttr.get( - [ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 1)] - ), + [1] ) return vector.ShapeCastOp(out_type, red).result From 5cc689976f792b0d571fb3265cc2dcf63d9cc05c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 14 Aug 2024 08:59:56 -0700 Subject: [PATCH 125/702] Use PEP484-style exports in several submodules --- jax/custom_batching.py | 4 ++-- jax/custom_transpose.py | 2 +- jax/distributed.py | 5 ++++- jax/dlpack.py | 6 +++++- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/jax/custom_batching.py b/jax/custom_batching.py index a4850f04c2ec..9b8dc8f8709a 100644 --- a/jax/custom_batching.py +++ b/jax/custom_batching.py @@ -13,6 +13,6 @@ # limitations under the License. from jax._src.custom_batching import ( - custom_vmap, - sequential_vmap, + custom_vmap as custom_vmap, + sequential_vmap as sequential_vmap, ) diff --git a/jax/custom_transpose.py b/jax/custom_transpose.py index 311139da2567..314163c4684a 100644 --- a/jax/custom_transpose.py +++ b/jax/custom_transpose.py @@ -13,5 +13,5 @@ # limitations under the License. from jax._src.custom_transpose import ( - custom_transpose, + custom_transpose as custom_transpose, ) diff --git a/jax/distributed.py b/jax/distributed.py index 284ae6f95f48..cf39b81f423a 100644 --- a/jax/distributed.py +++ b/jax/distributed.py @@ -12,4 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.distributed import (initialize, shutdown) +from jax._src.distributed import ( + initialize as initialize, + shutdown as shutdown, +) diff --git a/jax/dlpack.py b/jax/dlpack.py index 707e966ee243..a65496ec0cbf 100644 --- a/jax/dlpack.py +++ b/jax/dlpack.py @@ -12,4 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.dlpack import (to_dlpack, from_dlpack, SUPPORTED_DTYPES) +from jax._src.dlpack import ( + to_dlpack as to_dlpack, + from_dlpack as from_dlpack, + SUPPORTED_DTYPES as SUPPORTED_DTYPES, +) From 229cbae5ea4cad8d102dfd1351f2756a490cadef Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 14 Aug 2024 09:02:20 -0700 Subject: [PATCH 126/702] Add num_devices to Sharding interface so that it works with NamedSharding containing AbstractMesh too. PiperOrigin-RevId: 662938823 --- jax/_src/array.py | 2 +- jax/_src/dispatch.py | 2 +- jax/_src/sharding.py | 5 +++++ jax/_src/sharding_impls.py | 26 +++++++++++++++++++++++--- tests/shard_map_test.py | 17 +++++++++++++++++ 5 files changed, 47 insertions(+), 5 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index e7bc2e933531..7aaec18c9c89 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -489,7 +489,7 @@ def on_device_size_in_bytes(self): """Returns the total global on-device size of the array in bytes.""" arr = self._arrays[0] per_shard_size = arr.on_device_size_in_bytes() - return per_shard_size * len(self.sharding.device_set) + return per_shard_size * self.sharding.num_devices def devices(self) -> set[Device]: self._check_if_deleted() diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 7b20ca9f6ac1..bb6f5f4110b6 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -418,7 +418,7 @@ def _device_put_sharding_impl(x, aval, device): return _different_device_order_reshard(x, s) if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and - x.is_fully_addressable and len(s.device_set) > 1 and + x.is_fully_addressable and s.num_devices > 1 and s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error s.device_set == x.sharding.device_set): assert isinstance(s, Sharding) diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index fef61566c6ae..20fe3131dcba 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -144,6 +144,11 @@ def is_fully_addressable(self) -> bool: """ raise NotImplementedError('Subclasses should implement this method.') + @property + def num_devices(self) -> int: + """Number of devices that the sharding contains.""" + raise NotImplementedError('Subclasses should implement this method.') + @property def memory_kind(self) -> str | None: """Returns the memory kind of the sharding.""" diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index e99184299f4e..00d408a73251 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -257,6 +257,10 @@ def _from_parsed_pspec( memory_kind=memory_kind, _parsed_pspec=parsed_pspec, _manual_axes=_manual_axes) + @property + def num_devices(self) -> int: + return self.mesh.size + @property def device_set(self) -> set[Device]: if isinstance(self.mesh, mesh_lib.AbstractMesh): @@ -366,6 +370,10 @@ def __eq__(self, other): return (self._device == other._device and self.memory_kind == other.memory_kind) + @property + def num_devices(self) -> int: + return len(self.device_set) + @property def device_set(self) -> set[Device]: return {self._device} @@ -501,6 +509,10 @@ def default(cls, shape: Shape, sharded_dim: int = 0, pmap_devices = np.array(devices) return cls(pmap_devices, sharding_spec) + @property + def num_devices(self) -> int: + return len(self.device_set) + @functools.cached_property def device_set(self) -> set[Device]: return set(self.devices.flat) @@ -707,6 +719,10 @@ def __eq__(self, other) -> bool: # Sharding interface + @property + def num_devices(self) -> int: + return len(self.device_set) + @functools.cached_property def device_set(self) -> set[xc.Device]: return set(self._devices) @@ -826,6 +842,10 @@ def check_compatible_aval(self, aval_shape: Shape) -> None: f"{len(num_ways_dim_sharded)}, but was applied to a value of rank " f"{len(aval_shape)}") + @property + def num_devices(self) -> int: + return len(self.device_set) + @functools.cached_property def device_set(self) -> set[Device]: return set(self._devices) @@ -1405,12 +1425,12 @@ def get_process_index_and_count( if (tensor_sharding.is_fully_addressable or tensor_sharding.is_fully_replicated): return (0, 1) - num_devices = len(tensor_sharding.device_set) # Get device to indices map, we don't care about the concrete # global shape here, only to get the distribution of shards across the tensor # using (num_devices, num_devices, ...) This is a universal shape that is # compatible with any mesh with num_devices. - device_map = tensor_sharding.devices_indices_map((num_devices,) * ndims) + device_map = tensor_sharding.devices_indices_map( + (tensor_sharding.num_devices,) * ndims) # Get the slices for 'dim' for all devices. global_slice = {k: v[dim] for k, v in device_map.items()} @@ -1564,7 +1584,7 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: def is_single_device_sharding(sharding: sharding.Sharding) -> bool: # Special case PmapSharding here because PmapSharding maps away an axis # and needs to be handled separately.test_pjit_single_device_sharding_add - return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) + return sharding.num_devices == 1 and not isinstance(sharding, PmapSharding) def make_key_array_phys_sharding(aval, sharding): if is_single_device_sharding(sharding): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 47d2224ccd09..7038c0043048 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -35,6 +35,7 @@ from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import core +from jax._src import prng from jax._src import test_util as jtu from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals @@ -1756,6 +1757,22 @@ def f(x): ) self.assertAllClose(v*v, f(v), check_dtypes=False) + def test_sharded_prng_with_abstract_mesh(self): + shape = (8, 2, 2) + mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + + np_inp = np.arange(math.prod(shape), dtype=np.uint32).reshape(shape) + key = prng.random_seed(np_inp, impl=prng.threefry_prng_impl) + key = jax.device_put(key, NamedSharding(mesh, P())) + + @jax.jit + def shard_key(key): + return shard_map( + lambda x: x, mesh=mesh.abstract_mesh, in_specs=P(), out_specs=P())(key) + + out = shard_key(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + def test_partial_auto_error_wsc_manual(self): mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) From bab70dda97cf3ecab68a915ad9f9261f25a02efc Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 14 Aug 2024 09:11:13 -0700 Subject: [PATCH 127/702] Reverts 734ebd570891ceaf8c7104e12256a1edfe942b14 PiperOrigin-RevId: 662942100 --- jax/_src/interpreters/mlir.py | 24 ------------------ tests/layout_test.py | 48 ----------------------------------- 2 files changed, 72 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 5cfb6a0e0699..814c6a9886d7 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1010,23 +1010,6 @@ def _get_mem_kind(s: JSharding | None) -> str | None: return s.memory_kind -def _is_default_layout(curr_layout, sharding, aval): - if curr_layout is None or sharding is None: - return True - if isinstance(curr_layout, AutoLayout): - return False - d = sharding._device_assignment[0] - try: - return curr_layout == DeviceLocalLayout.from_pjrt_layout( - d.client.get_default_layout(aval.dtype, aval.shape, d)) - except xla_extension.XlaRuntimeError as e: - msg, *_ = e.args - if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - return True - else: - raise - - def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, @@ -1081,13 +1064,6 @@ def lower_jaxpr_to_module( "In multi-platform lowering either all or no lowering platforms " f"should support donation. Lowering for {platforms} of which " f"only {platforms_with_donation} support donation") - if (in_layouts is not None and arg_shardings is not None and - out_layouts is not None and result_shardings is not None - ) and not ( - all(map(_is_default_layout, in_layouts, arg_shardings, in_avals)) and - all(map(_is_default_layout, out_layouts, result_shardings, out_avals)) - ): - xla_donated_args = donated_args if num_partitions > 1 and ( result_shardings is None or all(s is None for s in result_shardings)): xla_donated_args = donated_args diff --git a/tests/layout_test.py b/tests/layout_test.py index 2ddd72764be0..c72082d0a16c 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -500,54 +500,6 @@ def g(x): 'Layout passed to jit does not match the layout on the respective arg'): g(arr) - def test_layout_donation(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - s = NamedSharding(mesh, P('x', 'y')) - shape = (16, 128) - np_inp = np.arange(math.prod(shape)).reshape(shape) - - custom_dll = DLL(major_to_minor=(0, 1)) - arr = jax.device_put(np_inp, Layout(custom_dll, s)) - - @partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0) - def f(x): - return x - - out = f(arr) - self.assertTrue(arr.is_deleted()) - - def test_layout_donation_auto(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - s = NamedSharding(mesh, P('x', 'y')) - shape = (128, 16) - np_inp = np.arange(math.prod(shape)).reshape(shape) - - arr = jax.device_put(np_inp, s) - - @partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0) - def f(x): - return x * x - - out = f(arr) - self.assertTrue(arr.is_deleted()) - - def test_layout_donation_matching_in_and_out(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) - s = NamedSharding(mesh, P('x', 'y')) - shape = (128, 16) - np_inp = np.arange(math.prod(shape)).reshape(shape) - - custom_dll = DLL(major_to_minor=(0, 1)) - l = Layout(custom_dll, s) - arr = jax.device_put(np_inp, l) - - @partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0) - def f(x): - return x * x - - out = f(arr) - self.assertTrue(arr.is_deleted()) - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From ad1bd387901af49fc8cf5576cd8af6abe3faf988 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 14 Aug 2024 09:20:07 -0700 Subject: [PATCH 128/702] Move logic about when to dispatch to batched LU decomposition algorithm on GPU into the kernel. This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism. PiperOrigin-RevId: 662945116 --- jaxlib/cuda/BUILD | 21 +---- jaxlib/gpu/BUILD | 2 - jaxlib/gpu/blas.cc | 4 - jaxlib/gpu/blas_kernels_ffi.cc | 133 ------------------------------- jaxlib/gpu/blas_kernels_ffi.h | 30 ------- jaxlib/gpu/gpu_kernels.cc | 3 - jaxlib/gpu/solver_kernels_ffi.cc | 103 +++++++++++++++++++++--- jaxlib/gpu_solver.py | 18 ++--- jaxlib/rocm/BUILD.bazel | 23 +----- 9 files changed, 104 insertions(+), 233 deletions(-) delete mode 100644 jaxlib/gpu/blas_kernels_ffi.cc delete mode 100644 jaxlib/gpu/blas_kernels_ffi.h diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index e515de2d3a95..8121a1058768 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -114,22 +114,6 @@ cc_library( ], ) -cc_library( - name = "cublas_kernels_ffi", - srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"], - deps = [ - ":cuda_blas_handle_pool", - ":cuda_gpu_kernel_helpers", - ":cuda_vendor", - "//jaxlib:ffi_helpers", - "@xla//xla/ffi/api:ffi", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@com_google_absl//absl/status", - ], -) - pybind_extension( name = "_blas", srcs = ["//jaxlib/gpu:blas.cc"], @@ -148,7 +132,6 @@ pybind_extension( module_name = "_blas", deps = [ ":cublas_kernels", - ":cublas_kernels_ffi", ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", "@xla//xla/tsl/cuda:cublas", @@ -238,11 +221,13 @@ cc_library( srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"], deps = [ + ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", ":cuda_solver_handle_pool", ":cuda_vendor", "//jaxlib:ffi_helpers", "@xla//xla/ffi/api:ffi", + "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", "@com_google_absl//absl/status", @@ -274,6 +259,7 @@ pybind_extension( ":cusolver_kernels", ":cusolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", + "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/python/lib/core:numpy", @@ -466,7 +452,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":cublas_kernels", - ":cublas_kernels_ffi", ":cuda_linalg_kernels", ":cuda_prng_kernels", ":cuda_vendor", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index daa03aa5be24..6bdaf4ef1322 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -29,8 +29,6 @@ exports_files(srcs = [ "blas_handle_pool.h", "blas_kernels.cc", "blas_kernels.h", - "blas_kernels_ffi.cc", - "blas_kernels_ffi.h", "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", "gpu_kernels.cc", diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index 62a1bbc94790..e8761bd32ac9 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/blas_kernels_ffi.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/tsl/python/lib/core/numpy.h" @@ -70,9 +69,6 @@ nb::dict Registrations() { nb::dict dict; dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched); dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); - - dict[JAX_GPU_PREFIX "blas_getrf_batched_ffi"] = - EncapsulateFfiHandler(GetrfBatchedFfi); return dict; } diff --git a/jaxlib/gpu/blas_kernels_ffi.cc b/jaxlib/gpu/blas_kernels_ffi.cc deleted file mode 100644 index 610ce105260e..000000000000 --- a/jaxlib/gpu/blas_kernels_ffi.cc +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/blas_kernels_ffi.h" - -#include "absl/status/status.h" -#include "jaxlib/ffi_helpers.h" -#include "jaxlib/gpu/blas_handle_pool.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/vendor.h" -#include "xla/ffi/api/ffi.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { - -namespace ffi = ::xla::ffi; - -namespace { -#define GETRF_BATCHED_KERNEL_IMPL(type, name) \ - template <> \ - struct GetrfBatchedKernel { \ - static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \ - int* ipiv, int* info, int batch) { \ - return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \ - } \ - } - -template -struct GetrfBatchedKernel; -GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched); -#undef GETRF_BATCHED_KERNEL_IMPL - -template -ffi::Error GetrfBatchedImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch, - ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions())); - auto [batch, rows, cols] = SplitBatch2D(a.dimensions()); - if (rows != cols) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "getrf_batched only supports square matrices"); - } - FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); - FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); - - auto maybe_workspace = scratch.Allocate(sizeof(void*) * batch); - if (!maybe_workspace.has_value()) { - return ffi::Error(ffi::ErrorCode::kUnknown, - "Unable to allocate workspace for batched getrf"); - } - auto workspace = maybe_workspace.value(); - - auto a_data = a.untyped_data(); - auto out_data = out->untyped_data(); - auto ipiv_data = ipiv->typed_data(); - auto info_data = info->typed_data(); - if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols, - gpuMemcpyDeviceToDevice, stream))); - } - - FFI_ASSIGN_OR_RETURN( - auto a_ptrs_host, - MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n)); - // TODO(phawkins, danfm): ideally we would not need to synchronize here, but - // to avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); - - auto batch_ptrs = static_cast(workspace); - FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run( - handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch)); - - return ffi::Error::Success(); -} - -ffi::Error GetrfBatchedDispatch( - gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer a, - ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { - auto dataType = a.element_type(); - if (dataType != out->element_type()) { - return ffi::Error( - ffi::ErrorCode::kInvalidArgument, - "Input and output to getrf_batched must have the same element type"); - } - if (dataType == ffi::DataType::F32) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::F64) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::C64) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, - info); - } else if (dataType == ffi::DataType::C128) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, - info); - } - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "Unsupported element type for getrf"); -} -} // namespace - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - GetrfBatchedFfi, GetrfBatchedDispatch, - ffi::Ffi::Bind() - .Ctx>() - .Ctx() - .Arg() // a - .Ret() // out - .Ret>() // ipiv - .Ret>() // info -); - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_kernels_ffi.h b/jaxlib/gpu/blas_kernels_ffi.h deleted file mode 100644 index ad3bf90120e9..000000000000 --- a/jaxlib/gpu/blas_kernels_ffi.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_GPU_BLAS_KERNELS_FFI_H_ -#define JAXLIB_GPU_BLAS_KERNELS_FFI_H_ - -#include "jaxlib/gpu/vendor.h" -#include "xla/ffi/api/ffi.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { - -XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfBatchedFfi); - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax - -#endif // JAXLIB_GPU_BLAS_KERNELS_FFI_H_ diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index ccca8e157b98..b76cea19ea2e 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -17,7 +17,6 @@ limitations under the License. // JAX-generated HLO code from outside of JAX. #include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/blas_kernels_ffi.h" #include "jaxlib/gpu/linalg_kernels.h" #include "jaxlib/gpu/prng_kernels.h" #include "jaxlib/gpu/rnn_kernels.h" @@ -36,8 +35,6 @@ namespace { XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, "CUDA"); -XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cublas_getrf_batched_ffi", "CUDA", - GetrfBatchedFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 414e159b2aac..051b9fd03f9e 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -16,10 +16,12 @@ limitations under the License. #include "jaxlib/gpu/solver_kernels_ffi.h" #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "jaxlib/ffi_helpers.h" +#include "jaxlib/gpu/blas_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/vendor.h" @@ -58,12 +60,11 @@ GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf); #undef GETRF_KERNEL_IMPL template -ffi::Error GetrfImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch, +ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, ffi::Result out, ffi::Result> ipiv, ffi::Result> info) { - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions())); - auto [batch, rows, cols] = SplitBatch2D(a.dimensions()); FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); @@ -98,6 +99,64 @@ ffi::Error GetrfImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch, return ffi::Error::Success(); } +#define GETRF_BATCHED_KERNEL_IMPL(type, name) \ + template <> \ + struct GetrfBatchedKernel { \ + static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \ + int* ipiv, int* info, int batch) { \ + return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \ + } \ + } + +template +struct GetrfBatchedKernel; +GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched); +GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched); +GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched); +GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched); +#undef GETRF_BATCHED_KERNEL_IMPL + +template +ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, + ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result> ipiv, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); + + auto maybe_workspace = scratch.Allocate(sizeof(void*) * batch); + if (!maybe_workspace.has_value()) { + return ffi::Error(ffi::ErrorCode::kUnknown, + "Unable to allocate workspace for batched getrf"); + } + auto workspace = maybe_workspace.value(); + + auto a_data = a.untyped_data(); + auto out_data = out->untyped_data(); + auto ipiv_data = ipiv->typed_data(); + auto info_data = info->typed_data(); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( + gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols, + gpuMemcpyDeviceToDevice, stream))); + } + + FFI_ASSIGN_OR_RETURN( + auto a_ptrs_host, + MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n)); + // TODO(phawkins, danfm): ideally we would not need to synchronize here, but + // to avoid it we need a way to keep the host-side buffer alive until the copy + // completes. + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + auto batch_ptrs = static_cast(workspace); + FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run( + handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch)); + + return ffi::Error::Success(); +} + ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer a, ffi::Result out, ffi::Result> ipiv, @@ -108,14 +167,36 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::ErrorCode::kInvalidArgument, "The input and output to getrf must have the same element type"); } - if (dataType == ffi::DataType::F32) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::F64) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::C64) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::C128) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); + FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions())); + auto [batch, rows, cols] = SplitBatch2D(a.dimensions()); + if (batch > 1 && rows == cols && rows / batch <= 128) { + if (dataType == ffi::DataType::F32) { + return GetrfBatchedImpl(batch, cols, stream, scratch, a, out, ipiv, + info); + } else if (dataType == ffi::DataType::F64) { + return GetrfBatchedImpl(batch, cols, stream, scratch, a, out, + ipiv, info); + } else if (dataType == ffi::DataType::C64) { + return GetrfBatchedImpl(batch, cols, stream, scratch, a, + out, ipiv, info); + } else if (dataType == ffi::DataType::C128) { + return GetrfBatchedImpl( + batch, cols, stream, scratch, a, out, ipiv, info); + } + } else { + if (dataType == ffi::DataType::F32) { + return GetrfImpl(batch, rows, cols, stream, scratch, a, out, ipiv, + info); + } else if (dataType == ffi::DataType::F64) { + return GetrfImpl(batch, rows, cols, stream, scratch, a, out, ipiv, + info); + } else if (dataType == ffi::DataType::C64) { + return GetrfImpl(batch, rows, cols, stream, scratch, a, out, + ipiv, info); + } else if (dataType == ffi::DataType::C128) { + return GetrfImpl(batch, rows, cols, stream, scratch, a, + out, ipiv, info); + } } return ffi::Error(ffi::ErrorCode::kInvalidArgument, "Unsupported element type for getrf"); diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 87171fdb4611..baa84e8eb9de 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -43,10 +43,7 @@ if _cublas: for _name, _value in _cublas.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) + xla_client.register_custom_call_target(_name, _value, platform="CUDA") for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: try: @@ -78,10 +75,7 @@ if _hipblas: for _name, _value in _hipblas.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) + xla_client.register_custom_call_target(_name, _value, platform="ROCM") for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: try: @@ -115,15 +109,14 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a): num_bd = len(batch_dims) i32_type = ir.IntegerType.get_signless(32) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - batch = math.prod(batch_dims) - use_batched = batch > 1 and m == n and m // batch <= 128 # TODO(b/357034884): Remove after 3 week forward compatibility window. if ctx.is_forward_compat(): if not gpu_blas: raise GpuLibNotLinkedError() - if use_batched: + batch = math.prod(batch_dims) + if batch > 1 and m == n and m // batch <= 128: lwork, opaque = gpu_blas.build_getrf_batched_descriptor( np.dtype(dtype), batch, m) workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) @@ -154,9 +147,8 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a): operand_output_aliases={0: 0}).results return out[:3] - target = "blas_getrf_batched_ffi" if use_batched else "solver_getrf_ffi" return custom_call( - f"{platform}{target}", + f"{platform}solver_getrf_ffi", result_types=[ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type), diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index 1ec36fd30c8e..ba9ceb4c3fa7 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -98,23 +98,6 @@ cc_library( ], ) -cc_library( - name = "hipblas_kernels_ffi", - srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"], - deps = [ - ":hip_blas_handle_pool", - ":hip_gpu_kernel_helpers", - ":hip_vendor", - "//jaxlib:ffi_helpers", - "@com_google_absl//absl/status", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", - ], -) - pybind_extension( name = "_blas", srcs = ["//jaxlib/gpu:blas.cc"], @@ -127,7 +110,6 @@ pybind_extension( deps = [ ":hip_vendor", ":hipblas_kernels", - ":hipblas_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", @@ -176,12 +158,14 @@ cc_library( srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"], deps = [ + ":hip_blas_handle_pool", ":hip_gpu_kernel_helpers", ":hip_solver_handle_pool", ":hip_vendor", "//jaxlib:ffi_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", @@ -199,14 +183,15 @@ pybind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ - ":hip_solver_handle_pool", ":hip_gpu_kernel_helpers", + ":hip_solver_handle_pool", ":hip_vendor", ":hipsolver_kernels", ":hipsolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@nanobind", From bd9698ec6d558586da57e7113e057189f9761e45 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 14 Aug 2024 10:06:13 -0700 Subject: [PATCH 129/702] Deprecate several internal utilities in jax.core --- CHANGELOG.md | 3 +++ jax/core.py | 18 ++++++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b30b08abb01..d0ac5c3737b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The `jax.experimental.array_api` module is deprecated, and importing it is no longer required to use the Array API. `jax.numpy` supports the array API directly; see {ref}`python-array-api` for more information. + * The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and + `jax.core.check_valid_jaxtype` are now deprecated, and will be removed in + the future. ## jaxlib 0.4.32 diff --git a/jax/core.py b/jax/core.py index 80025e8619f3..1f433d6f5c29 100644 --- a/jax/core.py +++ b/jax/core.py @@ -66,10 +66,7 @@ call_bind_with_continuation as call_bind_with_continuation, call_impl as call_impl, call_p as call_p, - check_eqn as check_eqn, check_jaxpr as check_jaxpr, - check_type as check_type, - check_valid_jaxtype as check_valid_jaxtype, closed_call_p as closed_call_p, concrete_aval as concrete_aval, concrete_or_error as concrete_or_error, @@ -110,7 +107,6 @@ new_sublevel as new_sublevel, no_axis_name as no_axis_name, no_effects as no_effects, - non_negative_dim as _deprecated_non_negative_dim, outfeed_primitives as outfeed_primitives, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, primitive_uses_outfeed as primitive_uses_outfeed, @@ -144,6 +140,13 @@ from jax._src import core as _src_core _deprecations = { + # Added 2024-08-14 + "check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn), + "check_type": ("jax.core.check_type is deprecated.", _src_core.check_type), + "check_valid_jaxtype": ( + ("jax.core.check_valid_jaxtype is deprecated. Instead, you can manually" + " raise an error if core.valid_jaxtype() returns False."), + _src_core.check_valid_jaxtype), # Added 2024-06-12 "pp_aval": ("jax.core.pp_aval is deprecated.", _src_core.pp_aval), "pp_eqn": ("jax.core.pp_eqn is deprecated.", _src_core.pp_eqn), @@ -181,13 +184,16 @@ ), # Added Jan 8, 2024 "non_negative_dim": ( - "jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _deprecated_non_negative_dim, + "jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _src_core.non_negative_dim, ), } import typing if typing.TYPE_CHECKING: - non_negative_dim = _deprecated_non_negative_dim + check_eqn = _src_core.check_eqn + check_type = _src_core.check_type + check_valid_jaxtype = _src_core.check_valid_jaxtype + non_negative_dim = _src_core.non_negative_dim pp_aval = _src_core.pp_aval pp_eqn = _src_core.pp_eqn pp_eqn_rules = _src_core.pp_eqn_rules From 599c13aa0964069fa21a9e164281581fefff3cd9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 14 Aug 2024 10:57:53 -0700 Subject: [PATCH 130/702] Introduce hermetic CUDA in Google ML projects. 1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases. [Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/) [Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/) [Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history) 2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA. Note: use `@local_tsl` instead of `@tsl` in Tensorflow project. ``` load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", ) cuda_json_init_repository() load( "@cuda_redist_json//:distributions.bzl", "CUDA_REDISTRIBUTIONS", "CUDNN_REDISTRIBUTIONS", ) load( "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", "cuda_redist_init_repositories", "cudnn_redist_init_repository", ) cuda_redist_init_repositories( cuda_redistributions = CUDA_REDISTRIBUTIONS, ) cudnn_redist_init_repository( cudnn_redistributions = CUDNN_REDISTRIBUTIONS, ) load( "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure", ) cuda_configure(name = "local_config_cuda") load( "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", ) nccl_redist_init_repository() load( "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure", ) nccl_configure(name = "local_config_nccl") ``` PiperOrigin-RevId: 662981325 --- .bazelrc | 35 +++++++++------------ CHANGELOG.md | 8 +++++ WORKSPACE | 47 ++++++++++++++++++++++++++++ build/build.py | 62 ++++++++----------------------------- docs/developer.md | 34 ++++++++++++++++++--- tests/logging_test.py | 71 ++++++++++++++++++++++++------------------- 6 files changed, 152 insertions(+), 105 deletions(-) diff --git a/.bazelrc b/.bazelrc index ce9a219f4157..767e0982a1e4 100644 --- a/.bazelrc +++ b/.bazelrc @@ -61,17 +61,15 @@ build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true - -# Build with nvcc for CUDA and clang for host -build:nvcc_clang --config=cuda -# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang -build:nvcc_clang --action_env=TF_CUDA_CLANG="1" -build:nvcc_clang --action_env=TF_NVCC_CLANG="1" -build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc +# Default hermetic CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +# This flag is needed to include hermetic CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true # Requires MSVC and LLVM to be installed build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl @@ -94,7 +92,6 @@ build:cuda --linkopt=-Wl,--disable-new-dtags build:cuda_clang --@local_config_cuda//:cuda_compiler=clang build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang --action_env=TF_CUDA_CLANG="1" # Disable clang extention that rejects type definitions within offsetof. # This was added in clang-16 by https://reviews.llvm.org/D133574. # Can be removed once upb is updated, since a type definition is used within @@ -104,6 +101,12 @@ build:cuda_clang --copt=-Wno-gnu-offsetof-extensions # Disable clang extention that rejects unknown arguments. build:cuda_clang --copt=-Qunused-arguments +# Build with nvcc for CUDA and clang for host +build:nvcc_clang --config=cuda +build:nvcc_clang --config=cuda_clang +build:nvcc_clang --action_env=TF_NVCC_CLANG="1" +build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc + build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true build:rocm --repo_env TF_NEED_ROCM=1 @@ -198,7 +201,6 @@ build:rbe_linux --host_linkopt=-lm # https://github.com/bazelbuild/bazel/issues/13623 build:rbe_cpu_linux_base --config=rbe_linux build:rbe_cpu_linux_base --config=cuda_clang -build:rbe_cpu_linux_base --action_env=TF_NVCC_CLANG="1" build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64" @@ -218,22 +220,15 @@ build:rbe_linux_cuda_base --config=cuda build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda12.3_nvcc_base --config=cuda_clang -build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1" -build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDA_VERSION=12 -build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDNN_VERSION=9 -build:rbe_linux_cuda12.3_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12" -build:rbe_linux_cuda12.3_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" +build:rbe_linux_cuda12.3_nvcc_base --config=nvcc_clang +build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda" -build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_nccl" -# RBE machines have an older CUDA driver version, so we have to enable driver forward compatibility -build:rbe_linux_cuda12.3_nvcc_base --test_env=LD_LIBRARY_PATH=/usr/local/cuda/compat build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b30b08abb01..d22b6a2f980a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jaxlib 0.4.32 +* Breaking changes + * Hermetic CUDA support is added. + Hermetic CUDA uses a specific downloadable version of CUDA instead of the + user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL + distributions, and then use CUDA libraries and tools as dependencies in + various Bazel targets. This enables more reproducible builds for JAX and its + supported CUDA versions. + ## jax 0.4.31 (July 29, 2024) * Deletion diff --git a/WORKSPACE b/WORKSPACE index 57e84b12ddf1..383adf810766 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -59,3 +59,50 @@ xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() + +load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", +) + +cuda_json_init_repository() + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) +load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, +) + +load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) + +cuda_configure(name = "local_config_cuda") + +load( + "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +nccl_redist_init_repository() + +load( + "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", +) + +nccl_configure(name = "local_config_nccl") diff --git a/build/build.py b/build/build.py index 0db2630a3e2b..f2920fba6221 100755 --- a/build/build.py +++ b/build/build.py @@ -236,16 +236,13 @@ def get_clang_major_version(clang_path): return major_version - def write_bazelrc(*, remote_build, - cuda_toolkit_path, cudnn_install_path, cuda_version, cudnn_version, rocm_toolkit_path, cpu, cuda_compute_capabilities, rocm_amdgpu_targets, target_cpu_features, wheel_cpu, enable_mkl_dnn, use_clang, clang_path, clang_major_version, enable_cuda, enable_nccl, enable_rocm, python_version): - tf_cuda_paths = [] with open("../.jax_configure.bazelrc", "w") as f: if not remote_build: @@ -263,28 +260,6 @@ def write_bazelrc(*, remote_build, # https://github.com/openxla/xla/blob/c4277a076e249f5b97c8e45c8cb9d1f554089d76/.bazelrc#L505 f.write("build --copt=-Wno-gnu-offsetof-extensions\n") - if cuda_toolkit_path: - tf_cuda_paths.append(cuda_toolkit_path) - f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n" - .format(cuda_toolkit_path=cuda_toolkit_path)) - if cudnn_install_path: - # see https://github.com/tensorflow/tensorflow/issues/51040 - if cudnn_install_path not in tf_cuda_paths: - tf_cuda_paths.append(cudnn_install_path) - f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n" - .format(cudnn_install_path=cudnn_install_path)) - if len(tf_cuda_paths): - f.write("build --action_env TF_CUDA_PATHS=\"{tf_cuda_paths}\"\n" - .format(tf_cuda_paths=",".join(tf_cuda_paths))) - if cuda_version: - f.write("build --action_env TF_CUDA_VERSION=\"{cuda_version}\"\n" - .format(cuda_version=cuda_version)) - if cudnn_version: - f.write("build --action_env TF_CUDNN_VERSION=\"{cudnn_version}\"\n" - .format(cudnn_version=cudnn_version)) - if cuda_compute_capabilities: - f.write( - f'build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') if rocm_toolkit_path: f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n" .format(rocm_toolkit_path=rocm_toolkit_path)) @@ -313,6 +288,15 @@ def write_bazelrc(*, remote_build, if use_clang: f.write("build --config=nvcc_clang\n") f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") + if cuda_version: + f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n" + .format(cuda_version=cuda_version)) + if cudnn_version: + f.write("build --repo_env HERMETIC_CUDNN_VERSION=\"{cudnn_version}\"\n" + .format(cudnn_version=cudnn_version)) + if cuda_compute_capabilities: + f.write( + f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') if enable_rocm: f.write("build --config=rocm\n") if not enable_nccl: @@ -458,7 +442,7 @@ def main(): ) parser.add_argument( "--gpu_plugin_cuda_version", - choices=["11", "12"], + choices=["12"], default="12", help="Which CUDA major version the gpu plugin is for.") parser.add_argument( @@ -481,22 +465,14 @@ def main(): "remote_build", default=False, help_str="Should we build with RBE (Remote Build Environment)?") - parser.add_argument( - "--cuda_path", - default=None, - help="Path to the CUDA toolkit.") - parser.add_argument( - "--cudnn_path", - default=None, - help="Path to CUDNN libraries.") parser.add_argument( "--cuda_version", default=None, - help="CUDA toolkit version, e.g., 11.1") + help="CUDA toolkit version, e.g., 12.3.2") parser.add_argument( "--cudnn_version", default=None, - help="CUDNN version, e.g., 8") + help="CUDNN version, e.g., 8.9.7.29") # Caution: if changing the default list of CUDA capabilities, you should also # update the list in .bazelrc, which is used for wheel builds. parser.add_argument( @@ -562,12 +538,6 @@ def main(): if args.verbose: logger.setLevel(logging.DEBUG) - if is_windows() and args.enable_cuda: - if args.cuda_version is None: - parser.error("--cuda_version is needed for Windows CUDA build.") - if args.cudnn_version is None: - parser.error("--cudnn_version is needed for Windows CUDA build.") - if args.enable_cuda and args.enable_rocm: parser.error("--enable_cuda and --enable_rocm cannot be enabled at the same time.") @@ -615,15 +585,9 @@ def main(): print(f"Target CPU: {wheel_cpu}") print(f"Target CPU features: {args.target_cpu_features}") - cuda_toolkit_path = args.cuda_path - cudnn_install_path = args.cudnn_path rocm_toolkit_path = args.rocm_path print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no")) if args.enable_cuda: - if cuda_toolkit_path: - print(f"CUDA toolkit path: {cuda_toolkit_path}") - if cudnn_install_path: - print(f"CUDNN library path: {cudnn_install_path}") if args.cuda_compute_capabilities is not None: print(f"CUDA compute capabilities: {args.cuda_compute_capabilities}") if args.cuda_version: @@ -640,8 +604,6 @@ def main(): write_bazelrc( remote_build=args.remote_build, - cuda_toolkit_path=cuda_toolkit_path, - cudnn_install_path=cudnn_install_path, cuda_version=args.cuda_version, cudnn_version=args.cudnn_version, rocm_toolkit_path=rocm_toolkit_path, diff --git a/docs/developer.md b/docs/developer.md index e2850d2a94e7..78471a530c99 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -75,11 +75,10 @@ There are two ways to build `jaxlib` with CUDA support: (1) use `python build/build.py --enable_cuda` to generate a jaxlib wheel with cuda support, or (2) use `python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` -to generate three wheels (jaxlib without cuda, jax-cuda-plugin, -and jax-cuda-pjrt). You can set `gpu_plugin_cuda_version` to 11 or 12. +to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and +jax-cuda-pjrt). -See `python build/build.py --help` for configuration options, including ways to -specify the paths to CUDA and CUDNN, which you must have installed. Here +See `python build/build.py --help` for configuration options. Here `python` should be the name of your Python 3 interpreter; on some systems, you may need to use `python3` instead. Despite calling the script with `python`, Bazel will always use its own hermetic Python interpreter and dependencies, only @@ -87,6 +86,31 @@ the `build/build.py` script itself will be processed by your system Python interpreter. By default, the wheel is written to the `dist/` subdirectory of the current directory. +* JAX versions starting from v.0.4.32: you can provide custom CUDA and CUDNN + versions in the configuration options. Bazel will download them and use as + target dependencies. + + To download the specific versions of CUDA/CUDNN redistributions, you can use + the following command: + + ```bash + python build/build.py --enable_cuda \ + --cuda_version=12.3.2 --cudnn_version=9.1.1 + ``` + + To point to CUDA/CUDNN/NCCL redistributions on local file system, you can use + the following command: + + ```bash + python build/build.py --enable_cuda \ + --bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \ + --bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \ + --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" + ``` + +* JAX versions prior v.0.4.32: you must have CUDA and CUDNN installed and + provide paths to them using configuration options. + ### Building jaxlib from source with a modified XLA repository. JAX depends on XLA, whose source code is in the @@ -112,6 +136,8 @@ particular before each `jaxlib` release. ### Additional Notes for Building `jaxlib` from source on Windows +Note: JAX does not support CUDA on Windows; use WSL2 for CUDA support. + On Windows, follow [Install Visual Studio](https://docs.microsoft.com/en-us/visualstudio/install/install-visual-studio?view=vs-2019) to set up a C++ toolchain. Visual Studio 2019 version 16.5 or newer is required. diff --git a/tests/logging_test.py b/tests/logging_test.py index 5a495d47d31b..70f619de5ee6 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -19,6 +19,7 @@ import platform import subprocess import sys +import tempfile import textwrap import unittest @@ -74,37 +75,45 @@ def test_no_log_spam(self): if sys.executable is None: raise self.skipTest("test requires access to python binary") - program = textwrap.dedent(""" - import jax - jax.device_count() - f = jax.jit(lambda x: x + 1) - f(1) - f(2) - jax.numpy.add(1, 1) - """) - python = sys.executable - assert "python" in python - env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} - if os.getenv("PYTHONPATH"): - env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") - if os.getenv("LD_LIBRARY_PATH"): - env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") - # Make sure C++ logging is at default level for the test process. - proc = subprocess.run( - [python, "-c", program], - capture_output=True, - env=env_variables, - ) - - lines = proc.stdout.split(b"\n") - lines.extend(proc.stderr.split(b"\n")) - allowlist = [ - b"", - b"An NVIDIA GPU may be present on this machine, but a CUDA-enabled " - b"jaxlib is not installed. Falling back to cpu.", - ] - lines = [l for l in lines if l not in allowlist] - self.assertEmpty(lines) + # Save script in file to fix the problem with + # `tsl::Env::Default()->GetExecutablePath()` not working properly with + # command flag. + with tempfile.NamedTemporaryFile( + mode="w+", encoding="utf-8", suffix=".py" + ) as f: + f.write(textwrap.dedent(""" + import jax + jax.device_count() + f = jax.jit(lambda x: x + 1) + f(1) + f(2) + jax.numpy.add(1, 1) + """)) + python = sys.executable + assert "python" in python + env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} + if os.getenv("PYTHONPATH"): + env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") + if os.getenv("LD_LIBRARY_PATH"): + env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") + # Make sure C++ logging is at default level for the test process. + proc = subprocess.run( + [python, f.name], + capture_output=True, + env=env_variables, + ) + + lines = proc.stdout.split(b"\n") + lines.extend(proc.stderr.split(b"\n")) + allowlist = [ + b"", + ( + b"An NVIDIA GPU may be present on this machine, but a" + b" CUDA-enabled jaxlib is not installed. Falling back to cpu." + ), + ] + lines = [l for l in lines if l not in allowlist] + self.assertEmpty(lines) def test_debug_logging(self): # Warmup so we don't get "No GPU/TPU" warning later. From db000459a83e8afbeda2e23428d3a294905d55cf Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 14 Aug 2024 11:07:45 -0700 Subject: [PATCH 131/702] [Pallas] Add boolean vector support. PiperOrigin-RevId: 662985359 --- jax/_src/pallas/mosaic/lowering.py | 19 +++++++---- tests/pallas/tpu_pallas_test.py | 51 ++++++++++++++++++------------ 2 files changed, 42 insertions(+), 28 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 4db430ecde9e..caf35316af63 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1148,15 +1148,20 @@ def _maybe_cast_load_to_bool( if out_aval.dtype != jnp.bool_: return val load_scalar_type = _dtype_to_ir_type(BOOL_MEMREF_TYPE) - if not out_aval.shape: - # For scalars, truncate the value to a bool. - pred = _cmpi_lowering_types[lax.ne_p] - predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) - const_zero = ir.IntegerAttr.get(load_scalar_type, 0) + pred = _cmpi_lowering_types[lax.ne_p] + predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) + const_zero = ir.IntegerAttr.get(load_scalar_type, 0) + if out_aval.shape: # Vector case. + load_vector_type = aval_to_ir_type(out_aval, is_kernel_boundary=True) + vector_zeros = arith.ConstantOp( + load_vector_type, + ir.DenseElementsAttr.get_splat(load_vector_type, const_zero) + ) + return arith.CmpIOp(predicate, val, vector_zeros).result + else: # Scalar case. const_zero = arith.ConstantOp(load_scalar_type, const_zero) return arith.CmpIOp(predicate, val, const_zero).result - else: - raise NotImplementedError("Boolean vector loads are not supported.") + def _maybe_cast_store_to_memref_type( expected_aval, val: ir.Value) -> ir.Value: diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index bf00acfb670f..5bf4cf2804ac 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1992,27 +1992,36 @@ def inner_scope(scoped_ref): def test_vector_bool_load_store(self): def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] - input = jnp.array([[False, True, True, False]]) - output_shape = jax.ShapeDtypeStruct((1, 4), jnp.bool_) - if self.INTERPRET: - result = self.pallas_call( - kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), - out_shape=output_shape, - )(input) - np.testing.assert_array_equal(result, input) - else: - # TODO(justinfu): Fix vector boolean ops so that they do not trigger - # a relayout error from changing bitwidths in Mosaic. - with self.assertRaisesRegex( - Exception, 'Boolean vector loads are not supported.'): - self.pallas_call( - kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), - out_shape=output_shape, - )(input) + input = jax.random.bernoulli(jax.random.key(0), p=0.5, shape=(8, 128)) + output_shape = jax.ShapeDtypeStruct((8, 128), jnp.bool_) + result = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=output_shape, + )(input) + np.testing.assert_array_equal(result, input) + + def test_vector_bool_masking(self): + def kernel(mask_ref, true_ref, false_ref, o_ref): + o_ref[...] = jnp.where(mask_ref[...], true_ref[...], false_ref[...]) + key = jax.random.key(0) + k1, k2, k3 = jax.random.split(key, 3) + values_1 = jax.random.normal(k1, (8, 128), jnp.float32) + values_2 = jax.random.normal(k2, (8, 128), jnp.float32) + mask = jax.random.bernoulli(k3, p=0.5, shape=(8, 128)) + output_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) + result = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=output_shape, + )(mask, values_1, values_2) + expected = jnp.where(mask, values_1, values_2) + np.testing.assert_array_equal(result, expected) def test_bool_dma_not_implemented(self): if not jtu.is_device_tpu_at_least(4): From 85fb66a26d1c5a3f453fe97768f6fd801bf8be11 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 14 Aug 2024 11:41:52 -0700 Subject: [PATCH 132/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/cb1541c5f092807fced9e5e2b261371dba888906. PiperOrigin-RevId: 662998853 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e6b18e38bd40..a6db1a1289c5 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "55476059f622468985141311ef20328993bd7ba5" -XLA_SHA256 = "1cbc9b2956154d724018302bcf23343f3cbe7dacd114b101615339614d953fb9" +XLA_COMMIT = "cb1541c5f092807fced9e5e2b261371dba888906" +XLA_SHA256 = "3069f3a3e232ee3c1e0fe4a765f5d20879fd64c383256adf01b33b8fe775f0e6" def repo(): tf_http_archive( From 020513f300deeb066c97e2fcbddbdf1d93d0e5e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 14 Aug 2024 12:47:21 -0700 Subject: [PATCH 133/702] [Mosaic] Update serde to handle upstream MLIR changes For changes from https://github.com/llvm/llvm-project/commit/5f26497da7de10c4eeec33b5a5cfcb47e96836cc PiperOrigin-RevId: 663020509 --- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index ac2389d6c238..000bfe3eaea2 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -22,11 +22,13 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/OperationSupport.h" #include "mlir/include/mlir/Support/LogicalResult.h" @@ -41,7 +43,7 @@ namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; -constexpr int kVersion = 2; +constexpr int kVersion = 3; StringRef mangle(StringRef name, std::string* storage) { storage->clear(); @@ -100,10 +102,40 @@ LogicalResult semaphore_signal_rule(Operation* op, int version) { return success(); } +LogicalResult vector_multi_dim_reduce_rule(Operation* op, int version) { + // Changed reductions_dims from ArrayAttr of IntegerAttrs to DenseI64ArrayAttr + // in version 3. + if (version < 3) { + Attribute reduction_dims_attr = op->getAttr("reduction_dims"); + if (!reduction_dims_attr) { + return op->emitError("Missing reduction_dims attribute"); + } + ArrayAttr reduction_dims_array = dyn_cast(reduction_dims_attr); + if (!reduction_dims_array) { + return op->emitOpError("reduction_dims attribute is not an ArrayAttr"); + } + std::vector reduction_dims; + reduction_dims.reserve(reduction_dims_array.size()); + for (Attribute reduction_dim : reduction_dims_array) { + IntegerAttr reduction_dim_attr = dyn_cast(reduction_dim); + if (!reduction_dim_attr) { + return op->emitOpError( + "reduction_dims attribute contains a non-IntegerAttr"); + } + reduction_dims.push_back(reduction_dim_attr.getInt()); + } + op->setAttr("reduction_dims", + DenseI64ArrayAttr::get(op->getContext(), reduction_dims)); + } + return success(); +} + const llvm::StringMap& upgrade_rules() { static auto rules = new llvm::StringMap{ {EnqueueDMAOp::getOperationName(), enqueue_dma_rule}, {SemaphoreSignalOp::getOperationName(), semaphore_signal_rule}, + {vector::MultiDimReductionOp::getOperationName(), + vector_multi_dim_reduce_rule} }; return *rules; } @@ -189,4 +221,4 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { } // namespace -} // namespace mlir::tpu +} // namespace mlir::tpu \ No newline at end of file From 2737a7358b40839cfeffa91623265876cc80a37a Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 14 Aug 2024 14:42:44 -0700 Subject: [PATCH 134/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/aa2340049456d45f3b1fd7b09acc8bcf9d50b749. PiperOrigin-RevId: 663060585 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a6db1a1289c5..28551a554b70 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "cb1541c5f092807fced9e5e2b261371dba888906" -XLA_SHA256 = "3069f3a3e232ee3c1e0fe4a765f5d20879fd64c383256adf01b33b8fe775f0e6" +XLA_COMMIT = "aa2340049456d45f3b1fd7b09acc8bcf9d50b749" +XLA_SHA256 = "92cd501640e7962e90641c0fd25742a3c72c184cb10571753395efa5e9556102" def repo(): tf_http_archive( From 91f55129650e06c26d5aec86293b8834e08335f1 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 14 Aug 2024 15:37:20 -0700 Subject: [PATCH 135/702] Document methods of custom_jvp/custom_vjp --- docs/jax.rst | 21 +++++++- jax/_src/custom_derivatives.py | 95 ++++++++++++++++++++-------------- 2 files changed, 76 insertions(+), 40 deletions(-) diff --git a/docs/jax.rst b/docs/jax.rst index 7be3e63015d9..96b37d436c62 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -99,12 +99,29 @@ Automatic differentiation linearize linear_transpose vjp - custom_jvp - custom_vjp custom_gradient closure_convert checkpoint +``custom_jvp`` +~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: _autosummary + + custom_jvp + custom_jvp.defjvp + custom_jvp.defjvps + +``custom_vjp`` +~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: _autosummary + + custom_vjp + custom_vjp.defvjp + jax.Array (:code:`jax.Array`) ----------------------------- diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 9a6253b1bef9..64a37b782358 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -152,10 +152,11 @@ def defjvp(self, ``nondiff_argnums``, the ``jvp`` function should accept two arguments, where the first is a tuple of primal inputs and the second is a tuple of tangent inputs. The lengths of both tuples are equal to the number of - parameters of the ``custom_jvp`` function. The ``jvp`` function should - produce as output a pair where the first element is the primal output - and the second element is the tangent output. Elements of the input and - output tuples may be arrays or any nested tuples/lists/dicts thereof. + parameters of the :class:`~jax.custom_jvp` function. The ``jvp`` function + should produce as output a pair where the first element is the primal + output and the second element is the tangent output. Elements of the + input and output tuples may be arrays or any nested tuples/lists/dicts + thereof. symbolic_zeros: boolean, indicating whether the rule should be passed objects representing static symbolic zeros in its tangent argument in correspondence with unperturbed values; otherwise, only standard JAX @@ -166,48 +167,60 @@ def defjvp(self, ``False``. Returns: - None. + Returns ``jvp`` so that ``defjvp`` can be used as a decorator. Examples: - @jax.custom_jvp - def f(x, y): - return jnp.sin(x) * y - - @f.defjvp - def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - primal_out = f(x, y) - tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot - return primal_out, tangent_out + >>> @jax.custom_jvp + ... def f(x, y): + ... return jnp.sin(x) * y + ... + >>> @f.defjvp + ... def f_jvp(primals, tangents): + ... x, y = primals + ... x_dot, y_dot = tangents + ... primal_out = f(x, y) + ... tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot + ... return primal_out, tangent_out + + >>> x = jnp.float32(1.0) + >>> y = jnp.float32(2.0) + >>> with jnp.printoptions(precision=2): + ... print(jax.value_and_grad(f)(x, y)) + (Array(1.68, dtype=float32), Array(1.08, dtype=float32)) """ self.jvp = jvp self.symbolic_zeros = symbolic_zeros return jvp - def defjvps(self, *jvps: Callable[..., ReturnValue] | None): + def defjvps(self, *jvps: Callable[..., ReturnValue] | None) -> None: """Convenience wrapper for defining JVPs for each argument separately. This convenience wrapper cannot be used together with ``nondiff_argnums``. Args: *jvps: a sequence of functions, one for each positional argument of the - ``custom_jvp`` function. Each function takes as arguments the tangent - value for the corresponding primal input, the primal output, and the - primal inputs. See the example below. + :class:`~jax.custom_jvp` function. Each function takes as arguments + the tangent value for the corresponding primal input, the primal + output, and the ßprimal inputs. See the example below. Returns: None. Examples: - @jax.custom_jvp - def f(x, y): - return jnp.sin(x) * y - - f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, - lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot) + >>> @jax.custom_jvp + ... def f(x, y): + ... return jnp.sin(x) * y + ... + >>> f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, + ... lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot) + + >>> x = jnp.float32(1.0) + >>> y = jnp.float32(2.0) + >>> with jnp.printoptions(precision=2): + ... print(jax.value_and_grad(f)(x, y)) + (Array(1.68, dtype=float32), Array(1.08, dtype=float32)) """ if self.nondiff_argnums: raise TypeError("Can't use ``defjvps`` with ``nondiff_argnums``.") @@ -560,18 +573,24 @@ def defvjp(self, Examples: - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd) + >>> @jax.custom_vjp + ... def f(x, y): + ... return jnp.sin(x) * y + ... + >>> def f_fwd(x, y): + ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) + ... + >>> def f_bwd(res, g): + ... cos_x, sin_x, y = res + ... return (cos_x * g * y, sin_x * g) + ... + >>> f.defvjp(f_fwd, f_bwd) + + >>> x = jnp.float32(1.0) + >>> y = jnp.float32(2.0) + >>> with jnp.printoptions(precision=2): + ... print(jax.value_and_grad(f)(x, y)) + (Array(1.68, dtype=float32), Array(1.08, dtype=float32)) """ self.fwd = fwd self.bwd = bwd From a18561df1a16b13a51149c642d80e8fefd049c4a Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 14 Aug 2024 17:52:11 -0700 Subject: [PATCH 136/702] [Pallas] Add run_scoped interpret mode rule + Enable DMA tests. PiperOrigin-RevId: 663115966 --- jax/_src/pallas/core.py | 29 ++++++-- jax/_src/pallas/mosaic/core.py | 2 + jax/_src/pallas/mosaic/lowering.py | 2 +- jax/_src/pallas/pallas_call.py | 40 +++++------ jax/_src/pallas/primitives.py | 53 +++++++++++++++ .../jax2tf/tests/primitives_test.py | 2 +- tests/pallas/tpu_pallas_test.py | 67 +++++++++++++------ 7 files changed, 148 insertions(+), 47 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index e99510f8d499..0ef208f755e5 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -182,6 +182,8 @@ def size(self, axis: int) -> int | DynamicGridDim: @dataclasses.dataclass class PallasTracingEnv(threading.local): grid_context: PallasGridContext | None = None + grid_env_stack: list[GridEnv] = dataclasses.field(default_factory=list) + is_interpret_mode: bool = False _pallas_tracing_env = PallasTracingEnv() @@ -202,22 +204,35 @@ class GridAxis: # Stores the kernel execution position and the size along grid axes. GridEnv = Sequence[GridAxis] -_grid_env_stack: list[GridEnv] = [] - - @contextlib.contextmanager def grid_env(env: GridEnv) -> Iterator[None]: - _grid_env_stack.append(env) + _pallas_tracing_env.grid_env_stack.append(env) try: yield finally: - _grid_env_stack.pop() + _pallas_tracing_env.grid_env_stack.pop() def current_grid_env() -> GridEnv | None: - if not _grid_env_stack: + if not _pallas_tracing_env.grid_env_stack: return None - return _grid_env_stack[-1] + return _pallas_tracing_env.grid_env_stack[-1] + + +@contextlib.contextmanager +def interpret_mode_env(interpret_mode: bool) -> Iterator[None]: + prev_interpret = _pallas_tracing_env.is_interpret_mode + if interpret_mode: + _pallas_tracing_env.is_interpret_mode = True + try: + yield + finally: + if interpret_mode: + _pallas_tracing_env.is_interpret_mode = prev_interpret + +def is_interpret_mode() -> bool: + """Returns whether the kernel is executing in interpret mode.""" + return _pallas_tracing_env.is_interpret_mode class Mapped: diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 75e5101de142..ca333f97626c 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -109,6 +109,8 @@ def __call__(self, shape: tuple[int, ...]): dtype = BarrierSemaphoreTy() else: dtype = SemaphoreTy() + if pallas_core.is_interpret_mode(): + dtype = jnp.int32 return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) def get_aval(self) -> AbstractMemoryRef: diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index caf35316af63..cb7d16bc5108 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2528,7 +2528,7 @@ def _bitcast_convert_type_lowering_rule( def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value: if isinstance(aval, pallas_core.AbstractMemoryRef): - memspace = ir.Attribute.parse(f"#tpu.memory_space<{aval.memory_space}>") + memspace = _memory_space_to_tpu_memspace(aval.memory_space) if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype): assert aval.memory_space == TPUMemorySpace.SEMAPHORE memref_type = aval_to_ir_type(aval, memory_space=TPUMemorySpace.SEMAPHORE) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index f6ee5381adc8..4f3c9918f664 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -192,7 +192,7 @@ def _pallas_call_impl_interpret( with grid_mapping.trace_env(): discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ()) if debug: - print(f"\nJaxpr the the kernel in pallas_call {name_and_src_info}:") + print(f"\nJaxpr of the the kernel in pallas_call {name_and_src_info}:") print(discharged_jaxpr) out = _initialize_output_vals(grid_mapping.block_mappings_output, args, input_output_aliases) @@ -889,7 +889,7 @@ def _trace_kernel_to_jaxpr(fun: Callable, grid_mapping: GridMapping, kernel_avals: tuple[pallas_core.AbstractMemRef, ...], kernel_in_tree: tree_util.PyTreeDef, - interpret: bool + interpret: bool, ) -> jax_core.ClosedJaxpr: if interpret: kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval, @@ -1126,10 +1126,11 @@ def wrapped(*args): flat_in_avals, in_tree, in_origins, flat_out_avals, out_tree, out_origins) flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals) - jaxpr = _trace_kernel_to_jaxpr( - kernel, kernel_src_info, - grid_mapping, tuple(flat_kernel_avals), kernel_in_tree, - interpret=interpret) + with pallas_core.interpret_mode_env(interpret): + jaxpr = _trace_kernel_to_jaxpr( + kernel, kernel_src_info, + grid_mapping, tuple(flat_kernel_avals), kernel_in_tree, + interpret=interpret) for i_idx, o_idx in input_output_aliases.items(): if i_idx not in range(len(flat_in_avals)): raise ValueError( @@ -1152,19 +1153,20 @@ def wrapped(*args): f"a different abstract value {out_aval}.") index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands]) - out_flat = pallas_call_p.bind( - *dynamic_grid_bounds, - *index_args, - *rest_args, - jaxpr=jaxpr, - name_and_src_info=name_and_src_info, - debug=debug, - interpret=interpret, - grid_mapping=grid_mapping, - input_output_aliases=tuple(input_output_aliases.items()), - compiler_params=compiler_params, - cost_estimate=cost_estimate, - ) + with pallas_core.interpret_mode_env(interpret): + out_flat = pallas_call_p.bind( + *dynamic_grid_bounds, + *index_args, + *rest_args, + jaxpr=jaxpr, + name_and_src_info=name_and_src_info, + debug=debug, + interpret=interpret, + grid_mapping=grid_mapping, + input_output_aliases=tuple(input_output_aliases.items()), + compiler_params=compiler_params, + cost_estimate=cost_estimate, + ) out = tree_util.tree_unflatten(out_tree, out_flat) return out return wrapped diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 7ba5fa27791f..db364820e443 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -425,6 +425,8 @@ def uninitialized_value(shape, dtype): return jnp.full(shape, jnp.iinfo(dtype).min, dtype) elif jnp.issubdtype(dtype, jnp.bool): return jnp.full(shape, False, dtype) + elif jnp.issubdtype(dtype, pallas_core.semaphore_dtype): + return jnp.full(shape, 0, dtype) raise NotImplementedError(dtype) def _pad_values_to_avoid_dynamic_slice_oob_shift(value, @@ -843,3 +845,54 @@ def _run_scoped_abstract_eval(*args, jaxpr): ) } return [v.aval for v in jaxpr.outvars], nonlocal_effects + + +def _run_scoped_discharge_rule(in_avals, + out_avals, + *args_flat, + jaxpr, + **_): + del out_avals + num_consts = len(args_flat) + jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr) + num_return_values = len(jaxpr_noconst.outvars) + discharged_body, new_consts = state_discharge.discharge_state( + jaxpr_noconst, []) + if new_consts: + raise NotImplementedError( + "Cannot handle new consts created by state discharge.") + # Create inputs filled with uninitialized values to the body. + body_avals = [v.aval for v in discharged_body.invars[num_consts:]] + init_vals = [uninitialized_value( + aval.shape, aval.dtype) for aval in body_avals] + init_vals_with_consts = args_flat + tuple(init_vals) + out = jax_core.eval_jaxpr(discharged_body, [], *init_vals_with_consts) + # Order of outputs: + # (1) return values, (2) closed refs, (3) scoped refs. + return_values = out[:num_return_values] + ref_outputs = out[num_return_values:] + # We update all ref values with their updated values from the discharged + # body. For other values we leave them in place. + updates = [ + ref_outputs.pop(0) if isinstance(aval, pallas_core.AbstractMemoryRef) + else None for aval in in_avals] + assert len(ref_outputs) == len( + body_avals), f'{len(body_avals)}, != {len(ref_outputs)}' + assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}' + return updates, return_values + + +state_discharge.register_discharge_rule(run_scoped_p)( + _run_scoped_discharge_rule) + + +@functools.partial(mlir.register_lowering, run_scoped_p) +def _run_scoped_lowering_rule(ctx, *args, jaxpr): + # This lowering rule gets triggered when run_scoped is not discharged. + # In this case there are no stateful effects to handle. + def _lower_fun(*lower_fun_args): + updates, out = _run_scoped_discharge_rule([], [], *lower_fun_args, + jaxpr=jaxpr) + assert len(updates) == 0, 'Cannot lower run_scoped with effects.' + return out + return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args) diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 40af7959cd4b..51a6d45556bd 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -178,7 +178,7 @@ def test_primitive_coverage(self): if p.name == "debug_callback" or p.name == "debug_print": # TODO(sharadmv,necula): enable debug callbacks in TF continue - if p.name in ("max_contiguous", "multiple_of"): + if p.name in ("max_contiguous", "multiple_of", "run_scoped"): # Pallas-specific primitives are not supported. continue if p.name == "pallas_call": diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 5bf4cf2804ac..99c2ab117c42 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -883,8 +883,12 @@ def kernel(y_ref): def body(dma_sems, sems): self.assertTupleEqual(dma_sems.shape, (4,)) self.assertTupleEqual(sems.shape, (3,)) - self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) - self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) + if self.INTERPRET: + self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.int32)) + self.assertTrue(jnp.issubdtype(sems.dtype, jnp.int32)) + else: + self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) + self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) pl.run_scoped( body, pltpu.SemaphoreType.DMA((4,)), pltpu.SemaphoreType.REGULAR((3,)) ) @@ -898,10 +902,13 @@ def test_can_allocate_scratch_semaphore_array(self): def kernel(y_ref, dma_sems, sems): self.assertTupleEqual(dma_sems.shape, (4,)) self.assertTupleEqual(sems.shape, (3,)) - self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) - self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) + if self.INTERPRET: + self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.int32)) + self.assertTrue(jnp.issubdtype(sems.dtype, jnp.int32)) + else: + self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) + self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) - # TODO(b/345534352): Add interpret support for REGULAR semaphore. jax.block_until_ready( self.pallas_call( kernel, @@ -917,6 +924,10 @@ def kernel(y_ref, dma_sems, sems): ) def test_can_wait_on_semaphore(self): + # TODO(b/345534352): Add interpret support for semaphore signal/wait. + if self.INTERPRET: + self.skipTest('Semaphore signal/wait not supported in interpret mode.') + def kernel(y_ref): def body(sem): pltpu.semaphore_signal(sem) @@ -943,6 +954,10 @@ def body3(sem): )()) def test_can_wait_on_semaphore_array(self): + # TODO(b/345534352): Add interpret support for semaphore signal/wait. + if self.INTERPRET: + self.skipTest('Semaphore signal/wait not supported in interpret mode.') + def kernel(y_ref): def body(sems): pltpu.semaphore_signal(sems.at[0]) @@ -961,12 +976,16 @@ def body(sems): pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((3,))) # TODO(b/345534352): Add interpret support for semaphore signal/wait. - jax.block_until_ready(pl.pallas_call( + jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) def test_can_wait_on_semaphore_array_with_dynamic_index(self): + # TODO(b/345534352): Add interpret support for semaphore signal/wait. + if self.INTERPRET: + self.skipTest('Semaphore signal/wait not supported in interpret mode.') + def kernel(y_ref): i = pl.program_id(0) def body(sems): @@ -985,19 +1004,21 @@ def body(sems): pltpu.semaphore_wait(sems.at[i, 2]) pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((4, 3))) - # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready( - pl.pallas_call( + self.pallas_call( kernel, in_specs=[], out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), grid=4, - debug=True, )() ) def test_can_read_semaphore(self): + # TODO(b/345534352): Add interpret support for semaphore signal/wait. + if self.INTERPRET: + self.skipTest('Semaphore signal/wait not supported in interpret mode.') + m, n = 2, 3 def kernel(y_ref): @@ -1013,7 +1034,7 @@ def body(sems): # TODO(b/345534352): Add interpret support for semaphore signal/wait. y = jax.block_until_ready( - pl.pallas_call( + self.pallas_call( kernel, out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32), @@ -1024,6 +1045,9 @@ def body(sems): ) def test_can_read_dma_semaphore(self): + # TODO(b/345534352): Add interpret support for semaphore signal/wait. + if self.INTERPRET: + self.skipTest('Semaphore signal/wait not supported in interpret mode.') def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): sem_val_ref[0, 0] = 123 @@ -1032,7 +1056,7 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) # TODO(b/345534352): Add interpret support for semaphore signal/wait. y, sem_val = jax.block_until_ready( - pl.pallas_call( + self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, @@ -1152,7 +1176,7 @@ def test_vmem_hbm_dma(self): def kernel(x_ref, y_hbm_ref): def body(y_ref, sem): y_ref[...] = x_ref[...] - pltpu.async_copy(y_hbm_ref, y_ref, sem).wait() + pltpu.async_copy(y_ref, y_hbm_ref, sem).wait() pl.run_scoped( body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA ) @@ -1287,6 +1311,9 @@ def body(sem): np.testing.assert_allclose(y, x.reshape((16, 128))) def test_hbm_vmem_dma_multiple_indexing(self): + if self.INTERPRET: + self.skipTest('Multiple indexing not supported in interpret mode.') + def kernel(x_hbm_ref, y_ref): def body(sem): for i in range(3): @@ -1313,6 +1340,9 @@ def body(sem): np.testing.assert_allclose(y, x.reshape((3, 16, 128))) def test_cannot_squeeze_lane_sublane(self): + if self.INTERPRET: + self.skipTest('Only works on Mosaic TPU.') + def kernel(x_hbm_ref, y_ref): def body(sem): dma1 = pltpu.async_copy( @@ -1335,11 +1365,7 @@ def body(sem): out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) - @parameterized.named_parameters( - ('', False), - ('_interpret', True), - ) - def test_hoisted_scratch_space(self, interpret): + def test_hoisted_scratch_space(self): def kernel(x_ref, y_ref, scratch_ref): i = pl.program_id(0) @pl.when(i == 0) @@ -1352,7 +1378,7 @@ def _(): y_ref[...] = scratch_ref[...] x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, @@ -1363,7 +1389,6 @@ def _(): out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), grid=(3,), ), - interpret=interpret, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x + 3) @@ -1441,6 +1466,10 @@ def kernel(index, x, y, sem): np.testing.assert_array_equal(y, i) del y + +class PallasCallDMAInterpreterTest(PallasCallDMATest): + INTERPRET = True + def test_interpret_local_dma(self): def test_kernel(x_ref, o_ref, From 82d3cfb3c6f88321f0b29b4cc41134a464de82c2 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 14 Aug 2024 18:22:13 -0700 Subject: [PATCH 137/702] [Pallas] Fix boolean vector loads with indexing. PiperOrigin-RevId: 663124475 --- jax/_src/pallas/mosaic/lowering.py | 12 ++++++------ tests/pallas/tpu_pallas_test.py | 13 +++++++------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index cb7d16bc5108..86ce2f0b1b81 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1088,12 +1088,12 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): else: load_val = vector.LoadOp( aval_to_ir_type(load_aval, is_kernel_boundary=True), ref, starts).result - load_val = _maybe_cast_load_to_bool(aval_out, load_val) - if load_aval == aval_out: - return load_val - vec_type = ir.VectorType.get(aval_out.shape, - _dtype_to_ir_type(aval_out.dtype)) - return vector.ShapeCastOp(vec_type, load_val).result + if load_aval != aval_out: + vec_type = ir.VectorType.get(aval_out.shape, + _dtype_to_ir_type(aval_out.dtype, + is_kernel_boundary=True)) + load_val = vector.ShapeCastOp(vec_type, load_val).result + return _maybe_cast_load_to_bool(aval_out, load_val) def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle: """Lowering rule for loading PRNG keys from SMEM. diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 99c2ab117c42..20de5b585ac4 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2031,15 +2031,16 @@ def kernel(x_ref, o_ref): )(input) np.testing.assert_array_equal(result, input) - def test_vector_bool_masking(self): + def test_vector_bool_masking_with_indexing(self): def kernel(mask_ref, true_ref, false_ref, o_ref): - o_ref[...] = jnp.where(mask_ref[...], true_ref[...], false_ref[...]) + o_ref[0, ...] = jnp.where( + mask_ref[0, ...], true_ref[0, ...], false_ref[0, ...]) key = jax.random.key(0) k1, k2, k3 = jax.random.split(key, 3) - values_1 = jax.random.normal(k1, (8, 128), jnp.float32) - values_2 = jax.random.normal(k2, (8, 128), jnp.float32) - mask = jax.random.bernoulli(k3, p=0.5, shape=(8, 128)) - output_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) + values_1 = jax.random.normal(k1, (1, 256, 256), jnp.float32) + values_2 = jax.random.normal(k2, (1, 256, 256), jnp.float32) + mask = jax.random.bernoulli(k3, p=0.5, shape=(1, 256, 256)) + output_shape = jax.ShapeDtypeStruct((1, 256, 256), jnp.float32) result = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM), From 6913551d8d35b32b7f5a994c2284f7b9509297c7 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 15 Aug 2024 08:10:00 -0700 Subject: [PATCH 138/702] If `AbstractMesh` is an input to `shard_map`, then in eager mode require atleast one input to be a `NamedSharding` not all inputs. PiperOrigin-RevId: 663310336 --- jax/experimental/shard_map.py | 7 +------ tests/shard_map_test.py | 6 +++--- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index fa75c48292bf..656f97027c18 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -717,12 +717,7 @@ def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: def get_mesh_from_args(args_flat, mesh): for a in args_flat: - if hasattr(a, 'sharding'): - if not isinstance(a.sharding, NamedSharding): - raise TypeError( - "shard_map got `AbstractMesh` as an input to the `mesh` argument" - " which requires the input's sharding to be a `NamedSharding`. Got" - f" sharding type {type(a.sharding)}") + if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding): if a.sharding.mesh.shape_tuple != mesh.shape_tuple: raise ValueError( f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 7038c0043048..16daaca6d48e 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -812,9 +812,9 @@ def test_shmap_abstract_mesh_errors(self): abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) with self.assertRaisesRegex( - TypeError, - 'shard_map got `AbstractMesh` as an input to the `mesh` argument' - " which requires the input's sharding to be a `NamedSharding`"): + ValueError, + "Please pass `jax.Array`s with a `NamedSharding` as input to" + " `shard_map` when passing `AbstractMesh` to the mesh argument"): shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))(jnp.arange(8)) From 322d0c2f31e92e68a531f95a53c3f040d6a76bdf Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Thu, 15 Aug 2024 09:00:06 -0700 Subject: [PATCH 139/702] Rollback the change "Import from ``mlir.dialects`` lazily" Reverts a755f1db837c464f6aa3d3111a1bc40b5ebdd37d PiperOrigin-RevId: 663324497 --- jax/_src/lazy_loader.py | 25 +++------- jax/_src/lib/mlir/dialects/__init__.py | 64 ++++++++++---------------- 2 files changed, 32 insertions(+), 57 deletions(-) diff --git a/jax/_src/lazy_loader.py b/jax/_src/lazy_loader.py index 5150f38111c3..cf6e68e49c81 100644 --- a/jax/_src/lazy_loader.py +++ b/jax/_src/lazy_loader.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Lazy loading APIs.""" +"""A LazyLoader class.""" from collections.abc import Callable, Sequence import importlib -import sys from typing import Any @@ -27,27 +26,17 @@ def attach(package_name: str, submodules: Sequence[str]) -> tuple[ ]: """Lazily loads submodules of a package. - Returns: - A tuple of ``__getattr__``, ``__dir__`` function and ``__all__`` -- - a list of available global names, which can be used to replace the - corresponding definitions in the package. - - Raises: - RuntimeError: If the ``__name__`` of the caller cannot be determined. + Example use: + ``` + __getattr__, __dir__, __all__ = lazy_loader.attach(__name__, ["sub1", "sub2"]) + ``` """ - owner_name = sys._getframe(1).f_globals.get("__name__") - if owner_name is None: - raise RuntimeError("Cannot determine the ``__name__`` of the caller.") - __all__ = list(submodules) + __all__: list[str] = list(submodules) def __getattr__(name: str) -> Any: if name in submodules: - value = importlib.import_module(f"{package_name}.{name}") - # Update module-level globals to avoid calling ``__getattr__`` again - # for this ``name``. - setattr(sys.modules[owner_name], name, value) - return value + return importlib.import_module(f"{package_name}.{name}") raise AttributeError(f"module '{package_name}' has no attribute '{name}") def __dir__() -> list[str]: diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index a9bae8821db5..01dc7e2725b5 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -13,49 +13,35 @@ # limitations under the License. # ruff: noqa: F401 +from typing import Any -from typing import Any, TYPE_CHECKING - -if TYPE_CHECKING: - from jaxlib.mlir.dialects import arith as arith - from jaxlib.mlir.dialects import builtin as builtin - from jaxlib.mlir.dialects import chlo as chlo - from jaxlib.mlir.dialects import func as func - from jaxlib.mlir.dialects import gpu as gpu - from jaxlib.mlir.dialects import llvm as llvm - from jaxlib.mlir.dialects import math as math - from jaxlib.mlir.dialects import memref as memref - from jaxlib.mlir.dialects import mhlo as mhlo - from jaxlib.mlir.dialects import nvgpu as nvgpu - from jaxlib.mlir.dialects import nvvm as nvvm - from jaxlib.mlir.dialects import scf as scf - from jaxlib.mlir.dialects import sparse_tensor as sparse_tensor - from jaxlib.mlir.dialects import vector as vector -else: - from jax._src import lazy_loader as _lazy - __getattr__, __dir__, __all__ = _lazy.attach("jaxlib.mlir.dialects", [ - "arith", - "builtin", - "chlo", - "func", - "gpu", - "llvm", - "math", - "memref", - "mhlo", - "nvgpu", - "nvvm", - "scf", - "sparse_tensor", - "vector", - ]) - del _lazy - +import jaxlib.mlir.dialects.arith as arith +import jaxlib.mlir.dialects.builtin as builtin +import jaxlib.mlir.dialects.chlo as chlo +import jaxlib.mlir.dialects.func as func +import jaxlib.mlir.dialects.math as math +import jaxlib.mlir.dialects.memref as memref +import jaxlib.mlir.dialects.mhlo as mhlo +import jaxlib.mlir.dialects.scf as scf # TODO(bartchr): Once JAX is released with SDY, remove the try/except. try: - from jaxlib.mlir.dialects import sdy as sdy + import jaxlib.mlir.dialects.sdy as sdy except ImportError: sdy: Any = None # type: ignore[no-redef] +import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor +import jaxlib.mlir.dialects.vector as vector +try: + # pytype: disable=import-error + import jaxlib.mlir.dialects.gpu as gpu + import jaxlib.mlir.dialects.nvgpu as nvgpu + import jaxlib.mlir.dialects.nvvm as nvvm + import jaxlib.mlir.dialects.llvm as llvm + # pytype: enable=import-error +except ImportError: + pass + +from jax._src import lib + # Alias that is set up to abstract away the transition from MHLO to StableHLO. -from jaxlib.mlir.dialects import stablehlo as hlo +import jaxlib.mlir.dialects.stablehlo as hlo From 1516d59744e629f4a6f5c4930c013c1ad0209a80 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 15 Aug 2024 09:32:43 -0700 Subject: [PATCH 140/702] Reverts 6fc57c0eb6f06b2da20c94f5f127fe4a551bda09 PiperOrigin-RevId: 663334727 --- jax/_src/ops/special.py | 36 +++++++----------------------------- tests/lax_scipy_test.py | 10 +--------- 2 files changed, 8 insertions(+), 38 deletions(-) diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index 45b26b0de4d3..59ad594ef2bc 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -14,12 +14,12 @@ from __future__ import annotations -from typing import Literal, overload +from typing import overload, Literal import jax from jax import lax from jax import numpy as jnp -from jax._src.numpy.reductions import Axis, _reduction_dims +from jax._src.numpy.reductions import _reduction_dims, Axis from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike import numpy as np @@ -40,7 +40,6 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: ... - def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: r"""Log-sum-exp reduction. @@ -72,22 +71,18 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, """ if b is not None: a_arr, b_arr = promote_args_inexact("logsumexp", a, b) - a_masked = jnp.where(b_arr != 0, a_arr, -jnp.inf) + a_arr = jnp.where(b_arr != 0, a_arr, -jnp.inf) else: a_arr, = promote_args_inexact("logsumexp", a) b_arr = a_arr # for type checking - a_masked = a_arr pos_dims, dims = _reduction_dims(a_arr, axis) - amax = jnp.max( - a_masked.real, axis=dims, keepdims=keepdims, where=where, initial=-jnp.inf - ) + amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-jnp.inf) amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) - if b is None: - exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) - else: - exp_a = _stable_mulexp(a_arr - amax_with_dims.astype(a_arr.dtype), b_arr) + exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) + if b is not None: + exp_a = lax.mul(exp_a, b_arr) sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where) sign = lax.sign(sumexp) if return_sign or not np.issubdtype(a_arr.dtype, np.complexfloating): @@ -100,20 +95,3 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, with jax.debug_nans(False): out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out - - -@jax.custom_jvp -def _stable_mulexp(a_scaled: Array, b: Array) -> Array: - # This helper ensures that the output of logsumexp depends on b for b == 0. - # See https://github.com/google/jax/issues/22398. - a_scaled = jnp.where(b != 0, a_scaled, -jnp.inf) - return lax.mul(lax.exp(a_scaled), b) - - -@_stable_mulexp.defjvp -def _stable_mulexp_jvp(primals, tangents): - a_scaled, b = primals - da, db = tangents - out = _stable_mulexp(a_scaled, b) - dout = _stable_mulexp(a_scaled, db) + da * out - return out, dout diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 50d2ee7259dd..b3c373b63424 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -20,9 +20,9 @@ from absl.testing import absltest import numpy as np -import scipy.cluster as osp_cluster import scipy.integrate import scipy.special as osp_special +import scipy.cluster as osp_cluster import jax import jax.dtypes @@ -202,14 +202,6 @@ def testLogSumExpWhere(self, shape, dtype): y_actual = lsp_special.logsumexp(x, where=mask) self.assertAllClose(y_expected, y_actual, check_dtypes=False) - def testLogSumExpZerosJac(self): - # Regression test for https://github.com/google/jax/issues/22398 - fun = lambda b: lsp_special.logsumexp(jnp.zeros(2), axis=0, b=b) - np.testing.assert_array_equal( - jax.jacfwd(fun)(jnp.array([1.0, 0.0])), - jnp.ones(2), - ) - @jtu.sample_product( shape=all_shapes, dtype=float_dtypes, From a498c1e66836a1b8f3c2ce67c53ba0f6caeadbc1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 15 Aug 2024 13:35:41 -0700 Subject: [PATCH 141/702] Set Clang as the default compiler in the build script. PiperOrigin-RevId: 663433112 --- build/build.py | 1 + 1 file changed, 1 insertion(+) diff --git a/build/build.py b/build/build.py index f2920fba6221..2f3d54addedd 100755 --- a/build/build.py +++ b/build/build.py @@ -390,6 +390,7 @@ def main(): add_boolean_argument( parser, "use_clang", + default = "true", help_str=( "Should we build using clang as the host compiler? Requires " "clang to be findable via the PATH, or a path to be given via " From 417fcd574b9f33410ea8eb78ffdea825ad343eee Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 15 Aug 2024 14:29:30 -0700 Subject: [PATCH 142/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/1db3272ca01754dd38827f4ea332a2f136df5d05. PiperOrigin-RevId: 663454724 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 28551a554b70..b87cb1fc345c 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "aa2340049456d45f3b1fd7b09acc8bcf9d50b749" -XLA_SHA256 = "92cd501640e7962e90641c0fd25742a3c72c184cb10571753395efa5e9556102" +XLA_COMMIT = "1db3272ca01754dd38827f4ea332a2f136df5d05" +XLA_SHA256 = "ca5821e7f95e1d26f420619daed6ee6449bbd58e76f6f91dbde63fc72c1ced1c" def repo(): tf_http_archive( From fd7c52d213dea9bb5043fa4b88e9c74e725b478f Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Thu, 15 Aug 2024 21:33:51 -0500 Subject: [PATCH 143/702] [ROCm] Fix python in rocm ci_build script. --- build/rocm/ci_build | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 31ee591aedd0..43f34f6ca758 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -118,7 +118,7 @@ def _fetch_jax_metadata(xla_path): except Exception as ex: LOG.warning("Exception while retrieving xla_commit: %s" % ex) - cmd = ["python", "setup.py", "-V"] + cmd = ["python3", "setup.py", "-V"] env = dict(os.environ) env["JAX_RELEASE"] = "1" From 9785368c7feb76c78bb5b76f803cce1e2342784c Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 16 Aug 2024 00:32:00 -0700 Subject: [PATCH 144/702] [Easy] Refactor ragged_dot transpose, combine ragged_to_dense PiperOrigin-RevId: 663630185 --- jax/_src/lax/lax.py | 51 ++++++++++++++++++--------------------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 1e1f2d48c538..ca0096c2e05c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3039,6 +3039,22 @@ def _ragged_dot_jvp_rule( return primal_out, tangent_out +def _ragged_to_dense(x, y, group_sizes): + shape = (y.shape[0], x.shape[0], x.shape[1]) + x = broadcast_in_dim(x, shape, [1, 2]) + iota = broadcasted_iota(group_sizes.dtype, shape, 1) + group_ends = jax.lax.cumsum(group_sizes) + group_starts = concatenate( + [_zeros(group_sizes)[:1], group_ends[:-1]], + dimension=0, + ) + group_ends = broadcast_in_dim(group_ends, shape, (0,)) + group_starts = broadcast_in_dim(group_starts, shape, (0,)) + mask = bitwise_and(group_starts <= iota, iota < group_ends) + x = select(mask, x, _zeros(x)) + return x + + def _ragged_dot_transpose_rule( ct, *operands, precision, preferred_element_type, group_offset ): @@ -3046,24 +3062,6 @@ def _ragged_dot_transpose_rule( if group_offset is not None: raise NotImplementedError('Unimplemented group_offset support.') - def ragged_to_dense(x, group_sizes): - group_count = group_sizes.shape[0] - shape = (group_count, x.shape[0], x.shape[1]) - x_broadcasted = jax.lax.broadcast_in_dim(x, shape, (1, 2)) - iota = jax.lax.broadcasted_iota(group_sizes.dtype, shape, 1) - group_ends = jax.lax.cumsum(group_sizes) - group_starts = concatenate( - [ - np.zeros_like([group_ends[0]], dtype=group_sizes.dtype), - group_ends[:-1], - ], - 0, - ) - group_ends = jax.lax.broadcast_in_dim(group_ends, shape, (0,)) - group_starts = jax.lax.broadcast_in_dim(group_starts, shape, (0,)) - mask = (group_starts <= iota) & (iota < group_ends) - return jax.numpy.where(mask, x_broadcasted, 0) - if ad.is_undefined_primal(y): grad_x = None else: @@ -3079,8 +3077,9 @@ def ragged_to_dense(x, group_sizes): if ad.is_undefined_primal(x): grad_y = None else: - x_dense = ragged_to_dense(x, gs) - ct_dense = ragged_to_dense(ct, gs) + y = y.aval if ad.is_undefined_primal(y) else y + x_dense = _ragged_to_dense(x, y, group_sizes=gs) + ct_dense = _ragged_to_dense(ct, y, group_sizes=gs) dimension_numbers = (([1], [1]), ([0], [0])) grad_y = jax.lax.dot_general( x_dense, @@ -3109,17 +3108,7 @@ def _ragged_dot_impl( ) -> Array: if group_offset is not None: raise NotImplementedError("Unimplemented group_offset support.") - shape = (rhs.shape[0], lhs.shape[0], lhs.shape[1]) - lhs = broadcast_in_dim(lhs, shape, [1, 2]) - iota = broadcasted_iota(group_sizes.dtype, shape, 1) - group_ends = jax.lax.cumsum(group_sizes) - group_starts = concatenate( - [_zeros(group_sizes)[:1], group_ends[:-1]], dimension=0, - ) - group_ends = broadcast_in_dim(group_ends, shape, (0,)) - group_starts = broadcast_in_dim(group_starts, shape, (0,)) - mask = bitwise_and(group_starts <= iota, iota < group_ends) - lhs = select(mask, lhs, _zeros(lhs)) + lhs = _ragged_to_dense(lhs, rhs, group_sizes=group_sizes) return dot_general( lhs, rhs, From acacf8884e9cc23071d1dfbd8ef8e467126000d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Fri, 16 Aug 2024 01:20:02 -0700 Subject: [PATCH 145/702] Determine LAPACK workspace during QR Factorization Kernel runtime PiperOrigin-RevId: 663641199 --- jaxlib/cpu/_lapack/__init__.pyi | 4 ---- jaxlib/cpu/lapack.cc | 12 ------------ jaxlib/cpu/lapack_kernels.cc | 31 +++++++++++++++---------------- jaxlib/cpu/lapack_kernels.h | 9 ++++----- 4 files changed, 19 insertions(+), 37 deletions(-) diff --git a/jaxlib/cpu/_lapack/__init__.pyi b/jaxlib/cpu/_lapack/__init__.pyi index 5fcb2a5ad50e..f2a4d943086a 100644 --- a/jaxlib/cpu/_lapack/__init__.pyi +++ b/jaxlib/cpu/_lapack/__init__.pyi @@ -51,13 +51,9 @@ def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matr # FFI Kernel LAPACK Workspace Size Queries def heevd_rwork_size_ffi(n: int) -> int: ... def heevd_work_size_ffi(n: int) -> int: ... -def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_dgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_sgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def syevd_iwork_size_ffi(n: int) -> int: ... def syevd_work_size_ffi(n: int) -> int: ... diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index c13608e813f5..83ed7610ced7 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -336,18 +336,6 @@ NB_MODULE(_lapack, m) { m.def("lapack_zhetrd_workspace", &Sytrd>::Workspace, nb::arg("lda"), nb::arg("n")); // FFI Kernel LAPACK Workspace Size Queries - m.def("lapack_sgeqrf_workspace_ffi", - &QrFactorization::GetWorkspaceSize, nb::arg("m"), - nb::arg("n")); - m.def("lapack_dgeqrf_workspace_ffi", - &QrFactorization::GetWorkspaceSize, nb::arg("m"), - nb::arg("n")); - m.def("lapack_cgeqrf_workspace_ffi", - &QrFactorization::GetWorkspaceSize, nb::arg("m"), - nb::arg("n")); - m.def("lapack_zgeqrf_workspace_ffi", - &QrFactorization::GetWorkspaceSize, nb::arg("m"), - nb::arg("n")); m.def("lapack_sorgqr_workspace_ffi", &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), nb::arg("n"), nb::arg("k")); diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 551800bae8f2..8b260d1408e4 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -309,17 +309,17 @@ template struct Geqrf>; template ffi::Error QrFactorization::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, - ffi::ResultBuffer tau, ffi::ResultBuffer info, - ffi::ResultBuffer work) { + ffi::ResultBuffer tau, ffi::ResultBuffer info) { auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); auto* x_out_data = x_out->typed_data(); auto* tau_data = tau->typed_data(); auto* info_data = info->typed_data(); - auto* work_data = work->typed_data(); + const int64_t work_size = GetWorkspaceSize(x_rows, x_cols); + auto work_data = AllocateScratchMemory(work_size); CopyIfDiffBuffer(x, x_out); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, + MaybeCastNoOverflow(work_size)); FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow(x_rows)); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); auto x_leading_dim_v = x_rows_v; @@ -327,8 +327,8 @@ ffi::Error QrFactorization::Kernel( const int64_t x_out_step{x_rows * x_cols}; const int64_t tau_step{std::min(x_rows, x_cols)}; for (int64_t i = 0; i < batch_count; ++i) { - fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data, work_data, - &workspace_dim_v, info_data); + fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data, + work_data.get(), &workspace_dim_v, info_data); x_out_data += x_out_step; tau_data += tau_step; ++info_data; @@ -1701,15 +1701,14 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*ipiv*/) \ .Ret<::xla::ffi::Buffer>(/*info*/)) -#define JAX_CPU_DEFINE_GEQRF(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, QrFactorization::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*tau*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/)) +#define JAX_CPU_DEFINE_GEQRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, QrFactorization::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*tau*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) #define JAX_CPU_DEFINE_ORGQR(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index d78fd1b8d3d3..8abf8e22daac 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -192,11 +192,10 @@ struct QrFactorization { inline static FnType* fn = nullptr; - static ::xla::ffi::Error Kernel(::xla::ffi::Buffer x, - ::xla::ffi::ResultBuffer x_out, - ::xla::ffi::ResultBuffer tau, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer work); + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer tau, + ::xla::ffi::ResultBuffer info); static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); }; From 12e8bf45259e2abca37588e94d1a734dbc4ddd12 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 16 Aug 2024 11:47:21 +0200 Subject: [PATCH 146/702] Pass bazel options to requirements_update and requirements_nightly_update commands --- build/build.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/build/build.py b/build/build.py index f2920fba6221..e49aee5949d8 100755 --- a/build/build.py +++ b/build/build.py @@ -622,20 +622,17 @@ def main(): python_version=python_version, ) - if args.requirements_update: + if args.requirements_update or args.requirements_nightly_update: + if args.requirements_update: + task = "//build:requirements.update" + else: # args.requirements_nightly_update + task = "//build:requirements_nightly.update" update_command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true", "//build:requirements.update"]) + ["run", "--verbose_failures=true", task, *args.bazel_options]) print(" ".join(update_command)) shell(update_command) return - if args.requirements_nightly_update: - update_nightly_command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true", "//build:requirements_nightly.update"]) - print(" ".join(update_nightly_command)) - shell(update_nightly_command) - return - if args.configure_only: return From b6306e395347c33feed55270a741b8be24fe76d0 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 16 Aug 2024 04:36:30 -0700 Subject: [PATCH 147/702] Remove synchronization from GPU LU decomposition kernel by adding an async batch pointers builder. In the batched LU decomposition in cuBLAS, the output buffer is required to be a pointer of pointers to the appropriate batch matrices. Previously this reshaping was done on the host and then copied to the device, requiring a synchronization, but it seems straightforward to instead implement a tiny CUDA kernel to do this work. This definitely isn't a bottleneck or a high priority change, but this seemed like a reasonable time to fix a longstanding TODO. PiperOrigin-RevId: 663686539 --- jaxlib/cuda/BUILD | 12 ++++++++ jaxlib/gpu/BUILD | 2 ++ jaxlib/gpu/blas_kernels.cc | 28 ++++++----------- jaxlib/gpu/gpu_kernel_helpers.cc | 15 --------- jaxlib/gpu/gpu_kernel_helpers.h | 10 ------ jaxlib/gpu/make_batch_pointers.cu.cc | 46 ++++++++++++++++++++++++++++ jaxlib/gpu/make_batch_pointers.h | 30 ++++++++++++++++++ jaxlib/gpu/solver_kernels_ffi.cc | 10 ++---- jaxlib/rocm/BUILD.bazel | 12 ++++++++ 9 files changed, 115 insertions(+), 50 deletions(-) create mode 100644 jaxlib/gpu/make_batch_pointers.cu.cc create mode 100644 jaxlib/gpu/make_batch_pointers.h diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 8121a1058768..b31ed78e3b58 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -72,6 +72,16 @@ cc_library( ], ) +cuda_library( + name = "cuda_make_batch_pointers", + srcs = ["//jaxlib/gpu:make_batch_pointers.cu.cc"], + hdrs = ["//jaxlib/gpu:make_batch_pointers.h"], + deps = [ + ":cuda_vendor", + "@local_config_cuda//cuda:cuda_headers", + ], +) + cc_library( name = "cuda_blas_handle_pool", srcs = ["//jaxlib/gpu:blas_handle_pool.cc"], @@ -95,6 +105,7 @@ cc_library( deps = [ ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", + ":cuda_make_batch_pointers", ":cuda_vendor", "//jaxlib:kernel_helpers", "@xla//xla/service:custom_call_status", @@ -223,6 +234,7 @@ cc_library( deps = [ ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", + ":cuda_make_batch_pointers", ":cuda_solver_handle_pool", ":cuda_vendor", "//jaxlib:ffi_helpers", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 6bdaf4ef1322..706cac6b46d4 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -36,6 +36,8 @@ exports_files(srcs = [ "linalg_kernels.cc", "linalg_kernels.cu.cc", "linalg_kernels.h", + "make_batch_pointers.cu.cc", + "make_batch_pointers.h", "prng.cc", "prng_kernels.cc", "prng_kernels.cu.cc", diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc index a963aa3fd762..ac30aa9cc520 100644 --- a/jaxlib/gpu/blas_kernels.cc +++ b/jaxlib/gpu/blas_kernels.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "jaxlib/gpu/blas_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/make_batch_pointers.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" @@ -69,13 +70,9 @@ static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers, int* ipiv = static_cast(buffers[2]); int* info = static_cast(buffers[3]); - auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch, - SizeOfBlasType(d.type) * d.n * d.n); - JAX_RETURN_IF_ERROR(a_ptrs_host.status()); - // TODO(phawkins): ideally we would not need to synchronize here, but to - // avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + MakeBatchPointersAsync(stream, buffers[1], buffers[4], d.batch, + SizeOfBlasType(d.type) * d.n * d.n); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); switch (d.type) { case BlasType::F32: { float** batch_ptrs = static_cast(buffers[4]); @@ -132,17 +129,12 @@ static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, } std::vector info(d.batch); - auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch, - SizeOfBlasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(a_ptrs_host.status()); - auto tau_ptrs_host = - MakeBatchPointers(stream, buffers[2], buffers[4], d.batch, - SizeOfBlasType(d.type) * std::min(d.m, d.n)); - JAX_RETURN_IF_ERROR(tau_ptrs_host.status()); - // TODO(phawkins): ideally we would not need to synchronize here, but to - // avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + MakeBatchPointersAsync(stream, buffers[1], buffers[3], d.batch, + SizeOfBlasType(d.type) * d.m * d.n); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); + MakeBatchPointersAsync(stream, buffers[2], buffers[4], d.batch, + SizeOfBlasType(d.type) * std::min(d.m, d.n)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); switch (d.type) { case BlasType::F32: { float** a_batch_ptrs = static_cast(buffers[3]); diff --git a/jaxlib/gpu/gpu_kernel_helpers.cc b/jaxlib/gpu/gpu_kernel_helpers.cc index f43122f2efaa..5a434f4b6ad5 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.cc +++ b/jaxlib/gpu/gpu_kernel_helpers.cc @@ -313,20 +313,5 @@ absl::Status AsStatus(cufftResult error, const char* file, std::int64_t line, } #endif -absl::StatusOr> MakeBatchPointers( - gpuStream_t stream, void* buffer, void* dev_ptrs, int batch, - int batch_elem_size) { - char* ptr = static_cast(buffer); - auto host_ptrs = absl::make_unique(batch); - for (int i = 0; i < batch; ++i) { - host_ptrs[i] = ptr; - ptr += batch_elem_size; - } - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpuMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch, - gpuMemcpyHostToDevice, stream))); - return std::move(host_ptrs); -} - } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/gpu_kernel_helpers.h b/jaxlib/gpu/gpu_kernel_helpers.h index 46fca7bc4bd4..aecb8a4fdcf1 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.h +++ b/jaxlib/gpu/gpu_kernel_helpers.h @@ -67,16 +67,6 @@ absl::Status AsStatus(cufftResult error, const char* file, std::int64_t line, const char* expr); #endif -// Builds an array of pointers to each array in a batch, in device memory. -// Caution: the return value must be kept alive (e.g., via a stream -// synchronization) until the copy enqueued by MakeBatchPointers on `stream` -// completes. -absl::StatusOr> MakeBatchPointers(gpuStream_t stream, - void* buffer, - void* dev_ptrs, - int batch, - int batch_elem_size); - } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc new file mode 100644 index 000000000000..b10655645924 --- /dev/null +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -0,0 +1,46 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/make_batch_pointers.h" + +#include + +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +namespace { +__global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out, + int batch, int batch_elem_size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch; + idx += blockDim.x * gridDim.x) { + buffer_out[idx] = buffer_in + idx * batch_elem_size; + } +} +} // namespace + +void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, + void* buffer_out, int batch, int batch_elem_size) { + const int block_dim = 128; + const std::size_t grid_dim = + std::min(1024, (batch + block_dim - 1) / block_dim); + MakeBatchPointersAsyncKernel<<>>( + static_cast(buffer_in), static_cast(buffer_out), batch, + batch_elem_size); +} + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/make_batch_pointers.h b/jaxlib/gpu/make_batch_pointers.h new file mode 100644 index 000000000000..f2fd064961e8 --- /dev/null +++ b/jaxlib/gpu/make_batch_pointers.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ +#define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ + +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, + void* buffer_out, int batch, int batch_elem_size); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 051b9fd03f9e..6deb89144ec7 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -23,6 +23,7 @@ limitations under the License. #include "jaxlib/ffi_helpers.h" #include "jaxlib/gpu/blas_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/make_batch_pointers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" @@ -142,13 +143,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, gpuMemcpyDeviceToDevice, stream))); } - FFI_ASSIGN_OR_RETURN( - auto a_ptrs_host, - MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n)); - // TODO(phawkins, danfm): ideally we would not need to synchronize here, but - // to avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + MakeBatchPointersAsync(stream, out_data, workspace, batch, sizeof(T) * n * n); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); auto batch_ptrs = static_cast(workspace); FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run( diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index ba9ceb4c3fa7..ce733d827e35 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -58,6 +58,16 @@ cc_library( ]), ) +rocm_library( + name = "hip_make_batch_pointers", + srcs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.cu.cc"], + hdrs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.h"], + deps = [ + ":hip_vendor", + "@local_config_rocm//rocm:rocm_headers", + ], +) + cc_library( name = "hip_blas_handle_pool", srcs = ["//jaxlib/gpu:blas_handle_pool.cc"], @@ -80,6 +90,7 @@ cc_library( deps = [ ":hip_blas_handle_pool", ":hip_gpu_kernel_helpers", + ":hip_make_batch_pointers", ":hip_vendor", "//jaxlib:kernel_helpers", "@com_google_absl//absl/algorithm:container", @@ -160,6 +171,7 @@ cc_library( deps = [ ":hip_blas_handle_pool", ":hip_gpu_kernel_helpers", + ":hip_make_batch_pointers", ":hip_solver_handle_pool", ":hip_vendor", "//jaxlib:ffi_helpers", From e9d6fd37953dfff79691552c50b3dbfdc6baa98a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 16 Aug 2024 06:37:19 -0700 Subject: [PATCH 148/702] document jax.Array methods and attributes --- docs/jax.rst | 67 ++++ jax/_src/numpy/array_methods.py | 591 ++++++++++++++++++++++++-------- 2 files changed, 507 insertions(+), 151 deletions(-) diff --git a/docs/jax.rst b/docs/jax.rst index fe4827c252bd..b2c4ba60739b 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -133,6 +133,73 @@ jax.Array (:code:`jax.Array`) make_array_from_single_device_arrays make_array_from_process_local_data +Array properties and methods +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: _autosummary + + Array.addressable_shards + Array.all + Array.any + Array.argmax + Array.argmin + Array.argpartition + Array.argsort + Array.astype + Array.at + Array.choose + Array.clip + Array.compress + Array.conj + Array.conjugate + Array.copy + Array.copy_to_host_async + Array.cumprod + Array.cumsum + Array.device + Array.diagonal + Array.dot + Array.dtype + Array.flat + Array.flatten + Array.global_shards + Array.imag + Array.is_fully_addressable + Array.is_fully_replicated + Array.item + Array.itemsize + Array.max + Array.mean + Array.min + Array.nbytes + Array.ndim + Array.nonzero + Array.prod + Array.ptp + Array.ravel + Array.real + Array.repeat + Array.reshape + Array.round + Array.searchsorted + Array.shape + Array.sharding + Array.size + Array.sort + Array.squeeze + Array.std + Array.sum + Array.swapaxes + Array.take + Array.to_device + Array.trace + Array.transpose + Array.var + Array.view + Array.T + Array.mT + Vectorization (:code:`vmap`) ---------------------------- diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 87635be37c84..03745a7dcd45 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -26,7 +26,7 @@ import abc from functools import partial, wraps import math -from typing import Any +from typing import Any, Sequence import numpy as np import jax @@ -43,9 +43,8 @@ from jax._src.numpy import lax_numpy from jax._src.numpy import reductions from jax._src.numpy import ufuncs -from jax._src.numpy import util from jax._src.ops import scatter -from jax._src.typing import Array, ArrayLike, DimSize, DTypeLike, Shape +from jax._src.typing import Array, ArrayLike, DimSize, DTypeLike, Shape, StaticScalar from jax._src.util import safe_zip, safe_map map, unsafe_map = safe_map, map @@ -59,7 +58,56 @@ # functions, which can themselves handle instances from any of these classes. -def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: +def _all(self: ArrayLike, axis: reductions.Axis = None, out: None = None, + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: + """Test whether all array elements along a given axis evaluate to True. + + Refer to :func:`jax.numpy.all` for the full documentation. + """ + return reductions.all(self, axis=axis, out=out, keepdims=keepdims, where=where) + +def _any(self: Array, axis: reductions.Axis = None, out: None = None, + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: + """Test whether any array elements along a given axis evaluate to True. + + Refer to :func:`jax.numpy.any` for the full documentation. + """ + return reductions.any(self, axis=axis, out=out, keepdims=keepdims, where=where) + +def _argmax(self: Array, axis: int | None = None, out: None = None, + keepdims: bool | None = None) -> Array: + """Return the index of the maximum value. + + Refer to :func:`jax.numpy.argmax` for the full documentation. + """ + return lax_numpy.argmax(self, axis=axis, out=out, keepdims=keepdims) + +def _argmin(self: Array, axis: int | None = None, out: None = None, + keepdims: bool | None = None) -> Array: + """Return the index of the minimum value. + + Refer to :func:`jax.numpy.argmin` for the full documentation. + """ + return lax_numpy.argmin(self, axis=axis, out=out, keepdims=keepdims) + +def _argpartition(self: Array, kth: int, axis: int = -1, + kind: str = 'introselect', order: None = None) -> Array: + """Return the indices that partially sort the array. + + Refer to :func:`jax.numpy.argpartition` for the full documentation. + """ + return lax_numpy.argpartition(self, kth=kth, axis=axis, kind=kind, order=order) + +def _argsort(self: Array, axis: int | None = -1, *, kind: None = None, order: None = None, + stable: bool = True, descending: bool = False) -> Array: + """Return the indices that sort the array. + + Refer to :func:`jax.numpy.argsort` for the full documentation. + """ + return lax_numpy.argsort(self, axis=axis, kind=kind, order=order, + stable=stable, descending=descending) + +def _astype(self: Array, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: """Copy the array and cast to a specified dtype. This is implemented via :func:`jax.lax.convert_element_type`, which may @@ -67,42 +115,300 @@ def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Dev some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """ - return lax_numpy.astype(arr, dtype, copy=copy, device=device) + return lax_numpy.astype(self, dtype, copy=copy, device=device) -def _to_device(arr: ArrayLike, device: xc.Device | Sharding, *, - stream: int | Any | None = None): - if stream is not None: - raise NotImplementedError("stream argument of array.to_device()") - return api.device_put(arr, device) +def _choose(self: Array, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: + """Construct an array choosing from elements of multiple arrays. + Refer to :func:`jax.numpy.choose` for the full documentation. + """ + return lax_numpy.choose(self, choices=choices) -def _nbytes(arr: ArrayLike) -> int: - """Total bytes consumed by the elements of the array.""" - return np.size(arr) * dtypes.dtype(arr, canonicalize=True).itemsize +def _clip(number: ArrayLike, + min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: + """Return an array whose values are limited to a specified range. + + Refer to :func:`jax.numpy.clip` for full documentation. + """ + return lax_numpy.clip(number, min=min, max=max) + +def _compress(self: Array, condition: ArrayLike, + axis: int | None = None, *, out: None = None, + size: int | None = None, fill_value: ArrayLike = 0) -> Array: + """Return selected slices of this array along given axis. + + Refer to :func:`jax.numpy.compress` for full documentation. + """ + return lax_numpy.compress(condition, self, axis=axis, out=out, + size=size, fill_value=fill_value) + +def _conj(self: Array) -> Array: + """Return the complex conjugate of the array. + + Refer to :func:`jax.numpy.conj` for the full documentation. + """ + return ufuncs.conj(self) + +def _conjugate(self: Array) -> Array: + """Return the complex conjugate of the array. + + Refer to :func:`jax.numpy.conjugate` for the full documentation. + """ + return ufuncs.conjugate(self) + +def _copy(self: Array) -> Array: + """Return a copy of the array. + + Refer to :func:`jax.numpy.copy` for the full documentation. + """ + return lax_numpy.copy(self) + +def _cumprod(self: Array, /, axis: int | Sequence[int] | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Return the cumulative product of the array. + + Refer to :func:`jax.numpy.cumprod` for the full documentation. + """ + return reductions.cumprod(self, axis=axis, dtype=dtype, out=out) +def _cumsum(self: Array, /, axis: int | Sequence[int] | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Return the cumulative sum of the array. + + Refer to :func:`jax.numpy.cumsum` for the full documentation. + """ + return reductions.cumsum(self, axis=axis, dtype=dtype, out=out) + +def _diagonal(self: Array, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: + """Return the specified diagonal from the array. + + Refer to :func:`jax.numpy.diagonal` for the full documentation. + """ + return lax_numpy.diagonal(self, offset=offset, axis1=axis1, axis2=axis2) -def _item(a: Array, *args) -> bool | int | float | complex: +def _dot(self: Array, b: ArrayLike, *, precision: lax_internal.PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: + """Compute the dot product of two arrays. + + Refer to :func:`jax.numpy.dot` for the full documentation. + """ + return lax_numpy.dot(self, b, precision=precision, preferred_element_type=preferred_element_type) + +def _flatten(self: Array, order: str = "C") -> Array: + """Flatten array into a 1-dimensional shape. + + Refer to :func:`jax.numpy.ravel` for the full documentation. + """ + return lax_numpy.ravel(self, order=order) + +def _imag_property(self: Array) -> Array: + """Return the imaginary part of the array.""" + return ufuncs.imag(self) + +def _item(self: Array, *args: int) -> bool | int | float | complex: """Copy an element of an array to a standard Python scalar and return it.""" - arr = core.concrete_or_error(np.asarray, a, context="This occurred in the item() method of jax.Array") - if dtypes.issubdtype(a.dtype, dtypes.extended): - raise TypeError(f"No Python scalar type for {a.dtype=}") + arr = core.concrete_or_error(np.asarray, self, context="This occurred in the item() method of jax.Array") + if dtypes.issubdtype(self.dtype, dtypes.extended): + raise TypeError(f"No Python scalar type for {arr.dtype=}") return arr.item(*args) -def _itemsize(arr: ArrayLike) -> int: +def _itemsize_property(self: Array) -> int: """Length of one array element in bytes.""" - return dtypes.dtype(arr, canonicalize=True).itemsize + return dtypes.dtype(self, canonicalize=True).itemsize +def _matrix_transpose_property(self: Array): + """Compute the (batched) matrix transpose. -def _clip(number: ArrayLike, - min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: - """Return an array whose values are limited to a specified range. + Refer to :func:`jax.numpy.matrix_transpose` for details. + """ + return lax_numpy.matrix_transpose(self) - Refer to :func:`jax.numpy.clip` for full documentation.""" - return lax_numpy.clip(number, min=min, max=max) +def _max(self: Array, axis: reductions.Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: + """Return the maximum of array elements along a given axis. + + Refer to :func:`jax.numpy.max` for the full documentation. + """ + return reductions.max(self, axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) + + +def _mean(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, *, + where: ArrayLike | None = None) -> Array: + """Return the mean of array elements along a given axis. + + Refer to :func:`jax.numpy.mean` for the full documentation. + """ + return reductions.mean(self, axis=axis, dtype=dtype, out=out, + keepdims=keepdims, where=where) + +def _min(self: Array, axis: reductions.Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: + """Return the minimum of array elements along a given axis. + + Refer to :func:`jax.numpy.min` for the full documentation. + """ + return reductions.min(self, axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) + +def _nbytes_property(self: Array) -> int: + """Total bytes consumed by the elements of the array.""" + return np.size(self) * dtypes.dtype(self, canonicalize=True).itemsize + +def _nonzero(self: Array, *, size: int | None = None, + fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None + ) -> tuple[Array, ...]: + """Return indices of nonzero elements of an array. + + Refer to :func:`jax.numpy.nonzero` for the full documentation. + """ + return lax_numpy.nonzero(self, size=size, fill_value=fill_value) + +def _prod(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None, + promote_integers: bool = True) -> Array: + """Return product of the array elements over a given axis. + Refer to :func:`jax.numpy.prod` for the full documentation. + """ + return reductions.prod(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where=where, promote_integers=promote_integers) -def _transpose(a: Array, *args: Any) -> Array: - """Returns a view of the array with axes transposed. +def _ptp(self: Array, axis: reductions.Axis = None, out: None = None, + keepdims: bool = False) -> Array: + """Return the peak-to-peak range along a given axis. + + Refer to :func:`jax.numpy.ptp` for the full documentation. + """ + return reductions.ptp(self, axis=axis, out=out, keepdims=keepdims) + +def _real_property(self: Array) -> Array: + """Return the real part of the array.""" + return ufuncs.real(self) + +def _repeat(self: Array, repeats: ArrayLike, axis: int | None = None, *, + total_repeat_length: int | None = None) -> Array: + """Construct an array from repeated elements. + + Refer to :func:`jax.numpy.repeat` for the full documentation. + """ + return lax_numpy.repeat(self, repeats=repeats, axis=axis, total_repeat_length=total_repeat_length) + +def _reshape(self: Array, *args: Any, order: str = "C") -> Array: + """Returns an array containing the same data with a new shape. + + Refer to :func:`jax.numpy.reshape` for full documentation. + """ + __tracebackhide__ = True + newshape = _compute_newshape(self, args[0] if len(args) == 1 else args) + if order == "C": + return lax.reshape(self, newshape, None) + elif order == "F": + dims = list(range(self.ndim)[::-1]) + return lax.reshape(self, newshape[::-1], dims).T + elif order == "A": + raise NotImplementedError("np.reshape order=A is not implemented.") + else: + raise ValueError(f"Unexpected value for 'order' argument: {order}.") + +def _round(self: Array, decimals: int = 0, out: None = None) -> Array: + """Round array elements to a given decimal. + + Refer to :func:`jax.numpy.round` for full documentation. + """ + return lax_numpy.round(self, decimals=decimals, out=out) + +def _searchsorted(self: Array, v: ArrayLike, side: str = 'left', + sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: + """Perform a binary search within a sorted array. + + Refer to :func:`jax.numpy.searchsorted` for full documentation.""" + return lax_numpy.searchsorted(self, v, side=side, sorter=sorter, method=method) + +def _sort(self: Array, axis: int | None = -1, *, kind: None = None, + order: None = None, stable: bool = True, descending: bool = False) -> Array: + """Return a sorted copy of an array. + + Refer to :func:`jax.numpy.sort` for full documentation. + """ + return lax_numpy.sort(self, axis=axis, kind=kind, order=order, + stable=stable, descending=descending) + +def _squeeze(self: Array, axis: int | Sequence[int] | None = None) -> Array: + """Remove one or more length-1 axes from array. + + Refer to :func:`jax.numpy.squeeze` for full documentation. + """ + return lax_numpy.squeeze(self, axis=axis) + +def _std(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: + """Compute the standard deviation along a given axis. + + Refer to :func:`jax.numpy.std` for full documentation. + """ + return reductions.std(self, axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, + where=where, correction=correction) + +def _sum(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None, promote_integers: bool = True) -> Array: + """Sum of the elements of the array over a given axis. + + Refer to :func:`jax.numpy.sum` for full documentation. + """ + return reductions.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims, + where=where, promote_integers=promote_integers) + +def _swapaxes(self: Array, axis1: int, axis2: int) -> Array: + """Swap two axes of an array. + + Refer to :func:`jax.numpy.swapaxes` for full documentation. + """ + return lax_numpy.swapaxes(self, axis1=axis1, axis2=axis2) + + +def _take(self: Array, indices: ArrayLike, axis: int | None = None, out: None = None, + mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, + fill_value: StaticScalar | None = None) -> Array: + """Take elements from an array. + + Refer to :func:`jax.numpy.take` for full documentation. + """ + return lax_numpy.take(self, indices, axis=axis, out=out, mode=mode, unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, fill_value=fill_value) + +def _to_device(self: Array, device: xc.Device | Sharding, *, + stream: int | Any | None = None): + """Return a copy of the array on the specified device + + Args: + device: :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + stream: not implemented, passing a non-None value will lead to an error. + Returns: + copy of array placed on the specified device or devices. + """ + if stream is not None: + raise NotImplementedError("stream argument of array.to_device()") + return api.device_put(self, device) + + +def _trace(self: Array, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Return the sum along the diagonal. + + Refer to :func:`jax.numpy.trace` for full documentation. + """ + return lax_numpy.trace(self, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out) + +def _transpose(self: Array, *args: Any) -> Array: + """Returns a copy of the array with axes transposed. Refer to :func:`jax.numpy.transpose` for full documentation. """ @@ -112,10 +418,27 @@ def _transpose(a: Array, *args: Any) -> Array: axis = args[0] if args[0] is None else _ensure_index_tuple(args[0]) else: axis = _ensure_index_tuple(args) - return lax_numpy.transpose(a, axis) + return lax_numpy.transpose(self, axis) + +def _transpose_property(self: Array): + """Compute the all-axis array transpose. + + Refer to :func:`jax.numpy.transpose` for details. + """ + return lax_numpy.transpose(self) + +def _var(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: + """Compute the variance along a given axis. + + Refer to :func:`jax.numpy.var` for full documentation. + """ + return reductions.var(self, axis=axis, dtype=dtype, out=out, ddof=ddof, + keepdims=keepdims, where=where, correction=correction) -def _compute_newshape(a: ArrayLike, newshape: DimSize | Shape) -> Shape: +def _compute_newshape(arr: Array, newshape: DimSize | Shape) -> Shape: """Fixes a -1 value in newshape, if present.""" orig_newshape = newshape # for error messages try: @@ -130,43 +453,24 @@ def _compute_newshape(a: ArrayLike, newshape: DimSize | Shape) -> Shape: if neg1s: i, = neg1s other_sizes = (*newshape[:i], *newshape[i+1:]) - if (all(isinstance(d, int) for d in (*np.shape(a), *other_sizes)) and - np.size(a) % math.prod(other_sizes) != 0): - raise TypeError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) " + if (all(isinstance(d, int) for d in (*arr.shape, *other_sizes)) and + arr.size % math.prod(other_sizes) != 0): + raise TypeError(f"cannot reshape array of shape {arr.shape} (size {arr.size}) " f"into shape {orig_newshape} because the product of " f"specified axis sizes ({math.prod(other_sizes)}) does " - f"not evenly divide {np.size(a)}") - sz = core.cancel_divide_tracers(np.shape(a), other_sizes) + f"not evenly divide {arr.size}") + sz = core.cancel_divide_tracers(arr.shape, other_sizes) if sz is not None: return (*newshape[:i], sz, *newshape[i+1:]) else: - if (all(isinstance(d, int) for d in (*np.shape(a), *newshape)) and - np.size(a) != math.prod(newshape)): - raise TypeError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) " + if (all(isinstance(d, int) for d in (*arr.shape, *newshape)) and + arr.size != math.prod(newshape)): + raise TypeError(f"cannot reshape array of shape {arr.shape} (size {arr.size}) " f"into shape {orig_newshape} (size {math.prod(newshape)})") - return tuple(-core.divide_shape_sizes(np.shape(a), newshape) + return tuple(-core.divide_shape_sizes(arr.shape, newshape) if core.definitely_equal(d, -1) else d for d in newshape) - -def _reshape(a: Array, *args: Any, order: str = "C") -> Array: - """Returns an array containing the same data with a new shape. - - Refer to :func:`jax.numpy.reshape` for full documentation. - """ - __tracebackhide__ = True - newshape = _compute_newshape(a, args[0] if len(args) == 1 else args) - if order == "C": - return lax.reshape(a, newshape, None) - elif order == "F": - dims = list(range(a.ndim)[::-1]) - return lax.reshape(a, newshape[::-1], dims).T - elif order == "A": - raise NotImplementedError("np.reshape order=A is not implemented.") - else: - raise ValueError(f"Unexpected value for 'order' argument: {order}.") - - -def _view(arr: Array, dtype: DTypeLike | None = None, type: None = None) -> Array: +def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Array: """Return a bitwise copy of the array, viewed as a new dtype. This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`. @@ -187,72 +491,70 @@ def _view(arr: Array, dtype: DTypeLike | None = None, type: None = None) -> Arra should only contain 0 or 1 bytes. Otherwise, results may be unpredictable or may change depending on how the result is used. - This conversion is guaranteed and safe: - >>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_) - Array([ True, False, True], dtype=bool) + This conversion is guaranteed and safe:: + + >>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_) + Array([ True, False, True], dtype=bool) However, there are no guarantees about the results of any expression involving a view such as this: `jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_)`. In particular, the results may change between JAX releases and depending on the platform. To safely convert such an array to a boolean array, compare it - with `0`: + with `0`:: - >>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0 - Array([ True, True, False], dtype=bool) + >>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0 + Array([ True, True, False], dtype=bool) """ if type is not None: raise NotImplementedError("`type` argument of array.view() is not supported.") - util.check_arraylike("view", arr) - arr = lax_numpy.asarray(arr) - dtypes.check_user_dtype_supported(dtype, "view") dtype = dtypes.canonicalize_dtype(dtype) - if arr.ndim == 0: - if arr.dtype.itemsize != dtype.itemsize: + if self.ndim == 0: + if self.dtype.itemsize != dtype.itemsize: raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.") - return _view(lax.expand_dims(arr, (0,)), dtype).squeeze() + return _view(lax.expand_dims(self, (0,)), dtype).squeeze() - if (arr.shape[-1] * arr.dtype.itemsize) % dtype.itemsize != 0: + if (self.shape[-1] * self.dtype.itemsize) % dtype.itemsize != 0: raise ValueError("When changing to a larger dtype, its size must be a divisor " "of the total size in bytes of the last axis of the array.") - if arr.dtype == dtype: - return arr + if self.dtype == dtype: + return self # lax.bitcast_convert_type does not support bool or complex; in these cases we # cast to a compatible type and recursively call _view for simplicity. - if arr.dtype == bool: - return _view(arr.astype('uint8'), dtype) + if self.dtype == bool: + return _view(self.astype('uint8'), dtype) - if lax_numpy.issubdtype(arr.dtype, np.complexfloating): - new_shape = (*arr.shape[:-1], arr.shape[-1] * 2) - new_dtype = lax_numpy.finfo(arr.dtype).dtype - arr = (lax_numpy.zeros(new_shape, new_dtype) - .at[..., 0::2].set(arr.real) - .at[..., 1::2].set(arr.imag)) - return _view(arr, dtype) + if lax_numpy.issubdtype(self.dtype, np.complexfloating): + new_shape = (*self.shape[:-1], self.shape[-1] * 2) + new_dtype = lax_numpy.finfo(self.dtype).dtype + self = (lax_numpy.zeros(new_shape, new_dtype) + .at[..., 0::2].set(self.real) + .at[..., 1::2].set(self.imag)) + return _view(self, dtype) if dtype == bool: - return _view(arr, np.uint8).astype(bool) + return _view(self, np.uint8).astype(bool) if lax_numpy.issubdtype(dtype, np.complexfloating): - out = _view(arr, lax_numpy.finfo(dtype).dtype).astype(dtype) + out = _view(self, lax_numpy.finfo(dtype).dtype).astype(dtype) return out[..., 0::2] + 1j * out[..., 1::2] # lax.bitcast_convert_type adds or subtracts dimensions depending on the # relative bitwidths of the dtypes; we account for that with reshapes. - if arr.dtype.itemsize < dtype.itemsize: - factor = dtype.itemsize // arr.dtype.itemsize - arr = arr.reshape(*arr.shape[:-1], arr.shape[-1] // factor, factor) - return lax.bitcast_convert_type(arr, dtype) + if self.dtype.itemsize < dtype.itemsize: + factor = dtype.itemsize // self.dtype.itemsize + out = self.reshape(*self.shape[:-1], self.shape[-1] // factor, factor) + return lax.bitcast_convert_type(out, dtype) - if arr.dtype.itemsize > dtype.itemsize: - out = lax.bitcast_convert_type(arr, dtype) + if self.dtype.itemsize > dtype.itemsize: + out = lax.bitcast_convert_type(self, dtype) return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1]) - return lax.bitcast_convert_type(arr, dtype) + return lax.bitcast_convert_type(self, dtype) def _notimplemented_flat(self): @@ -291,9 +593,6 @@ def _operator_round(number: ArrayLike, ndigits: int | None = None) -> Array: # If `ndigits` is None, for a builtin float round(7.5) returns an integer. return out.astype(int) if ndigits is None else out -def _copy(self: Array) -> Array: - return self.copy() - def _deepcopy(self: Array, memo: Any) -> Array: del memo # unused return self.copy() @@ -311,19 +610,9 @@ def __array_module__(self, types): return NotImplemented -def _compress_method(a: ArrayLike, condition: ArrayLike, - axis: int | None = None, *, out: None = None, - size: int | None = None, fill_value: ArrayLike = 0) -> Array: - """Return selected slices of this array along given axis. - - Refer to :func:`jax.numpy.compress` for full documentation.""" - return lax_numpy.compress(condition, a, axis=axis, out=out, - size=size, fill_value=fill_value) - - @core.stash_axis_env() @partial(jax.jit, static_argnums=(1,2,3)) -def _multi_slice(arr: ArrayLike, +def _multi_slice(self: Array, start_indices: tuple[tuple[int, ...]], limit_indices: tuple[tuple[int, ...]], removed_dims: tuple[tuple[int, ...]]) -> list[Array]: @@ -334,13 +623,13 @@ def _multi_slice(arr: ArrayLike, """ results: list[Array] = [] for starts, limits, removed in zip(start_indices, limit_indices, removed_dims): - sliced = lax.slice(arr, starts, limits) + sliced = lax.slice(self, starts, limits) if removed: sliced = lax.squeeze(sliced, removed) results.append(sliced) return results -# The next two functions are related to iter(device_array), implemented here to +# The next two functions are related to iter(array), implemented here to # avoid circular imports. @jax.jit def _unstack(x: Array) -> list[Array]: @@ -667,46 +956,46 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, _array_methods = { "__array_namespace__": array_api_metadata.__array_namespace__, - "all": reductions.all, - "any": reductions.any, - "argmax": lax_numpy.argmax, - "argmin": lax_numpy.argmin, - "argpartition": lax_numpy.argpartition, - "argsort": lax_numpy.argsort, + "all": _all, + "any": _any, + "argmax": _argmax, + "argmin": _argmin, + "argpartition": _argpartition, + "argsort": _argsort, "astype": _astype, - "choose": lax_numpy.choose, + "choose": _choose, "clip": _clip, - "conj": ufuncs.conj, - "conjugate": ufuncs.conjugate, - "compress": _compress_method, - "copy": lax_numpy.copy, - "cumprod": reductions.cumprod, - "cumsum": reductions.cumsum, - "diagonal": lax_numpy.diagonal, - "dot": lax_numpy.dot, - "flatten": lax_numpy.ravel, + "compress": _compress, + "conj": _conj, + "conjugate": _conjugate, + "copy": _copy, + "cumprod": _cumprod, + "cumsum": _cumsum, + "diagonal": _diagonal, + "dot": _dot, + "flatten": _flatten, "item": _item, - "max": reductions.max, - "mean": reductions.mean, - "min": reductions.min, - "nonzero": lax_numpy.nonzero, - "prod": reductions.prod, - "ptp": reductions.ptp, - "ravel": lax_numpy.ravel, - "repeat": lax_numpy.repeat, + "max": _max, + "mean": _mean, + "min": _min, + "nonzero": _nonzero, + "prod": _prod, + "ptp": _ptp, + "ravel": _flatten, + "repeat": _repeat, "reshape": _reshape, - "round": lax_numpy.round, - "searchsorted": lax_numpy.searchsorted, - "sort": lax_numpy.sort, - "squeeze": lax_numpy.squeeze, - "std": reductions.std, - "sum": reductions.sum, - "swapaxes": lax_numpy.swapaxes, - "take": lax_numpy.take, + "round": _round, + "searchsorted": _searchsorted, + "sort": _sort, + "squeeze": _squeeze, + "std": _std, + "sum": _sum, + "swapaxes": _swapaxes, + "take": _take, "to_device": _to_device, - "trace": lax_numpy.trace, + "trace": _trace, "transpose": _transpose, - "var": reductions.var, + "var": _var, "view": _view, # Methods exposed in order to avoid circular imports @@ -721,12 +1010,12 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, _array_properties = { "flat": _notimplemented_flat, - "T": lax_numpy.transpose, - "mT": lax_numpy.matrix_transpose, - "real": ufuncs.real, - "imag": ufuncs.imag, - "nbytes": _nbytes, - "itemsize": _itemsize, + "T": _transpose_property, + "mT": _matrix_transpose_property, + "real": _real_property, + "imag": _imag_property, + "nbytes": _nbytes_property, + "itemsize": _itemsize_property, "at": _IndexUpdateHelper, } @@ -772,14 +1061,14 @@ def _set_tracer_aval_forwarding(tracer, exclude=()): if prop_name not in exclude: setattr(tracer, prop_name, _forward_property_to_aval(prop_name)) -def _set_array_base_attributes(device_array, include=None, exclude=None): +def _set_array_base_attributes(array_impl, include=None, exclude=None): # Forward operators, methods, and properties on Array to lax_numpy # functions (with no Tracers involved; this forwarding is direct) def maybe_setattr(attr_name, target): if exclude is not None and attr_name in exclude: return if not include or attr_name in include: - setattr(device_array, attr_name, target) + setattr(array_impl, attr_name, target) for operator_name, function in _array_operators.items(): maybe_setattr(f"__{operator_name}__", function) @@ -789,10 +1078,10 @@ def maybe_setattr(attr_name, target): maybe_setattr(prop_name, property(prop)) for name, func in _impl_only_array_methods.items(): - setattr(device_array, name, func) + setattr(array_impl, name, func) -def _set_array_attributes(device_array): - setattr(device_array, "__array_module__", __array_module__) +def _set_array_attributes(array_impl): + setattr(array_impl, "__array_module__", __array_module__) def _make_abstract_method(name, func): @abc.abstractmethod From 24394a1b03f01138219013f4773104b834e498b7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 16 Aug 2024 09:20:13 -0700 Subject: [PATCH 149/702] Implement initial vmap over pallas_call w/ ragged inputs (via jumbles) The plan here is to load it up with invariants, and start with a really simple kernel. After that, we can slowly relax the various invariants and implement support for others. Note - the work saving here is compute only, not memory yet. A fast-followup CL is adding memory savings via index-map rewriting PiperOrigin-RevId: 663752447 --- jax/_src/core.py | 24 ++- jax/_src/interpreters/batching.py | 6 +- jax/_src/pallas/core.py | 53 ++++-- jax/_src/pallas/mosaic/lowering.py | 28 ++- jax/_src/pallas/pallas_call.py | 275 ++++++++++++++++++++++++++--- tests/pallas/BUILD | 23 +++ tests/pallas/pallas_jumble_test.py | 201 +++++++++++++++++++++ 7 files changed, 560 insertions(+), 50 deletions(-) create mode 100644 tests/pallas/pallas_jumble_test.py diff --git a/jax/_src/core.py b/jax/_src/core.py index ebf29cf0b253..61ed81cdeea9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1954,6 +1954,7 @@ def __init__(self, aval, data): assert data.shape == pad_shape self._aval = aval self._data = data + shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) aval = property(lambda self: self._aval) @@ -1964,21 +1965,38 @@ def __repr__(self) -> str: dtypestr = _short_dtype_name(self._aval.dtype) shapestr = ','.join(map(str, self.shape)) - slices = tuple(slice(int(d._data)) if type(d) is DArray and - type(d.dtype) is bint else slice(None) for d in self.shape) - data = self._data[slices] + data = self.data return f'{dtypestr}[{shapestr}] with value: {data}' + def __hash__(self) -> int: if not self.shape: return hash((self._aval, int(self._data))) raise TypeError("unhashable type: DArray") + def __eq__(self, other): if isinstance(other, DArray) and self._aval == other._aval: return self._data == other._data return False + def __len__(self): return self.shape[0] + @property + def data(self): + if not self.shape and type(self.dtype) is bint: + # special-case scalar bints + return self._data + + slices = tuple( + slice(int(d._data)) + if type(d) is DArray and type(d.dtype) is bint + else slice(None) + for d in self.shape + ) + data = self._data[slices] + return data + + pytype_aval_mappings[DArray] = \ lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, x._data) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index fbcd2c4a7a30..27cde6d31d35 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -88,6 +88,7 @@ def _jumble_flatten(jumble): elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape)) aval = jumble.aval.replace(elt_ty=elt_ty) return (lengths, jumble.data), aval + def _jumble_unflatten(aval, x): lengths, data = x new_shape = [d.replace(lengths=lengths[d.lengths - 1]) @@ -251,7 +252,10 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: - assert False + # TODO(mvoz): This is a terrible place to fall into if you pass + # a non jumble type in, make it clearer what went wrong. + assert False, f'Unexpected type in ELT? {type(x)}' + to_elt_handlers: dict[type, ToEltHandler] = {} def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 0ef208f755e5..09e02ea5c3a1 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -112,7 +112,10 @@ class AbstractMemoryRef(state.AbstractRef): def __init__(self, inner_aval: jax_core.AbstractValue, memory_space: Any): - assert isinstance(inner_aval, jax_core.ShapedArray) + + assert isinstance( + inner_aval, jax_core.ShapedArray + ), f"Illegal ref, got {type(inner_aval)}" self.inner_aval = inner_aval self.memory_space = memory_space @@ -167,9 +170,7 @@ class PallasGridContext: mapped_dims: tuple[int, ...] def size(self, axis: int) -> int | DynamicGridDim: - valid_grid = tuple( - s for i, s in enumerate(self.grid) if i not in self.mapped_dims - ) + valid_grid = tuple(self.grid) try: size = valid_grid[axis] except IndexError as e: @@ -338,7 +339,10 @@ def check_invariants(self) -> None: ) assert not self.index_map_jaxpr.consts - assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals) + assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), ( + self.block_shape, + self.index_map_jaxpr.out_avals, + ) assert all(ov.shape == () and (ov.dtype == jnp.int32 or ov.dtype == jnp.int64) for ov in self.index_map_jaxpr.out_avals), ( @@ -422,6 +426,8 @@ class GridMapping: num_inputs: int num_outputs: int num_scratch_operands: int + get_grid_indices: Callable | None = None + local_grid_env: Callable | None = None def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -442,8 +448,8 @@ def check_invariants(self) -> None: assert len(index_map_args) >= len(self.grid) for i in range(len(self.grid)): index_map_arg = index_map_args[i] - assert index_map_arg.shape == () - assert index_map_arg.dtype == jnp.int32 + assert index_map_arg.shape == (), f"index_map_arg: {index_map_arg}" + assert index_map_arg.dtype == jnp.int32, f"index_map_arg: {index_map_arg}" assert len(self.vmapped_dims) <= len(self.grid) for i in self.vmapped_dims: @@ -454,8 +460,11 @@ def check_invariants(self) -> None: for bm in self.block_mappings: bm.check_invariants() - assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), ( + assert tuple(self.index_map_avals) == tuple( + bm.index_map_jaxpr.in_avals + ), ( self.index_map_avals, + "|", bm.index_map_jaxpr.in_avals, ) @@ -547,6 +556,17 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool: return True return jax_core.is_dim(dim) + +def _max_shape_from_aval(array_aval: jax_core.ShapedArray): + array_aval_shape = list(array_aval.shape) + for i, s in enumerate(array_aval.shape): + aval = jax_core.get_aval(s) + if isinstance(aval, jax_core.DShapedArray): + array_aval_shape[i] = aval.dtype.bound + + return tuple(array_aval_shape) + + def _convert_block_spec_to_block_mapping( block_spec: BlockSpec, origin: OriginStr, @@ -575,8 +595,15 @@ def _convert_block_spec_to_block_mapping( f"array shape {array_aval.shape}.") unmapped_block_shape = tuple(s for s in block_shape if s is not None) - block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape), - block_spec.memory_space) + block_array_aval = array_aval.update(shape=unmapped_block_shape) + if isinstance(array_aval, jax_core.DShapedArray): + # Get the "max" shape for the ragged array. + block_array_aval = jax_core.ShapedArray( + block_array_aval.shape, + block_array_aval.dtype, + block_array_aval.weak_type, + ) + block_aval = AbstractMemoryRef(block_array_aval, block_spec.memory_space) if not jax_core.is_constant_shape(block_aval.shape): raise ValueError( @@ -609,12 +636,12 @@ def _convert_block_spec_to_block_mapping( f"{origin} must return integer scalars. Output[{i}] has type " f"{ov}.") - if consts: raise ValueError( f"Index map function {index_map_src_info} for " f"{origin} must not capture constants: {consts}") + array_aval_shape = _max_shape_from_aval(array_aval) mapping = BlockMapping( block_shape=mapped_block_shape, @@ -622,7 +649,9 @@ def _convert_block_spec_to_block_mapping( index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), index_map_src_info=index_map_src_info, indexing_mode=block_spec.indexing_mode, - array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype), + array_shape_dtype=jax.ShapeDtypeStruct( + array_aval_shape, array_aval.dtype + ), origin=origin, ) mapping.check_invariants() diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 86ce2f0b1b81..aee894ee1b7e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -298,6 +298,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, self.jaxpr = jaxpr self.block_mappings = grid_mapping.block_mappings self.mapped_dims = grid_mapping.vmapped_dims + # TODO(mvoz): Generalize to not need this user_grid = tuple( g for i, g in enumerate(self.grid) if i not in self.mapped_dims ) @@ -345,9 +346,19 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, for _ in range(len(self.grid)) ]) self._prepare_mesh_info(mesh) - def _get_grid_indices(indices): - return indices - self.get_grid_indices = _get_grid_indices + + if grid_mapping.get_grid_indices is None: + + def _get_grid_indices(indices, maybe_include_mapped_dims: bool): + if maybe_include_mapped_dims: + return indices + return tuple( + idx for i, idx in enumerate(indices) if i not in self.mapped_dims + ) + + self.get_grid_indices = _get_grid_indices + else: + self.get_grid_indices = grid_mapping.get_grid_indices def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): if not self.has_communication: @@ -595,7 +606,9 @@ def lower_jaxpr_to_transform_func( ] def body_func(*args): grid_indices, scalar_prefetch = split_list(args, [num_grid]) - jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices( + grid_indices, maybe_include_mapped_dims=True + ) arg_block_shapes = [ *[()] * len(jaxpr_indices), *mosaic_grid_mapping.scalar_prefetch_block_shapes, @@ -663,9 +676,9 @@ def lower_jaxpr_to_func( def body_func(*args): grid_indices, scalar_prefetch, operands_and_scratch = split_list( args, [num_grid, num_scalar_prefetch]) - grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) - jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices) - if i not in mosaic_grid_mapping.mapped_dims) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices( + grid_indices, maybe_include_mapped_dims=False + ) mesh_info = mosaic_grid_mapping.mesh_info if mesh_info is not None: mesh_context = MeshContext( @@ -2365,6 +2378,7 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): + if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 4f3c9918f664..bb1683e38b1c 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -228,6 +228,12 @@ def _pallas_call_impl_interpret( # Pad values to evenly divide into block dimensions. This matches the # behavior of the non-interpret mode. We pad with NaN, to make it easier # to catch OOB accesses. + for carry_element in carry: + aval = carry_element.aval + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype) + carry_element.aval = aval + carry = map(_pad_values_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) @@ -247,11 +253,16 @@ def cond(carry): return i < num_iterations def body(carry): i, loop_idx, *carry_blocks = carry - local_grid_env = tuple( - pallas_core.GridAxis(idx, b) - for dim, (idx, b) in enumerate(zip(loop_idx, grid)) - if dim not in grid_mapping.vmapped_dims - ) + + if grid_mapping.local_grid_env is not None: + local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + else: + local_grid_env = tuple( + pallas_core.GridAxis(idx, b) + for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + if dim not in grid_mapping.vmapped_dims + ) + carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks]) with pallas_core.grid_env(local_grid_env): start_indices = [ @@ -268,8 +279,14 @@ def body(carry): len(blocks), len(scratch_values), ) - blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, - *blocks, *scratch) + for s in scalars: + aval = jax_core.get_aval(s) + if isinstance(aval, jax_core.DShapedArray): + s.aval = aval.update(dtype=jnp.int32) + + blocks = jax_core.eval_jaxpr( + discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch + ) _, out_inout, out_scratch = split_list( blocks, [grid_mapping.num_index_operands, num_inout_blocks]) @@ -390,19 +407,55 @@ def _pallas_call_jvp_rule( ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule -def _batch_block_mapping(grid_mapping: GridMapping, - axis_size: int, - aval: jax_core.ShapedArray, - dim: int | batching.NotMapped, - block_mapping: BlockMapping) -> BlockMapping: + +def _batch_block_mapping( + grid_mapping: GridMapping, + axis_size: int, + aval: jax_core.ShapedArray, + dim: int | batching.NotMapped, + block_mapping: BlockMapping, + for_ragged: bool, +) -> BlockMapping: def _block_map_function(new_idx, *args): - indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr, - block_mapping.index_map_jaxpr.consts, - *args) + if for_ragged: + drop_last_args = args[:-1] + else: + drop_last_args = args + + indices = jax_core.eval_jaxpr( + block_mapping.index_map_jaxpr.jaxpr, + block_mapping.index_map_jaxpr.consts, + *drop_last_args, + ) if dim is not batching.not_mapped: - indices.insert(dim, new_idx) + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + stacked_axis = dim.stacked_axis + indices.insert(stacked_axis, new_idx) + else: + indices.insert(dim, new_idx) return tuple(indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] + + if for_ragged: + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + _, _, ragged_axis_length = _ragged_axis_parts(dim) + aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) + lengths_aval = pallas_core.AbstractMemoryRef( + aval, + pallas_core.MemorySpace.INDEX, + ) + idx_avals = [*idx_avals, lengths_aval] + else: + i32_aval_memref = pallas_core.AbstractMemoryRef( + jax_core.ShapedArray(([axis_size]), jnp.int32), + pallas_core.MemorySpace.INDEX, + ) + idx_avals = [*idx_avals, i32_aval_memref] + with grid_mapping.trace_env(): block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_block_map_function), idx_avals) @@ -411,12 +464,27 @@ def _block_map_function(new_idx, *args): new_block_shape = shape new_array_shape_dtype = block_mapping.array_shape_dtype else: - new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + new_block_shape = shape + stacked_axis = dim.stacked_axis + new_block_shape = tuple_insert( + new_block_shape, stacked_axis, pallas_core.mapped + ) + else: + new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + + array_shape = block_mapping.array_shape_dtype.shape + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + stacked_axis = dim.stacked_axis + array_shape = tuple_insert(array_shape, stacked_axis, axis_size) + else: + array_shape = tuple_insert(array_shape, dim, axis_size) + new_array_shape_dtype = jax.ShapeDtypeStruct( - tuple_insert(block_mapping.array_shape_dtype.shape, - dim, - axis_size), - block_mapping.array_shape_dtype.dtype) + array_shape, block_mapping.array_shape_dtype.dtype + ) jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts) return block_mapping.replace(block_shape=new_block_shape, @@ -547,6 +615,16 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: return result, (0,) * len(result) +def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]: + stacked_axis = dim.stacked_axis + ragged_axes = dim.ragged_axes + if len(ragged_axes) != 1: + raise ValueError("Multiple ragged axes not yet implemented.") + ragged_axis_dim = ragged_axes[0][0] + ragged_axis_length = ragged_axes[0][1] + return stacked_axis, ragged_axis_dim, ragged_axis_length + + def _pallas_call_batching_rule( args, dims, @@ -567,8 +645,26 @@ def _maybe_squeeze_out_bdim( return x return jnp.squeeze(x, axis=bdim) + all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)] + if len(all_ragged_axes) > 1: + raise ValueError("Multiple ragged dimensions not yet implemented.") + + if all_ragged_axes: + stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts( + all_ragged_axes[0] + ) + else: + stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None + + def get_size(i, x, d): + if not isinstance(d, batching.RaggedAxis): + return x.shape[d] + return x.aval.shape[i] + (axis_size,) = { - x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped + get_size(i=i, x=x, d=d) + for i, (x, d) in enumerate(zip(args, dims)) + if d is not batching.not_mapped } if axis_size == 1: # Why are we even vmapping? @@ -670,12 +766,27 @@ def _maybe_squeeze_out_bdim( num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands + lengths_aval = None + if ragged_axis_length is not None: + aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) + lengths_aval = pallas_core.AbstractMemoryRef( + aval, + pallas_core.MemorySpace.INDEX, + ) + # Only add a batch dimension for the avals that actually have a grid mapping. # This excludes scalar prefetch inputs (the first in the list) and scratch # operands (the last in the list). avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] batched_block_mappings = map( - partial(_batch_block_mapping, grid_mapping, axis_size), + partial( + _batch_block_mapping, + grid_mapping, + axis_size, + for_ragged=lengths_aval is not None, + ), avals_to_batch, all_dims[num_index_operands:], block_mappings, @@ -685,15 +796,23 @@ def _maybe_squeeze_out_bdim( grid_mapping.index_map_avals) assert not index_map_tree_kwargs batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args + + if lengths_aval: + batched_index_map_args = batched_index_map_args + (lengths_aval,) + num_index_operands += 1 + batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten( (batched_index_map_args, {})) + batched_grid_mapping = grid_mapping.replace( grid=(axis_size, *grid_mapping.grid), block_mappings=tuple(batched_block_mappings), - index_map_avals=batched_index_map_avals, + index_map_avals=tuple(batched_index_map_avals), index_map_tree=batched_index_map_tree, + num_index_operands=num_index_operands, vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims), ) + if cost_estimate is not None: batched_cost_estimate = CostEstimate( flops=cost_estimate.flops * axis_size, @@ -702,6 +821,103 @@ def _maybe_squeeze_out_bdim( ) else: batched_cost_estimate = None + + if lengths_aval: + batched_grid_mapping = batched_grid_mapping.replace( + get_grid_indices=lambda indices, maybe_include_mapped_dims: indices, + local_grid_env=lambda loop_idx, grid: tuple( + pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid) + ), + ) + + # Note - on zero filling counterfactuals + # A debug util to produce a counterfactual version of the when + # gating, where for all values that don't pass the @when check, + # we write 0s. This is useful for debugging, as certain lowering paths + # like mosaic will write the last data as passthrough, leading to + # potentially confusing results. + debug_zero_fill_counterfactual = debug + + first_block_mapping = batched_grid_mapping.block_mappings[0] + for block_mapping in batched_grid_mapping.block_mappings: + # This invariant may already be checked elsewhere, but lets reaffirm it + assert block_mapping.block_shape == first_block_mapping.block_shape, ( + f"block_mapping.block_shape: {block_mapping.block_shape}, " + f"first_block_mapping.block_shape: {first_block_mapping.block_shape}" + ) + assert ( + block_mapping.array_shape_dtype + == first_block_mapping.array_shape_dtype + ), ( + f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype}," + " first_block_mapping.array_shape_dtype:" + f" {first_block_mapping.array_shape_dtype}" + ) + + mapped_dim_idxs = [ + i + for i, d in enumerate(first_block_mapping.block_shape) + if d is pallas_core.mapped + ] + assert len(mapped_dim_idxs) == 1 + mapped_dim_idx = mapped_dim_idxs[0] + if stacked_axis != mapped_dim_idx: + raise ValueError( + f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}" + ) + + assert ragged_axis_dim is not None, "Invariant violation" + # This is the blockspec size of the dimension + val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim] + + def when_wrapped_kernel(lengths_ref, *args, **kwargs): + b_idx = jax.experimental.pallas.program_id(stacked_axis) + i_idx = ( + jax.experimental.pallas.program_id(ragged_axis_dim) + * val_at_ragged_dim + ) + b_len = lengths_ref[b_idx] + + # TODO(mvoz): Unimplemented primitive in pallas + # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) + # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") + + @jax.experimental.pallas.when(i_idx < b_len) + def f(): + # Important! This allows us to trace the inner kernel with the correct + # grid to preserve user program_id semantics. Ex: program_id(0) will + # always be analogous to program_id(1) in the outer kernel. + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs) + + if debug_zero_fill_counterfactual: + + @jax.experimental.pallas.when(i_idx >= b_len) + def g(): + for arg_ref in args: + arg_ref[...] = jnp.zeros_like(arg_ref) + + kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars] + flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten( + list(kernel_avals) + ) + # Important! This allows us to trace the outer kernel with the correct grid + # to enable accessing the batch program_id. + with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): + kernel_src_info: pallas_core.SrcInfoStr = "" + + jaxpr = _trace_kernel_to_jaxpr( + when_wrapped_kernel, + kernel_src_info, + batched_grid_mapping, + tuple(flat_kernel_avals), + kernel_in_tree, + interpret=interpret, + ) + + assert ragged_axis_length is not None + args = (ragged_axis_length, *args) + out = pallas_call_p.bind( *dynamic_grid_args, *args, @@ -1097,12 +1313,14 @@ def pallas_call( out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths) flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore for x in flat_out_shapes] + @jax.jit def wrapped(*args): flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args) in_paths, flat_args = unzip2(flat_args_with_paths) flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) for a in flat_args) + flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype) for v in flat_out_shapes) @@ -1172,15 +1390,18 @@ def wrapped(*args): return wrapped -def in_path_to_input_origin(in_path: tree_util.KeyPath, - arg_names: tuple[str, ...] | None) -> pallas_core.OriginStr: +def in_path_to_input_origin( + in_path: tree_util.KeyPath, arg_names: tuple[str, ...] | None +) -> pallas_core.OriginStr: """Converts `args[k]` into `arg_k_name`.""" if arg_names is None: return f"args{tree_util.keystr(in_path)}" if len(in_path) == 0: return "args" arg_idx, *rest_path = in_path - if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names): + if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len( + arg_names + ): return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path)) else: return f"args{tree_util.keystr(tuple(in_path))}" diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index c0cf61387cbb..5559a0552f9f 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -62,6 +62,29 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_test( + name = "pallas_jumble_test", + srcs = [ + "pallas_jumble_test.py", + ], + disable_configs = [ + "gpu", + "gpu_x32", + "gpu_a100", + "gpu_p100", + "gpu_p100_x32", + "gpu_h100", + ], + shard_count = { + "tpu": 1, + }, + deps = [ + "//jax:pallas", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_test( name = "ops_test", srcs = [ diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py new file mode 100644 index 000000000000..ee176a0363aa --- /dev/null +++ b/tests/pallas/pallas_jumble_test.py @@ -0,0 +1,201 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" + +from absl.testing import absltest +import jax +from jax import lax +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.interpreters import batching +from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + + +# TODO(mvoz): Update signatures of pallas_call to correct inputs/outputs. +# pylint: disable=no-value-for-parameter + +config.parse_flags_with_absl() + + +intx = dtypes.canonicalize_dtype(jnp.int64) +floatx = dtypes.canonicalize_dtype(jnp.float64) + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: + self.skipTest("On CPU the test works only in interpret mode") + if jtu.test_device_matches( + ["cuda"] + ) and not jtu.is_cuda_compute_capability_at_least("8.0"): + self.skipTest("Only works on GPU with capability >= sm80") + if sys.platform == "win32" and not self.INTERPRET: + self.skipTest("Only works on non-Windows platforms") + + super().setUp() + _trace_kernel_to_jaxpr.cache_clear() + + def pallas_call(self, *args, **kwargs): + return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + + +@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") +class PallasCallRaggedVmapTest(PallasBaseTest): + + def test_vmap_jumble_over_sin_kernel(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([128 * x for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct( + (8, col_grid_size * 128), dtype=jnp.float32 + ), + grid=(1, col_grid_size), + interpret=self.INTERPRET, + # See note - on zero filling counterfactuals + debug=True, + )(x) + + res = jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + res = res.data + total = len(ragged_shape) * row_count * col_grid_size * 128 + res_total = np.prod(res.shape) + self.assertEqual(res_total, total) + ragged_total = 0 + for dim in ragged_shape: + ragged_total += row_count * dim * 128 + # See note - on zero filling counterfactuals + self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total) + + def test_vmap_jumble_over_sin_kernel_grid_remapping(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([128 * x for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) * pl.program_id(2) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), + grid=(1, 5), + interpret=False, + )(x) + + with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): + jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + self.skipTest("Checkify NYI") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([(128 * x) - 1 for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), + grid=(1, 5), + interpret=False, + )(x) + + with self.assertRaisesRegex( + ValueError, + "Ragged input shape must be evenly divisble by the grid" # noqa: W605 + " size at the ragged dimension 2", + ): + jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + +class PallasCallNamedGridInterpretTest(PallasCallRaggedVmapTest): + INTERPRET = True + + +if __name__ == "__main__": + absltest.main() From 60cc0411330d5c3652fe5eb8455408e7abf65912 Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Fri, 16 Aug 2024 10:53:24 -0700 Subject: [PATCH 150/702] Minor fixes and documentation update for custom hermetic Python interpreter support. PiperOrigin-RevId: 663784042 --- docs/developer.md | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/docs/developer.md b/docs/developer.md index 78471a530c99..954cf7982a3a 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -353,17 +353,24 @@ sudo apt-get install libopenblas-dev -y has `custom_python_interpreter()` entry there, pointing to the version of Python you want to build. -3) Run `bazel build @python_dev//:python_dev` to build Python interpreter. By default it will - be built with GCC compiler. If you wish to build with clang, you need to set - corresponding env variables to do so ( +3) Run `bazel build @python_dev//:python_dev -repo_env=HERMETIC_PYTHON_VERSION=3.12` + to build Python interpreter. Note, it is easy to confuse Python version used + to conduct the build (which is needed for technical reasons and is defined by + `HERMETIC_PYTHON_VERSION=3.12`) and the version of Python you are building + (defined by whichever version you specified in `custom_python_interpreter()` + on step 2). For build to succeed, please make sure that hermetic Python you + choose to conduct the build already exists in your configuraiton (the actual + version does not matter, as long as it is a working one). By default, Python + binary will be built with GCC compiler. If you wish to build it with clang, + you need to set corresponding env variables to do so ( e.g. `--repo_env=CC=/usr/lib/llvm-17/bin/clang --repo_env=CXX=/usr/lib/llvm-17/bin/clang++`). 4) Check the output of the previous command. At the very end of it you will find a code snippet for `python_register_toolchains()` entry with your newly built Python in it. Copy that code snippet in your `WORKSPACE` file either right after `python_init_toolchains()` entry (to add the new version of Python) or - instead of it (to replace an existing version, like replacing 3.12 with - custom built variant of 3.12). The code snippet is generated to match your + instead of it (to replace an existing version, like replacing `3.12` with + custom built variant of `3.12`). The code snippet is generated to match your actual setup, so it should work as is, but you can customize it if you choose so (for example to change location of Python's `.tgz` file so it could be downloaded remotely instead of being on local machine). @@ -371,7 +378,11 @@ sudo apt-get install libopenblas-dev -y 5) Make sure there is an entry for your Python's version in `requirements` parameter for `python_init_repositories()` in your WORKSPACE file. For example for `Python 3.13` it should have something - like `"3.13": "//build:requirements_lock_3_13.txt"`. + like `"3.13": "//build:requirements_lock_3_13.txt"`. Note, the key in the + `requirements` parameter must always be in `"major.minor"` version format, so + even if you are building Python version `3.13.0rc1` the corresponding + `requirements` entry must still be `"3.13": "//build:requirements_lock_3_13.txt"`, + **not** `"3.13.0rc1": "//build:requirements_lock_3_13_0rc1.txt"`. 6) For unstable versions of Python, optionally (but highly recommended) run `bazel build //build:all_py_deps --repo_env=HERMETIC_PYTHON_VERSION="3.13"`, From 8b831b89604e411d99577a8646fd457b138f83a1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 16 Aug 2024 14:13:51 -0700 Subject: [PATCH 151/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/075859a60b9ba002c9f1712798c297d3828abebe. PiperOrigin-RevId: 663861515 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b87cb1fc345c..cd9c8d571bf8 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "1db3272ca01754dd38827f4ea332a2f136df5d05" -XLA_SHA256 = "ca5821e7f95e1d26f420619daed6ee6449bbd58e76f6f91dbde63fc72c1ced1c" +XLA_COMMIT = "075859a60b9ba002c9f1712798c297d3828abebe" +XLA_SHA256 = "1720ee6e194714539e00c55b3c7a73b393a21b8c5a97b135d3720a33220b01fe" def repo(): tf_http_archive( From 0c543aef1d90f87ae20ace7a0537a30ac63d6267 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 16 Aug 2024 17:21:10 -0700 Subject: [PATCH 152/702] Match the argument name with the name in `Args` section in docstring PiperOrigin-RevId: 663926739 --- jax/_src/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 7aaec18c9c89..118f2bfe8851 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -883,11 +883,11 @@ def make_array_from_process_local_data( Args: sharding: sharding of the global tensor. - host_local_data: data on the host to be placed on local devices. Each + local_data: data on the host to be placed on local devices. Each dimension should either match global_shape, or match num_addressable_indices(dim). global_shape: the target shape of the global tensor. If None, - will infer from host_local_data and sharding. + will infer from local_data and sharding. Returns: Tensor that will have sharding=sharding and of shape global_shape. From ce4dae1c9c2f9a70c9dbf104fa46745ac700ffc4 Mon Sep 17 00:00:00 2001 From: John QiangZhang Date: Tue, 13 Aug 2024 10:46:13 -0700 Subject: [PATCH 153/702] Add logging the jax2tf `mlir_module_serialized` module size. PiperOrigin-RevId: 662574156 diag_docstring_added diag_docstring_desc_fixed white_space_fixed white_space_fixed white_space_fixed linting fixed --- jax/_src/numpy/lax_numpy.py | 127 +++++++++++++++++++++++++++++++++--- 1 file changed, 118 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0f27cb5ff409..49eb57b266c7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5648,10 +5648,42 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: return diag_indices(s[0], ndim=nd) -@util.implements(np.diagonal, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('offset', 'axis1', 'axis2')) def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: + """Returns the specified diagonal of an array. + + JAX implementation of :func:`numpy.diagonal`. + + The JAX version always returns a copy of the input, although if this is used + within a JIT compilation, the compiler may avoid the copy. + + Args: + a: Input array. Must be at least 2-dimensional. + offset: optional, default=0. Diagonal offset from the main diagonal. + Must be a static integer value. Can be positive or negative. + axis1: optional, default=0. The first axis along which to take the diagonal. + axis2: optional, default=1. The second axis along which to take the diagonal. + + Returns: + A 1D array for 2D input, and in general a N-1 dimensional array + for N-dimensional input. + + See also: + - :func:`jax.numpy.diag` + - :func:`jax.numpy.diagflat` + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6], + ... [7, 8, 9]]) + >>> jnp.diagonal(x) + Array([1, 5, 9], dtype=int32) + >>> jnp.diagonal(x, offset=1) + Array([2, 6], dtype=int32) + >>> jnp.diagonal(x, offset=-1) + Array([4, 8], dtype=int32) + """ util.check_arraylike("diagonal", a) a_shape = shape(a) if ndim(a) < 2: @@ -5667,8 +5699,53 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, return a[..., i, j] if offset >= 0 else a[..., j, i] -@util.implements(np.diag, lax_description=_ARRAY_VIEW_DOC) def diag(v: ArrayLike, k: int = 0) -> Array: + """Returns the specified diagonal or constructs a diagonal array. + + JAX implementation of :func:`numpy.diag`. + + The JAX version always returns a copy of the input, although if this is used + within a JIT compilation, the compiler may avoid the copy. + + Args: + v: Input array. Can be a 1-D array to create a diagonal matrix or a + 2-D array to extract a diagonal. + k: optional, default=0. Diagonal offset. Positive values place the diagonal + above the main diagonal, negative values place it below the main diagonal. + + Returns: + If `v` is a 2-D array, a 1-D array containing the diagonal elements. + If `v` is a 1-D array, a 2-D array with the input elements placed along the + specified diagonal. + + See also: + - :func:`jax.numpy.diagflat` + - :func:`jax.numpy.diagonal` + + Examples: + Creating a diagonal matrix from a 1-D array: + + >>> jnp.diag(jnp.array([1, 2, 3])) + Array([[1, 0, 0], + [0, 2, 0], + [0, 0, 3]], dtype=int32) + + Specifying a diagonal offset: + + >>> jnp.diag(jnp.array([1, 2, 3]), k=1) + Array([[0, 1, 0, 0], + [0, 0, 2, 0], + [0, 0, 0, 3], + [0, 0, 0, 0]], dtype=int32) + + Extracting a diagonal from a 2-D array: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6], + ... [7, 8, 9]]) + >>> jnp.diag(x) + Array([1, 5, 9], dtype=int32) + """ return _diag(v, operator.index(k)) @partial(jit, static_argnames=('k',)) @@ -5685,14 +5762,46 @@ def _diag(v, k): else: raise ValueError("diag input must be 1d or 2d") -_SCALAR_VALUE_DOC = """\ -This differs from np.diagflat for some scalar values of v, -jax always returns a two-dimensional array, whereas numpy may -return a scalar depending on the type of v. -""" - -@util.implements(np.diagflat, lax_description=_SCALAR_VALUE_DOC) def diagflat(v: ArrayLike, k: int = 0) -> Array: + """Return a 2-D array with the flattened input array laid out on the diagonal. + + JAX implementation of :func:`numpy.diagflat`. + + This differs from `np.diagflat` for some scalar values of `v`. JAX always returns + a two-dimensional array, whereas NumPy may return a scalar depending on the type + of `v`. + + Args: + v: Input array. Can be N-dimensional but is flattened to 1D. + k: optional, default=0. Diagonal offset. Positive values place the diagonal + above the main diagonal, negative values place it below the main diagonal. + + Returns: + A 2D array with the input elements placed along the diagonal with the + specified offset (k). The remaining entries are filled with zeros. + + See also: + - :func:`jax.numpy.diag` + - :func:`jax.numpy.diagonal` + + Examples: + >>> jnp.diagflat(jnp.array([1, 2, 3])) + Array([[1, 0, 0], + [0, 2, 0], + [0, 0, 3]], dtype=int32) + >>> jnp.diagflat(jnp.array([1, 2, 3]), k=1) + Array([[0, 1, 0, 0], + [0, 0, 2, 0], + [0, 0, 0, 3], + [0, 0, 0, 0]], dtype=int32) + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.diagflat(a) + Array([[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]], dtype=int32) + """ util.check_arraylike("diagflat", v) v_ravel = ravel(v) v_length = len(v_ravel) From dd697a9abc2a8afba32d6dbc0eeefde600623189 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Sat, 17 Aug 2024 05:49:25 -0700 Subject: [PATCH 154/702] Improve type annotations for jax.Array methods --- jax/_src/basearray.pyi | 127 ++++++++++++++++++++------------ jax/_src/numpy/array_methods.py | 19 +++-- 2 files changed, 87 insertions(+), 59 deletions(-) diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 32e1d27dcdf5..e4c7fdc4a9ab 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -14,17 +14,29 @@ import abc from collections.abc import Callable, Sequence from types import ModuleType -from typing import Any, Union +from typing import Any, Protocol, Union, runtime_checkable import numpy as np from jax._src.sharding import Sharding +# TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py. +# We redefine these here to prevent circular imports. +@runtime_checkable +class SupportsDType(Protocol): + @property + def dtype(self) -> np.dtype: ... +DTypeLike = Union[str, type[Any], np.dtype, SupportsDType] + +Axis = Union[int, Sequence[int], None] Shard = Any # TODO: alias this to xla_client.Traceback Device = Any Traceback = Any +# TODO(jakevdp): fix import cycles and import this from jax._src.lax. +PrecisionLike = Any + class Array(abc.ABC): aval: Any @@ -117,72 +129,89 @@ class Array(abc.ABC): def __release_buffer__(self, view: memoryview) -> None: ... # np.ndarray methods: - def all(self, axis: int | Sequence[int] | None = None, out=None, - keepdims=None, *, where: ArrayLike | None = ...) -> Array: ... - def any(self, axis: int | Sequence[int] | None = None, out=None, - keepdims=None, *, where: ArrayLike | None = ...) -> Array: ... - def argmax(self, axis: int | None = None, out=None, keepdims=None) -> Array: ... - def argmin(self, axis: int | None = None, out=None, keepdims=None) -> Array: ... - def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ... - def argsort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ... - def astype(self, dtype) -> Array: ... - def choose(self, choices, out=None, mode='raise') -> Array: ... - def clip(self, min=None, max=None, out=None) -> Array: ... - def compress(self, condition, axis: int | None = None, out=None) -> Array: ... + def all(self, axis: Axis = None, out: None = None, + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ... + def any(self: Array, axis: Axis = None, out: None = None, + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ... + def argmax(self: Array, axis: int | None = None, out: None = None, + keepdims: bool | None = None) -> Array: ... + def argmin(self, axis: int | None = None, out: None = None, + keepdims: bool | None = None) -> Array: ... + def argpartition(self, kth, axis=-1, kind='introselect', order: None = None) -> Array: ... + def argsort(self, axis: int | None = -1, kind='quicksort', order: None = None) -> Array: ... + def astype(self, dtype: DTypeLike | None = None, max: ArrayLike | None = None) -> Array: ... + def choose(self, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: ... + def clip(self, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: ... + def compress(self, condition: ArrayLike, + axis: int | None = None, *, out: None = None, + size: int | None = None, fill_value: ArrayLike = 0) -> Array: ... def conj(self) -> Array: ... def conjugate(self) -> Array: ... def copy(self) -> Array: ... def cumprod(self, axis: int | Sequence[int] | None = None, - dtype=None, out=None) -> Array: ... + dtype: DTypeLike | None = None, out: None = None) -> Array: ... def cumsum(self, axis: int | Sequence[int] | None = None, - dtype=None, out=None) -> Array: ... - def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ... - def dot(self, b, *, precision=None) -> Array: ... - def flatten(self) -> Array: ... + dtype: DTypeLike | None = None, out: None = None) -> Array: ... + def diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: ... + def dot(self, b: ArrayLike, *, precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: ... + def flatten(self, order: str = "C") -> Array: ... @property def imag(self) -> Array: ... - def item(self, *args) -> Any: ... - def max(self, axis: int | Sequence[int] | None = None, out=None, - keepdims=None, initial=None, where=None) -> Array: ... - def mean(self, axis: int | Sequence[int] | None = None, dtype=None, - out=None, keepdims=False, *, where=None,) -> Array: ... - def min(self, axis: int | Sequence[int] | None = None, out=None, - keepdims=None, initial=None, where=None) -> Array: ... + def item(self, *args: int) -> Any: ... + def max(self, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... + def mean(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, *, + where: ArrayLike | None = None) -> Array: ... + def min(self, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... @property def nbytes(self) -> int: ... - def nonzero(self, *, size=None, fill_value=None) -> Array: ... - def prod(self, axis: int | Sequence[int] | None = None, dtype=None, - out=None, keepdims=None, initial=None, where=None) -> Array: ... - def ptp(self, axis: int | Sequence[int] | None = None, out=None, - keepdims=False,) -> Array: ... - def ravel(self, order='C') -> Array: ... + def nonzero(self, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None, + size: int | None = None,) -> tuple[Array, ...]: ... + def prod(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None, + promote_integers: bool = True) -> Array: ... + def ptp(self, axis: Axis = None, out: None = None, + keepdims: bool = False) -> Array: ... + def ravel(self, order: str = 'C') -> Array: ... @property def real(self) -> Array: ... - def repeat(self, repeats, axis: int | None = None, *, - total_repeat_length=None) -> Array: ... - def reshape(self, *args, order='C') -> Array: ... - def round(self, decimals=0, out=None) -> Array: ... - def searchsorted(self, v, side='left', sorter=None) -> Array: ... - def sort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ... + def repeat(self, repeats: ArrayLike, axis: int | None = None, *, + total_repeat_length: int | None = None) -> Array: ... + def reshape(self, *args: Any, order: str = "C") -> Array: ... + def round(self, decimals: int = 0, out: None = None) -> Array: ... + def searchsorted(self, v: ArrayLike, side: str = 'left', + sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ... + def sort(self, axis: int | None = -1, *, kind: None = None, + order: None = None, stable: bool = True, descending: bool = False) -> Array: ... def squeeze(self, axis: int | Sequence[int] | None = None) -> Array: ... - def std(self, axis: int | Sequence[int] | None = None, - dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... - def sum(self, axis: int | Sequence[int] | None = None, dtype=None, - out=None, keepdims=None, initial=None, where=None) -> Array: ... + def std(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ... + def sum(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None, promote_integers: bool = True) -> Array: ... def swapaxes(self, axis1: int, axis2: int) -> Array: ... - def take(self, indices, axis: int | None = None, out=None, - mode=None) -> Array: ... - def tobytes(self, order='C') -> bytes: ... + def take(self, indices: ArrayLike, axis: int | None = None, out: None = None, + mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, + fill_value: StaticScalar | None = None) -> Array: ... + def tobytes(self, order: str = 'C') -> bytes: ... def tolist(self) -> list[Any]: ... - def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, - out=None) -> Array: ... - def transpose(self, *args) -> Array: ... + def trace(self, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, + dtype: DTypeLike | None = None, out: None = None) -> Array: ... + def transpose(self, *args: Any) -> Array: ... @property def T(self) -> Array: ... @property def mT(self) -> Array: ... - def var(self, axis: int | Sequence[int] | None = None, - dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... + def var(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ... def view(self, dtype=None, type=None) -> Array: ... # Even though we don't always support the NumPy array protocol, e.g., for diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 03745a7dcd45..a0222b5c586d 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -58,7 +58,7 @@ # functions, which can themselves handle instances from any of these classes. -def _all(self: ArrayLike, axis: reductions.Axis = None, out: None = None, +def _all(self: Array, axis: reductions.Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: """Test whether all array elements along a given axis evaluate to True. @@ -107,7 +107,8 @@ def _argsort(self: Array, axis: int | None = -1, *, kind: None = None, order: No return lax_numpy.argsort(self, axis=axis, kind=kind, order=order, stable=stable, descending=descending) -def _astype(self: Array, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: +def _astype(self: Array, dtype: DTypeLike | None, copy: bool = False, + device: xc.Device | Sharding | None = None) -> Array: """Copy the array and cast to a specified dtype. This is implemented via :func:`jax.lax.convert_element_type`, which may @@ -124,13 +125,12 @@ def _choose(self: Array, choices: Sequence[ArrayLike], out: None = None, mode: s """ return lax_numpy.choose(self, choices=choices) -def _clip(number: ArrayLike, - min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: +def _clip(self: Array, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: """Return an array whose values are limited to a specified range. Refer to :func:`jax.numpy.clip` for full documentation. """ - return lax_numpy.clip(number, min=min, max=max) + return lax_numpy.clip(self, min=min, max=max) def _compress(self: Array, condition: ArrayLike, axis: int | None = None, *, out: None = None, @@ -163,7 +163,7 @@ def _copy(self: Array) -> Array: """ return lax_numpy.copy(self) -def _cumprod(self: Array, /, axis: int | Sequence[int] | None = None, +def _cumprod(self: Array, axis: int | Sequence[int] | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: """Return the cumulative product of the array. @@ -171,7 +171,7 @@ def _cumprod(self: Array, /, axis: int | Sequence[int] | None = None, """ return reductions.cumprod(self, axis=axis, dtype=dtype, out=out) -def _cumsum(self: Array, /, axis: int | Sequence[int] | None = None, +def _cumsum(self: Array, axis: int | Sequence[int] | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: """Return the cumulative sum of the array. @@ -258,9 +258,8 @@ def _nbytes_property(self: Array) -> int: """Total bytes consumed by the elements of the array.""" return np.size(self) * dtypes.dtype(self, canonicalize=True).itemsize -def _nonzero(self: Array, *, size: int | None = None, - fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None - ) -> tuple[Array, ...]: +def _nonzero(self: Array, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None, + size: int | None = None) -> tuple[Array, ...]: """Return indices of nonzero elements of an array. Refer to :func:`jax.numpy.nonzero` for the full documentation. From e39c8685d09434ea1c917e4b1cca8a1e437bf576 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Sat, 17 Aug 2024 22:39:10 +0530 Subject: [PATCH 155/702] Remove the unused _CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION --- jax/_src/numpy/lax_numpy.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0f27cb5ff409..d6c11d181025 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -385,13 +385,6 @@ def trunc(x: ArrayLike) -> Array: return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x)) -_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """ -preferred_element_type : dtype, optional - If specified, accumulate results and return a result of the given data type. - If not specified, the function instead follows the numpy convention of always - accumulating results and returning an inexact dtype. -""" - @partial(jit, static_argnames=['mode', 'op', 'precision', 'preferred_element_type']) def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, preferred_element_type: DTypeLike | None = None) -> Array: From b957f8baab287f1a0e1e880b885f89b1f4272b50 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 17 Aug 2024 13:49:37 -0700 Subject: [PATCH 156/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/1a810717944fef76920d1b718df92aed4abdfc57. PiperOrigin-RevId: 664212400 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index cd9c8d571bf8..8c123f0a279a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "075859a60b9ba002c9f1712798c297d3828abebe" -XLA_SHA256 = "1720ee6e194714539e00c55b3c7a73b393a21b8c5a97b135d3720a33220b01fe" +XLA_COMMIT = "1a810717944fef76920d1b718df92aed4abdfc57" +XLA_SHA256 = "845562ca0b222f54d693fd86e1c2609758081ba1ca5af522262bd734e692a3df" def repo(): tf_http_archive( From ba5b081571d4558b05f92dfe736331ce92647dbe Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 18 Aug 2024 09:09:00 -0700 Subject: [PATCH 157/702] [numpy] Fix test failures under NumPy 2.0. PiperOrigin-RevId: 664465687 --- jax/_src/lax/lax.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ca0096c2e05c..e81c6c157627 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4994,13 +4994,15 @@ def padtype_to_pads(in_shape, window_shape, window_strides, padding): for d in (out_shape - 1) * window_strides + window_shape - in_shape) if padding == PaddingType.SAME: - return [ + pads = [ (pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes ] else: - return [ + pads = [ (pad_size - pad_size // 2, pad_size // 2) for pad_size in pad_sizes ] + # Avoids verbose numpy scalars in jaxprs. + return [p.item() if isinstance(p, np.generic) else p for p in pads] elif padding == PaddingType.VALID: return [(0, 0)] * len(in_shape) else: From 05792c952fdc12b05731154667c663c34f084e50 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 18 Aug 2024 14:35:12 -0700 Subject: [PATCH 158/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b43657003ce347f83a36fb6f56f528ce3ce982c4. PiperOrigin-RevId: 664534707 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 8c123f0a279a..6419cee31204 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "1a810717944fef76920d1b718df92aed4abdfc57" -XLA_SHA256 = "845562ca0b222f54d693fd86e1c2609758081ba1ca5af522262bd734e692a3df" +XLA_COMMIT = "b43657003ce347f83a36fb6f56f528ce3ce982c4" +XLA_SHA256 = "d58acd72ffb049756690600c6752f054ff5274e2b097fb92fbcc0f661b26d4d8" def repo(): tf_http_archive( From dad2f576ac5621ee4c3135646c62c55f4086dd3e Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 19 Aug 2024 01:04:52 -0700 Subject: [PATCH 159/702] Add support for shape polymorphism in ffi_lowering and move lu_pivots_to_permutation lowering out of jaxlib. The lowering logic for all jaxlib custom calls are currently split between JAX and jaxlib for reasons that are harder to justify now that the compiled calls are split between jaxlib and the relevant plugins. As part of my project to update these calls and simplify the lowering logic, it makes sense to consolidate the lowering rules in JAX instead of jaxlib since the logic is now the same for both GPU and CPU. This update tackles a simple kernel as a test case for what this would look like. Since the full lowering rule is now implemented in JAX, we can take advantage of the MLIR helpers that are included there, including `jex.ffi.ffi_lowering`, which I needed to update to support shape polymorphism. Of note: I think it is safe (in a compatibility sense) to delete the lowering code from jaxlib, but it does mean that it won't be possible to lower this operation when `jax.__version__ < jaxlib.__version__`. I think this is okay given our compatibility guarantees, but I'd love a sanity check on that! Another note, this doesn't actually change the lowered HLO for this op, so we don't need to worry about export compatibility. PiperOrigin-RevId: 664680250 --- jax/_src/extend/ffi.py | 5 +++ jax/_src/lax/linalg.py | 22 +++++------- .../jax2tf/tests/primitives_test.py | 2 ++ jaxlib/gpu_linalg.py | 34 +------------------ 4 files changed, 17 insertions(+), 46 deletions(-) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index 66af4f331d78..c39e05d9335d 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -142,6 +142,11 @@ def _lowering( kwargs["operand_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_in) # pytype: disable=attribute-error if result_layouts is None: kwargs["result_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_out) + if "result_shapes" not in kwargs and not all( + core.is_constant_shape(aval.shape) for aval in ctx.avals_out): + kwargs["result_shapes"] = [ + mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, aval.shape)) + for aval in ctx.avals_out] return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index b1e92b89af29..f771b83b372f 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -31,6 +31,7 @@ from jax._src import dtypes from jax._src.core import ( Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape) +from jax._src.extend import ffi from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -1190,16 +1191,13 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *, return lu_pivots_to_permutation_p.bind( x, permutation_size=permutation_size), 0 -def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *, +def _lu_pivots_to_permutation_gpu_lowering(platform, ctx, pivots, *, permutation_size): - # TODO(danfm): Remove once jaxlib 0.4.32 is the minimum version. - if jaxlib_version >= (0, 4, 32): - pivots_aval, = ctx.avals_in - pivots_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, pivots_aval.shape) - kwargs = dict(pivots_shape_vals=pivots_shape_vals) - else: - kwargs = {} - return lowering(pivots, permutation_size=permutation_size, **kwargs) + rule = ffi.ffi_lowering(f"{platform}_lu_pivots_to_permutation") + return rule(ctx, pivots, + # TODO(b/358275922): remove unused parameter 12 weeks after + # the release of jaxlib v0.4.32. + permutation_size=np.int32(permutation_size)) lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation') @@ -1215,13 +1213,11 @@ def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *, mlir.lower_fun(_generic_lu_pivots_to_permutation, multiple_results=False)) mlir.register_lowering( lu_pivots_to_permutation_p, - partial(_lu_pivots_to_permutation_gpu_lowering, - gpu_linalg.cuda_lu_pivots_to_permutation), + partial(_lu_pivots_to_permutation_gpu_lowering, "cu"), platform='cuda') mlir.register_lowering( lu_pivots_to_permutation_p, - partial(_lu_pivots_to_permutation_gpu_lowering, - gpu_linalg.hip_lu_pivots_to_permutation), + partial(_lu_pivots_to_permutation_gpu_lowering, "hip"), platform='rocm') # LU decomposition diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 51a6d45556bd..5169ba8ab252 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -183,6 +183,8 @@ def test_primitive_coverage(self): continue if p.name == "pallas_call": continue + if p.name == "ffi_call": + continue if p.name == "tpu_custom_call": continue if p.name == "custom_partitioning": diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index f392cc690046..39b3aaea2072 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -20,7 +20,7 @@ import jaxlib.mlir.ir as ir -from .hlo_helpers import custom_call, mk_result_types_and_shapes +from .hlo_helpers import custom_call from .gpu_common_utils import GpuLibNotLinkedError from jaxlib import xla_client @@ -61,38 +61,6 @@ _prod = lambda xs: functools.reduce(operator.mul, xs, 1) -def _lu_pivots_to_permutation_hlo(platform, pivots, *, permutation_size, - pivots_shape_vals): - """Kernel for the transformation of pivots to permutations on GPU.""" - typ = ir.RankedTensorType(pivots.type) - i32_type = ir.IntegerType.get_signless(32) - assert typ.element_type == i32_type, typ - assert len(pivots_shape_vals) >= 1 - - pivots_layout = tuple(range(len(pivots_shape_vals) - 1, -1, -1)) - permutations_layout = pivots_layout - permutations_dims = (*pivots_shape_vals[:-1], permutation_size) - result_types, result_shapes = mk_result_types_and_shapes( - [(permutations_dims, i32_type)]) - return custom_call( - f"{platform}_lu_pivots_to_permutation", - api_version=4, - operands=[pivots], - operand_layouts=[pivots_layout], - result_types=result_types, - result_shapes=result_shapes, - result_layouts=[permutations_layout], - # TODO(b/358275922): remove backend_config 12 weeks after release of - # jaxlib v0.4.32. - backend_config=dict( - permutation_size=ir.IntegerAttr.get(i32_type, permutation_size), - ), - ).results - -cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "cu") -hip_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "hip") - - def _cholesky_update_hlo(platform, gpu_linalg, r_matrix, w_vector, dtype): """Cholesky update.""" del platform From 66a3f87a24016594794c2ee289826baed5e979a4 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 19 Aug 2024 04:28:06 -0700 Subject: [PATCH 160/702] Rollback for: Implement initial vmap over pallas_call w/ ragged inputs (via jumbles) It can cause issues in x32 when trying to get the aval for array dimension sizes that are larger than i32. Reverts 24394a1b03f01138219013f4773104b834e498b7 PiperOrigin-RevId: 664742891 --- jax/_src/core.py | 24 +-- jax/_src/interpreters/batching.py | 6 +- jax/_src/pallas/core.py | 53 ++---- jax/_src/pallas/mosaic/lowering.py | 28 +-- jax/_src/pallas/pallas_call.py | 275 +++-------------------------- tests/pallas/BUILD | 23 --- tests/pallas/pallas_jumble_test.py | 201 --------------------- 7 files changed, 50 insertions(+), 560 deletions(-) delete mode 100644 tests/pallas/pallas_jumble_test.py diff --git a/jax/_src/core.py b/jax/_src/core.py index 61ed81cdeea9..ebf29cf0b253 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1954,7 +1954,6 @@ def __init__(self, aval, data): assert data.shape == pad_shape self._aval = aval self._data = data - shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) aval = property(lambda self: self._aval) @@ -1965,38 +1964,21 @@ def __repr__(self) -> str: dtypestr = _short_dtype_name(self._aval.dtype) shapestr = ','.join(map(str, self.shape)) - data = self.data + slices = tuple(slice(int(d._data)) if type(d) is DArray and + type(d.dtype) is bint else slice(None) for d in self.shape) + data = self._data[slices] return f'{dtypestr}[{shapestr}] with value: {data}' - def __hash__(self) -> int: if not self.shape: return hash((self._aval, int(self._data))) raise TypeError("unhashable type: DArray") - def __eq__(self, other): if isinstance(other, DArray) and self._aval == other._aval: return self._data == other._data return False - def __len__(self): return self.shape[0] - @property - def data(self): - if not self.shape and type(self.dtype) is bint: - # special-case scalar bints - return self._data - - slices = tuple( - slice(int(d._data)) - if type(d) is DArray and type(d.dtype) is bint - else slice(None) - for d in self.shape - ) - data = self._data[slices] - return data - - pytype_aval_mappings[DArray] = \ lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, x._data) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 27cde6d31d35..fbcd2c4a7a30 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -88,7 +88,6 @@ def _jumble_flatten(jumble): elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape)) aval = jumble.aval.replace(elt_ty=elt_ty) return (lengths, jumble.data), aval - def _jumble_unflatten(aval, x): lengths, data = x new_shape = [d.replace(lengths=lengths[d.lengths - 1]) @@ -252,10 +251,7 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: - # TODO(mvoz): This is a terrible place to fall into if you pass - # a non jumble type in, make it clearer what went wrong. - assert False, f'Unexpected type in ELT? {type(x)}' - + assert False to_elt_handlers: dict[type, ToEltHandler] = {} def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 09e02ea5c3a1..0ef208f755e5 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -112,10 +112,7 @@ class AbstractMemoryRef(state.AbstractRef): def __init__(self, inner_aval: jax_core.AbstractValue, memory_space: Any): - - assert isinstance( - inner_aval, jax_core.ShapedArray - ), f"Illegal ref, got {type(inner_aval)}" + assert isinstance(inner_aval, jax_core.ShapedArray) self.inner_aval = inner_aval self.memory_space = memory_space @@ -170,7 +167,9 @@ class PallasGridContext: mapped_dims: tuple[int, ...] def size(self, axis: int) -> int | DynamicGridDim: - valid_grid = tuple(self.grid) + valid_grid = tuple( + s for i, s in enumerate(self.grid) if i not in self.mapped_dims + ) try: size = valid_grid[axis] except IndexError as e: @@ -339,10 +338,7 @@ def check_invariants(self) -> None: ) assert not self.index_map_jaxpr.consts - assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), ( - self.block_shape, - self.index_map_jaxpr.out_avals, - ) + assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals) assert all(ov.shape == () and (ov.dtype == jnp.int32 or ov.dtype == jnp.int64) for ov in self.index_map_jaxpr.out_avals), ( @@ -426,8 +422,6 @@ class GridMapping: num_inputs: int num_outputs: int num_scratch_operands: int - get_grid_indices: Callable | None = None - local_grid_env: Callable | None = None def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -448,8 +442,8 @@ def check_invariants(self) -> None: assert len(index_map_args) >= len(self.grid) for i in range(len(self.grid)): index_map_arg = index_map_args[i] - assert index_map_arg.shape == (), f"index_map_arg: {index_map_arg}" - assert index_map_arg.dtype == jnp.int32, f"index_map_arg: {index_map_arg}" + assert index_map_arg.shape == () + assert index_map_arg.dtype == jnp.int32 assert len(self.vmapped_dims) <= len(self.grid) for i in self.vmapped_dims: @@ -460,11 +454,8 @@ def check_invariants(self) -> None: for bm in self.block_mappings: bm.check_invariants() - assert tuple(self.index_map_avals) == tuple( - bm.index_map_jaxpr.in_avals - ), ( + assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), ( self.index_map_avals, - "|", bm.index_map_jaxpr.in_avals, ) @@ -556,17 +547,6 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool: return True return jax_core.is_dim(dim) - -def _max_shape_from_aval(array_aval: jax_core.ShapedArray): - array_aval_shape = list(array_aval.shape) - for i, s in enumerate(array_aval.shape): - aval = jax_core.get_aval(s) - if isinstance(aval, jax_core.DShapedArray): - array_aval_shape[i] = aval.dtype.bound - - return tuple(array_aval_shape) - - def _convert_block_spec_to_block_mapping( block_spec: BlockSpec, origin: OriginStr, @@ -595,15 +575,8 @@ def _convert_block_spec_to_block_mapping( f"array shape {array_aval.shape}.") unmapped_block_shape = tuple(s for s in block_shape if s is not None) - block_array_aval = array_aval.update(shape=unmapped_block_shape) - if isinstance(array_aval, jax_core.DShapedArray): - # Get the "max" shape for the ragged array. - block_array_aval = jax_core.ShapedArray( - block_array_aval.shape, - block_array_aval.dtype, - block_array_aval.weak_type, - ) - block_aval = AbstractMemoryRef(block_array_aval, block_spec.memory_space) + block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape), + block_spec.memory_space) if not jax_core.is_constant_shape(block_aval.shape): raise ValueError( @@ -636,12 +609,12 @@ def _convert_block_spec_to_block_mapping( f"{origin} must return integer scalars. Output[{i}] has type " f"{ov}.") + if consts: raise ValueError( f"Index map function {index_map_src_info} for " f"{origin} must not capture constants: {consts}") - array_aval_shape = _max_shape_from_aval(array_aval) mapping = BlockMapping( block_shape=mapped_block_shape, @@ -649,9 +622,7 @@ def _convert_block_spec_to_block_mapping( index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), index_map_src_info=index_map_src_info, indexing_mode=block_spec.indexing_mode, - array_shape_dtype=jax.ShapeDtypeStruct( - array_aval_shape, array_aval.dtype - ), + array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype), origin=origin, ) mapping.check_invariants() diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index aee894ee1b7e..86ce2f0b1b81 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -298,7 +298,6 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, self.jaxpr = jaxpr self.block_mappings = grid_mapping.block_mappings self.mapped_dims = grid_mapping.vmapped_dims - # TODO(mvoz): Generalize to not need this user_grid = tuple( g for i, g in enumerate(self.grid) if i not in self.mapped_dims ) @@ -346,19 +345,9 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, for _ in range(len(self.grid)) ]) self._prepare_mesh_info(mesh) - - if grid_mapping.get_grid_indices is None: - - def _get_grid_indices(indices, maybe_include_mapped_dims: bool): - if maybe_include_mapped_dims: - return indices - return tuple( - idx for i, idx in enumerate(indices) if i not in self.mapped_dims - ) - - self.get_grid_indices = _get_grid_indices - else: - self.get_grid_indices = grid_mapping.get_grid_indices + def _get_grid_indices(indices): + return indices + self.get_grid_indices = _get_grid_indices def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): if not self.has_communication: @@ -606,9 +595,7 @@ def lower_jaxpr_to_transform_func( ] def body_func(*args): grid_indices, scalar_prefetch = split_list(args, [num_grid]) - jaxpr_indices = mosaic_grid_mapping.get_grid_indices( - grid_indices, maybe_include_mapped_dims=True - ) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) arg_block_shapes = [ *[()] * len(jaxpr_indices), *mosaic_grid_mapping.scalar_prefetch_block_shapes, @@ -676,9 +663,9 @@ def lower_jaxpr_to_func( def body_func(*args): grid_indices, scalar_prefetch, operands_and_scratch = split_list( args, [num_grid, num_scalar_prefetch]) - jaxpr_indices = mosaic_grid_mapping.get_grid_indices( - grid_indices, maybe_include_mapped_dims=False - ) + grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) + jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices) + if i not in mosaic_grid_mapping.mapped_dims) mesh_info = mosaic_grid_mapping.mesh_info if mesh_info is not None: mesh_context = MeshContext( @@ -2378,7 +2365,6 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): - if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index bb1683e38b1c..4f3c9918f664 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -228,12 +228,6 @@ def _pallas_call_impl_interpret( # Pad values to evenly divide into block dimensions. This matches the # behavior of the non-interpret mode. We pad with NaN, to make it easier # to catch OOB accesses. - for carry_element in carry: - aval = carry_element.aval - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype) - carry_element.aval = aval - carry = map(_pad_values_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) @@ -253,16 +247,11 @@ def cond(carry): return i < num_iterations def body(carry): i, loop_idx, *carry_blocks = carry - - if grid_mapping.local_grid_env is not None: - local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) - else: - local_grid_env = tuple( - pallas_core.GridAxis(idx, b) - for dim, (idx, b) in enumerate(zip(loop_idx, grid)) - if dim not in grid_mapping.vmapped_dims - ) - + local_grid_env = tuple( + pallas_core.GridAxis(idx, b) + for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + if dim not in grid_mapping.vmapped_dims + ) carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks]) with pallas_core.grid_env(local_grid_env): start_indices = [ @@ -279,14 +268,8 @@ def body(carry): len(blocks), len(scratch_values), ) - for s in scalars: - aval = jax_core.get_aval(s) - if isinstance(aval, jax_core.DShapedArray): - s.aval = aval.update(dtype=jnp.int32) - - blocks = jax_core.eval_jaxpr( - discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch - ) + blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, + *blocks, *scratch) _, out_inout, out_scratch = split_list( blocks, [grid_mapping.num_index_operands, num_inout_blocks]) @@ -407,55 +390,19 @@ def _pallas_call_jvp_rule( ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule - -def _batch_block_mapping( - grid_mapping: GridMapping, - axis_size: int, - aval: jax_core.ShapedArray, - dim: int | batching.NotMapped, - block_mapping: BlockMapping, - for_ragged: bool, -) -> BlockMapping: +def _batch_block_mapping(grid_mapping: GridMapping, + axis_size: int, + aval: jax_core.ShapedArray, + dim: int | batching.NotMapped, + block_mapping: BlockMapping) -> BlockMapping: def _block_map_function(new_idx, *args): - if for_ragged: - drop_last_args = args[:-1] - else: - drop_last_args = args - - indices = jax_core.eval_jaxpr( - block_mapping.index_map_jaxpr.jaxpr, - block_mapping.index_map_jaxpr.consts, - *drop_last_args, - ) + indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr, + block_mapping.index_map_jaxpr.consts, + *args) if dim is not batching.not_mapped: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - stacked_axis = dim.stacked_axis - indices.insert(stacked_axis, new_idx) - else: - indices.insert(dim, new_idx) + indices.insert(dim, new_idx) return tuple(indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] - - if for_ragged: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - _, _, ragged_axis_length = _ragged_axis_parts(dim) - aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) - lengths_aval = pallas_core.AbstractMemoryRef( - aval, - pallas_core.MemorySpace.INDEX, - ) - idx_avals = [*idx_avals, lengths_aval] - else: - i32_aval_memref = pallas_core.AbstractMemoryRef( - jax_core.ShapedArray(([axis_size]), jnp.int32), - pallas_core.MemorySpace.INDEX, - ) - idx_avals = [*idx_avals, i32_aval_memref] - with grid_mapping.trace_env(): block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_block_map_function), idx_avals) @@ -464,27 +411,12 @@ def _block_map_function(new_idx, *args): new_block_shape = shape new_array_shape_dtype = block_mapping.array_shape_dtype else: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - new_block_shape = shape - stacked_axis = dim.stacked_axis - new_block_shape = tuple_insert( - new_block_shape, stacked_axis, pallas_core.mapped - ) - else: - new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) - - array_shape = block_mapping.array_shape_dtype.shape - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - stacked_axis = dim.stacked_axis - array_shape = tuple_insert(array_shape, stacked_axis, axis_size) - else: - array_shape = tuple_insert(array_shape, dim, axis_size) - + new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) new_array_shape_dtype = jax.ShapeDtypeStruct( - array_shape, block_mapping.array_shape_dtype.dtype - ) + tuple_insert(block_mapping.array_shape_dtype.shape, + dim, + axis_size), + block_mapping.array_shape_dtype.dtype) jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts) return block_mapping.replace(block_shape=new_block_shape, @@ -615,16 +547,6 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: return result, (0,) * len(result) -def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]: - stacked_axis = dim.stacked_axis - ragged_axes = dim.ragged_axes - if len(ragged_axes) != 1: - raise ValueError("Multiple ragged axes not yet implemented.") - ragged_axis_dim = ragged_axes[0][0] - ragged_axis_length = ragged_axes[0][1] - return stacked_axis, ragged_axis_dim, ragged_axis_length - - def _pallas_call_batching_rule( args, dims, @@ -645,26 +567,8 @@ def _maybe_squeeze_out_bdim( return x return jnp.squeeze(x, axis=bdim) - all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)] - if len(all_ragged_axes) > 1: - raise ValueError("Multiple ragged dimensions not yet implemented.") - - if all_ragged_axes: - stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts( - all_ragged_axes[0] - ) - else: - stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None - - def get_size(i, x, d): - if not isinstance(d, batching.RaggedAxis): - return x.shape[d] - return x.aval.shape[i] - (axis_size,) = { - get_size(i=i, x=x, d=d) - for i, (x, d) in enumerate(zip(args, dims)) - if d is not batching.not_mapped + x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped } if axis_size == 1: # Why are we even vmapping? @@ -766,27 +670,12 @@ def get_size(i, x, d): num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands - lengths_aval = None - if ragged_axis_length is not None: - aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) - lengths_aval = pallas_core.AbstractMemoryRef( - aval, - pallas_core.MemorySpace.INDEX, - ) - # Only add a batch dimension for the avals that actually have a grid mapping. # This excludes scalar prefetch inputs (the first in the list) and scratch # operands (the last in the list). avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] batched_block_mappings = map( - partial( - _batch_block_mapping, - grid_mapping, - axis_size, - for_ragged=lengths_aval is not None, - ), + partial(_batch_block_mapping, grid_mapping, axis_size), avals_to_batch, all_dims[num_index_operands:], block_mappings, @@ -796,23 +685,15 @@ def get_size(i, x, d): grid_mapping.index_map_avals) assert not index_map_tree_kwargs batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args - - if lengths_aval: - batched_index_map_args = batched_index_map_args + (lengths_aval,) - num_index_operands += 1 - batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten( (batched_index_map_args, {})) - batched_grid_mapping = grid_mapping.replace( grid=(axis_size, *grid_mapping.grid), block_mappings=tuple(batched_block_mappings), - index_map_avals=tuple(batched_index_map_avals), + index_map_avals=batched_index_map_avals, index_map_tree=batched_index_map_tree, - num_index_operands=num_index_operands, vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims), ) - if cost_estimate is not None: batched_cost_estimate = CostEstimate( flops=cost_estimate.flops * axis_size, @@ -821,103 +702,6 @@ def get_size(i, x, d): ) else: batched_cost_estimate = None - - if lengths_aval: - batched_grid_mapping = batched_grid_mapping.replace( - get_grid_indices=lambda indices, maybe_include_mapped_dims: indices, - local_grid_env=lambda loop_idx, grid: tuple( - pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid) - ), - ) - - # Note - on zero filling counterfactuals - # A debug util to produce a counterfactual version of the when - # gating, where for all values that don't pass the @when check, - # we write 0s. This is useful for debugging, as certain lowering paths - # like mosaic will write the last data as passthrough, leading to - # potentially confusing results. - debug_zero_fill_counterfactual = debug - - first_block_mapping = batched_grid_mapping.block_mappings[0] - for block_mapping in batched_grid_mapping.block_mappings: - # This invariant may already be checked elsewhere, but lets reaffirm it - assert block_mapping.block_shape == first_block_mapping.block_shape, ( - f"block_mapping.block_shape: {block_mapping.block_shape}, " - f"first_block_mapping.block_shape: {first_block_mapping.block_shape}" - ) - assert ( - block_mapping.array_shape_dtype - == first_block_mapping.array_shape_dtype - ), ( - f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype}," - " first_block_mapping.array_shape_dtype:" - f" {first_block_mapping.array_shape_dtype}" - ) - - mapped_dim_idxs = [ - i - for i, d in enumerate(first_block_mapping.block_shape) - if d is pallas_core.mapped - ] - assert len(mapped_dim_idxs) == 1 - mapped_dim_idx = mapped_dim_idxs[0] - if stacked_axis != mapped_dim_idx: - raise ValueError( - f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}" - ) - - assert ragged_axis_dim is not None, "Invariant violation" - # This is the blockspec size of the dimension - val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim] - - def when_wrapped_kernel(lengths_ref, *args, **kwargs): - b_idx = jax.experimental.pallas.program_id(stacked_axis) - i_idx = ( - jax.experimental.pallas.program_id(ragged_axis_dim) - * val_at_ragged_dim - ) - b_len = lengths_ref[b_idx] - - # TODO(mvoz): Unimplemented primitive in pallas - # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) - # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") - - @jax.experimental.pallas.when(i_idx < b_len) - def f(): - # Important! This allows us to trace the inner kernel with the correct - # grid to preserve user program_id semantics. Ex: program_id(0) will - # always be analogous to program_id(1) in the outer kernel. - with pallas_core.tracing_grid_env(grid_mapping.grid, ()): - jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs) - - if debug_zero_fill_counterfactual: - - @jax.experimental.pallas.when(i_idx >= b_len) - def g(): - for arg_ref in args: - arg_ref[...] = jnp.zeros_like(arg_ref) - - kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars] - flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten( - list(kernel_avals) - ) - # Important! This allows us to trace the outer kernel with the correct grid - # to enable accessing the batch program_id. - with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): - kernel_src_info: pallas_core.SrcInfoStr = "" - - jaxpr = _trace_kernel_to_jaxpr( - when_wrapped_kernel, - kernel_src_info, - batched_grid_mapping, - tuple(flat_kernel_avals), - kernel_in_tree, - interpret=interpret, - ) - - assert ragged_axis_length is not None - args = (ragged_axis_length, *args) - out = pallas_call_p.bind( *dynamic_grid_args, *args, @@ -1313,14 +1097,12 @@ def pallas_call( out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths) flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore for x in flat_out_shapes] - @jax.jit def wrapped(*args): flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args) in_paths, flat_args = unzip2(flat_args_with_paths) flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) for a in flat_args) - flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype) for v in flat_out_shapes) @@ -1390,18 +1172,15 @@ def wrapped(*args): return wrapped -def in_path_to_input_origin( - in_path: tree_util.KeyPath, arg_names: tuple[str, ...] | None -) -> pallas_core.OriginStr: +def in_path_to_input_origin(in_path: tree_util.KeyPath, + arg_names: tuple[str, ...] | None) -> pallas_core.OriginStr: """Converts `args[k]` into `arg_k_name`.""" if arg_names is None: return f"args{tree_util.keystr(in_path)}" if len(in_path) == 0: return "args" arg_idx, *rest_path = in_path - if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len( - arg_names - ): + if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names): return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path)) else: return f"args{tree_util.keystr(tuple(in_path))}" diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 5559a0552f9f..c0cf61387cbb 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -62,29 +62,6 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( - name = "pallas_jumble_test", - srcs = [ - "pallas_jumble_test.py", - ], - disable_configs = [ - "gpu", - "gpu_x32", - "gpu_a100", - "gpu_p100", - "gpu_p100_x32", - "gpu_h100", - ], - shard_count = { - "tpu": 1, - }, - deps = [ - "//jax:pallas", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), -) - jax_test( name = "ops_test", srcs = [ diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py deleted file mode 100644 index ee176a0363aa..000000000000 --- a/tests/pallas/pallas_jumble_test.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" - -from absl.testing import absltest -import jax -from jax import lax -from jax._src import config -from jax._src import core -from jax._src import dtypes -from jax._src import test_util as jtu -from jax._src.interpreters import batching -from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr -from jax.experimental import pallas as pl -import jax.numpy as jnp -import numpy as np - - -# TODO(mvoz): Update signatures of pallas_call to correct inputs/outputs. -# pylint: disable=no-value-for-parameter - -config.parse_flags_with_absl() - - -intx = dtypes.canonicalize_dtype(jnp.int64) -floatx = dtypes.canonicalize_dtype(jnp.float64) - - -@jtu.with_config(jax_traceback_filtering="off") -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False - - def setUp(self): - if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: - self.skipTest("On CPU the test works only in interpret mode") - if jtu.test_device_matches( - ["cuda"] - ) and not jtu.is_cuda_compute_capability_at_least("8.0"): - self.skipTest("Only works on GPU with capability >= sm80") - if sys.platform == "win32" and not self.INTERPRET: - self.skipTest("Only works on non-Windows platforms") - - super().setUp() - _trace_kernel_to_jaxpr.cache_clear() - - def pallas_call(self, *args, **kwargs): - return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) - - -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") -class PallasCallRaggedVmapTest(PallasBaseTest): - - def test_vmap_jumble_over_sin_kernel(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct( - (8, col_grid_size * 128), dtype=jnp.float32 - ), - grid=(1, col_grid_size), - interpret=self.INTERPRET, - # See note - on zero filling counterfactuals - debug=True, - )(x) - - res = jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - res = res.data - total = len(ragged_shape) * row_count * col_grid_size * 128 - res_total = np.prod(res.shape) - self.assertEqual(res_total, total) - ragged_total = 0 - for dim in ragged_shape: - ragged_total += row_count * dim * 128 - # See note - on zero filling counterfactuals - self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total) - - def test_vmap_jumble_over_sin_kernel_grid_remapping(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) * pl.program_id(2) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), - grid=(1, 5), - interpret=False, - )(x) - - with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): - jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - self.skipTest("Checkify NYI") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([(128 * x) - 1 for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), - grid=(1, 5), - interpret=False, - )(x) - - with self.assertRaisesRegex( - ValueError, - "Ragged input shape must be evenly divisble by the grid" # noqa: W605 - " size at the ragged dimension 2", - ): - jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - -class PallasCallNamedGridInterpretTest(PallasCallRaggedVmapTest): - INTERPRET = True - - -if __name__ == "__main__": - absltest.main() From 79c222eee6aa5085ac963e12604fe8383df7baec Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 19 Aug 2024 07:19:22 -0700 Subject: [PATCH 161/702] Fix bug in ffi_lowering where custom layouts were ignored. PiperOrigin-RevId: 664795687 --- jax/_src/extend/ffi.py | 4 ++++ tests/extend_test.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index c39e05d9335d..3965c8b72c67 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -140,8 +140,12 @@ def _lowering( kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] if operand_layouts is None: kwargs["operand_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_in) # pytype: disable=attribute-error + else: + kwargs["operand_layouts"] = operand_layouts if result_layouts is None: kwargs["result_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_out) + else: + kwargs["result_layouts"] = result_layouts if "result_shapes" not in kwargs and not all( core.is_constant_shape(aval.shape) for aval in ctx.avals_out): kwargs["result_shapes"] = [ diff --git a/tests/extend_test.py b/tests/extend_test.py index 3194a3ef9073..cdf3af8fbc4d 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -25,6 +25,7 @@ from jax._src import abstract_arrays from jax._src import api +from jax._src import core from jax._src import linear_util from jax._src import prng from jax._src import test_util as jtu @@ -100,6 +101,29 @@ def testHeadersExist(self): for header in ["c_api.h", "api.h", "ffi.h"]: self.assertTrue(os.path.exists(os.path.join(base_dir, header))) + def testLoweringLayouts(self): + # Regression test to ensure that the lowering rule properly captures + # layouts. + def lowering_rule(ctx, x): + aval, = ctx.avals_in + ndim = len(aval.shape) + layout = tuple(range(ndim)) + return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout], + result_layouts=[layout])(ctx, x) + prim = core.Primitive("test_ffi") + prim.def_impl(lambda x: x) + prim.def_abstract_eval(lambda x: x) + mlir.register_lowering(prim, lowering_rule) + x = jnp.linspace(0, 1, 5) + lowered = jax.jit(prim.bind).lower(x) + module = lowered.compiler_ir("stablehlo") + for func in module.body.operations: + for block in func.body.blocks: + for op in block.operations: + if op.OPERATION_NAME == "stablehlo.custom_call": + self.assertIn("operand_layouts", op.attributes) + self.assertIn("result_layouts", op.attributes) + @parameterized.parameters([ (True, mlir.ir.BoolAttr.get), (1, mlir.i64_attr), From 30d54ec6ffe29654e680f061b12ec5fef507c6dc Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 19 Aug 2024 07:40:53 -0700 Subject: [PATCH 162/702] Refactor FFI shape inference functions to include dimension check. Previously we always had two steps when extracting the batch size: (1) check the buffer has enough dimensions, (2) get the shape. And, in a few cases, this first check was missing. Now these steps are combined into one function that returns a StatusOr. As part of this, I needed to fix our implementation of the `ASSIGN_OR_RETURN` macro to properly handle parentheses. PiperOrigin-RevId: 664803225 --- jaxlib/cpu/lapack_kernels.cc | 36 +++++++----- jaxlib/ffi_helpers.h | 94 ++++++++++++++++++++++++++------ jaxlib/gpu/solver_kernels_ffi.cc | 4 +- 3 files changed, 99 insertions(+), 35 deletions(-) diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 8b260d1408e4..9765b227d403 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -80,9 +80,8 @@ inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) { template void CopyIfDiffBuffer(ffi::Buffer x, ffi::ResultBuffer x_out) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); if (x.typed_data() != x_out->typed_data()) { - const auto x_size = batch_count * x_rows * x_cols; + const auto x_size = x.element_count(); std::copy_n(x.typed_data(), x_size, x_out->typed_data()); } } @@ -150,8 +149,8 @@ ffi::Error TriMatrixEquationSolver::Kernel( MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, MatrixParams::Diag diag) { CopyIfDiffBuffer(y, y_out); - - auto [batch_count, y_rows, y_cols] = SplitBatch2D(y.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, y_rows, y_cols]), + SplitBatch2D(y.dimensions())); auto* y_out_data = y_out->typed_data(); lapack_int x_leading_dim_v = side == MatrixParams::Side::kLeft ? y_rows : y_cols; @@ -226,8 +225,8 @@ ffi::Error LuDecomposition::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer ipiv, ffi::ResultBuffer info) { - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(x.dimensions())); - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* ipiv_data = ipiv->typed_data(); auto* info_data = info->typed_data(); @@ -310,7 +309,8 @@ template ffi::Error QrFactorization::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer tau, ffi::ResultBuffer info) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* tau_data = tau->typed_data(); auto* info_data = info->typed_data(); @@ -412,7 +412,8 @@ ffi::Error OrthogonalQr::Kernel(ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer info, ffi::ResultBuffer work) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* tau_data = tau.typed_data(); auto* x_out_data = x_out->typed_data(); auto* info_data = info->typed_data(); @@ -500,8 +501,8 @@ template ffi::Error CholeskyFactorization::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, ffi::ResultBuffer x_out, ffi::ResultBuffer info) { - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(x.dimensions())); - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* info_data = info->typed_data(); @@ -698,7 +699,8 @@ static ffi::Error SvdKernel( XLA_FFI_Error_Code_UNIMPLEMENTED, "Current implementation does not support this computation mode"); } - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* singular_values_data = singular_values->typed_data(); auto* u_data = u->typed_data(); @@ -977,7 +979,8 @@ ffi::Error EigenvalueDecompositionSymmetric::Kernel( ffi::ResultBuffer x_out, ffi::ResultBuffer eigenvalues, ffi::ResultBuffer info, ffi::ResultBuffer work, ffi::ResultBuffer iwork, eig::ComputationMode mode) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* eigenvalues_data = eigenvalues->typed_data(); auto* info_data = info->typed_data(); @@ -1039,7 +1042,8 @@ ffi::Error EigenvalueDecompositionHermitian::Kernel( ffi::ResultBuffer info, ffi::ResultBuffer work, ffi::ResultBuffer rwork, ffi::ResultBuffer iwork, eig::ComputationMode mode) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* eigenvalues_data = eigenvalues->typed_data(); auto* info_data = info->typed_data(); @@ -1265,7 +1269,8 @@ ffi::Error EigenvalueDecomposition::Kernel( ffi::ResultBuffer info, ffi::ResultBuffer x_work, ffi::ResultBuffer work_eigvecs_left, ffi::ResultBuffer work_eigvecs_right) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); const auto* x_data = x.typed_data(); auto* x_work_data = x_work->typed_data(); @@ -1339,7 +1344,8 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( ffi::ResultBuffer eigvecs_right, ffi::ResultBuffer info, ffi::ResultBuffer x_work, ffi::ResultBuffer rwork) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); const auto* x_data = x.typed_data(); auto* x_work_data = x_work->typed_data(); auto* eigvecs_left_data = eigvecs_left->typed_data(); diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 69c63a4ba000..bedfdca2f11b 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -9,6 +9,7 @@ #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/base/optimization.h" @@ -20,12 +21,7 @@ namespace jax { -#define FFI_ASSIGN_OR_RETURN(lhs, rhs) \ - if (ABSL_PREDICT_FALSE(!rhs.ok())) { \ - return ::jax::AsFfiError(rhs.status()); \ - } \ - lhs = rhs.value() - +// Returns from the function if the argument is an ffi::Error. #define FFI_RETURN_IF_ERROR(...) \ do { \ ::xla::ffi::Error err = (__VA_ARGS__); \ @@ -34,6 +30,8 @@ namespace jax { } \ } while (0) +// Returns from the function with an ffi::Error if the argument is an +// absl::Status. #define FFI_RETURN_IF_ERROR_STATUS(...) \ do { \ ::absl::Status status = (__VA_ARGS__); \ @@ -42,6 +40,57 @@ namespace jax { } \ } while (0) +// Returns from the function with an ffi::Error if the RHS is an absl::Status, +// otherwise assigns to the LHS. Most of the complication here stems from the +// fact that we want to support having the LHS wrapped in parentheses (when +// unpacking a tuple, for example). +#define FFI_ASSIGN_OR_RETURN(lhs, rhs) \ + FFI_ASSIGN_OR_RETURN_IMPL_( \ + FFI_ASSIGN_OR_RETURN_CONCAT_(_status_or_value, __LINE__), lhs, rhs) + +#define FFI_ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rhs) \ + auto statusor = (rhs); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + return ::jax::AsFfiError(statusor.status()); \ + } \ + FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \ + (*std::move(statusor)) + +#define FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) x##y +#define FFI_ASSIGN_OR_RETURN_CONCAT_(x, y) \ + FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) + +// All the macros below here are to handle the case in FFI_ASSIGN_OR_RETURN +// where the LHS is wrapped in parentheses. +#define FFI_ASSIGN_OR_RETURN_EAT(...) +#define FFI_ASSIGN_OR_RETURN_REM(...) __VA_ARGS__ +#define FFI_ASSIGN_OR_RETURN_EMPTY() + +#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(...) \ + FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER((__VA_ARGS__, 0, 1)) +#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER(args) \ + FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I args +#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) is_empty + +#define FFI_ASSIGN_OR_RETURN_IS_EMPTY(...) \ + FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(__VA_ARGS__) +#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(...) \ + FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(_, ##__VA_ARGS__) + +#define FFI_ASSIGN_OR_RETURN_IF_1(_Then, _Else) _Then +#define FFI_ASSIGN_OR_RETURN_IF_0(_Then, _Else) _Else +#define FFI_ASSIGN_OR_RETURN_IF(_Cond, _Then, _Else) \ + FFI_ASSIGN_OR_RETURN_CONCAT_(FFI_ASSIGN_OR_RETURN_IF_, _Cond)(_Then, _Else) + +#define FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(...) \ + FFI_ASSIGN_OR_RETURN_IS_EMPTY(FFI_ASSIGN_OR_RETURN_EAT __VA_ARGS__) + +#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(...) \ + FFI_ASSIGN_OR_RETURN_IF(FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(__VA_ARGS__), \ + FFI_ASSIGN_OR_RETURN_REM, \ + FFI_ASSIGN_OR_RETURN_EMPTY()) \ + __VA_ARGS__ + template inline absl::StatusOr MaybeCastNoOverflow( std::int64_t value, const std::string& source = __FILE__) { @@ -67,21 +116,30 @@ inline ::xla::ffi::Error AsFfiError(const absl::Status& status) { } } -template -::xla::ffi::Error CheckMatrixDimensions(::xla::ffi::Span dims) { - if (dims.size() < 2) { - return ::xla::ffi::Error(::xla::ffi::ErrorCode::kInvalidArgument, - "Matrix must have at least 2 dimensions"); +inline int64_t GetBatchSize(::xla::ffi::Span dims) { + return absl::c_accumulate(dims, 1, std::multiplies()); +} + +inline absl::StatusOr> SplitBatch1D( + ::xla::ffi::Span dims, + const std::string& source = __FILE__) { + if (dims.size() < 1) { + return absl::InvalidArgumentError( + absl::StrFormat("%s: Argument must have at least 1 dimension", source)); } - return ::xla::ffi::Error::Success(); + return std::make_pair(GetBatchSize(dims.first(dims.size() - 1)), dims.back()); } -template -std::tuple SplitBatch2D(::xla::ffi::Span dims) { - auto matrix_dims = dims.last(2); - return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1, - std::multiplies()), - matrix_dims.front(), matrix_dims.back()); +inline absl::StatusOr> SplitBatch2D( + ::xla::ffi::Span dims, + const std::string& source = __FILE__) { + if (dims.size() < 2) { + return absl::InvalidArgumentError(absl::StrFormat( + "%s: Argument must have at least 2 dimensions", source)); + } + auto trailingDims = dims.last(2); + return std::make_tuple(GetBatchSize(dims.first(dims.size() - 2)), + trailingDims.front(), trailingDims.back()); } template <::xla::ffi::DataType dtype> diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 6deb89144ec7..7b4b673bf9ec 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -163,8 +163,8 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::ErrorCode::kInvalidArgument, "The input and output to getrf must have the same element type"); } - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions())); - auto [batch, rows, cols] = SplitBatch2D(a.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); if (batch > 1 && rows == cols && rows / batch <= 128) { if (dataType == ffi::DataType::F32) { return GetrfBatchedImpl(batch, cols, stream, scratch, a, out, ipiv, From 351b780c48160c6b9d13fd68b36733a6cf358bea Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 19 Aug 2024 22:25:53 +0530 Subject: [PATCH 163/702] Better docs for jnp.fmin and fmax --- jax/_src/numpy/lax_numpy.py | 93 ++++++++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d6c11d181025..965b36d73632 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -349,14 +349,103 @@ def load(*args: Any, **kwargs: Any) -> Array: ### implementations of numpy functions in terms of lax -@util.implements(np.fmin, module='numpy') @jit def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: + """Return element-wise minimum of the input arrays. + + JAX implemtentation of :func:`numpy.fmin`. + + Args: + x1: input array or scalar. + x2: input array or scalar. x1 and x2 must either have same shape or be + broadcast compatible. + + Returns: + An array containing the element-wise minimum of x1 and x2. + + Note: + For each pair of elements, ``jnp.fmin`` returns: + - the smaller of the two if both elements are finite numbers. + - finite number if one element is ``nan``. + - ``-inf`` if one element is ``-inf`` and the other is finite or ``nan``. + - ``inf`` if one element is ``inf`` and the other is ``nan``. + - ``nan`` if both elements are ``nan``. + + Examples: + >>> jnp.fmin(2, 3) + Array(2, dtype=int32, weak_type=True) + >>> jnp.fmin(2, jnp.array([1, 4, 2, -1])) + Array([ 1, 2, 2, -1], dtype=int32) + + >>> x1 = jnp.array([1, 3, 2]) + >>> x2 = jnp.array([2, 1, 4]) + >>> jnp.fmin(x1, x2) + Array([1, 1, 2], dtype=int32) + + >>> x3 = jnp.array([1, 5, 3]) + >>> x4 = jnp.array([[2, 3, 1], + ... [5, 6, 7]]) + >>> jnp.fmin(x3, x4) + Array([[1, 3, 1], + [1, 5, 3]], dtype=int32) + + >>> nan = jnp.nan + >>> x5 = jnp.array([jnp.inf, 5, nan]) + >>> x6 = jnp.array([[2, 3, nan], + ... [nan, 6, 7]]) + >>> jnp.fmin(x5, x6) + Array([[ 2., 3., nan], + [inf, 5., 7.]], dtype=float32) + """ return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2) -@util.implements(np.fmax, module='numpy') + @jit def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: + """Return element-wise maximum of the input arrays. + + JAX implementation of :func:`numpy.fmax`. + + Args: + x1: input array or scalar + x2: input array or scalar. x1 and x1 must either have same shape or be + broadcast compatible. + + Returns: + An array containing the element-wise maximum of x1 and x2. + + Note: + For each pair of elements, ``jnp.fmax`` returns: + - the larger of the two if both elements are finite numbers. + - finite number if one element is ``nan``. + - ``nan`` if both elements are ``nan``. + - ``inf`` if one element is ``inf`` and the other is finite or ``nan``. + - ``-inf`` if one element is ``-inf`` and the other is ``nan``. + + Examples: + >>> jnp.fmax(3, 7) + Array(7, dtype=int32, weak_type=True) + >>> jnp.fmax(5, jnp.array([1, 7, 9, 4])) + Array([5, 7, 9, 5], dtype=int32) + + >>> x1 = jnp.array([1, 3, 7, 8]) + >>> x2 = jnp.array([-1, 4, 6, 9]) + >>> jnp.fmax(x1, x2) + Array([1, 4, 7, 9], dtype=int32) + + >>> x3 = jnp.array([[2, 3, 5, 10], + ... [11, 9, 7, 5]]) + >>> jnp.fmax(x1, x3) + Array([[ 2, 3, 7, 10], + [11, 9, 7, 8]], dtype=int32) + + >>> x4 = jnp.array([jnp.inf, 6, -jnp.inf, nan]) + >>> x5 = jnp.array([[3, 5, 7, nan], + ... [nan, 9, nan, -1]]) + >>> jnp.fmax(x4, x5) + Array([[ inf, 6., 7., nan], + [ inf, 9., -inf, -1.]], dtype=float32) + """ return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2) @util.implements(np.issubdtype) From a213d2fa305a301b51346d6a84e2fc77b7af0f13 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 19 Aug 2024 10:00:37 -0700 Subject: [PATCH 164/702] Improve documentation for jnp.copy --- jax/_src/numpy/lax_numpy.py | 44 ++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0f27cb5ff409..1dfbc52422e7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3736,8 +3736,50 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) -@util.implements(np.copy, lax_description=_ARRAY_DOC) def copy(a: ArrayLike, order: str | None = None) -> Array: + """Return a copy of the array. + + JAX implementation of :func:`numpy.copy`. + + Args: + a: arraylike object to copy + order: not implemented in JAX + + Returns: + a copy of the input array ``a``. + + See Also: + - :func:`jax.numpy.array`: create an array with or without a copy. + - :meth:`jax.Array.copy`: same function accessed as an array method. + + Examples: + Since JAX arrays are immutable, in most cases explicit array copies + are not necessary. One exception is when using a function with donated + arguments (see the ``donate_argnums`` argument to :func:`jax.jit`). + + >>> f = jax.jit(lambda x: 2 * x, donate_argnums=0) + >>> x = jnp.arange(4) + >>> y = f(x) + >>> print(y) + [0 2 4 6] + + Because we marked ``x`` as being donated, the original array is no longer + available: + + >>> print(x) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + RuntimeError: Array has been deleted with shape=int32[4]. + + In situations like this, an explicit copy will let you keep access to the + original buffer: + + >>> x = jnp.arange(4) + >>> y = f(x.copy()) + >>> print(y) + [0 2 4 6] + >>> print(x) + [0 1 2 3] + """ util.check_arraylike("copy", a) return array(a, copy=True, order=order) From 7a4eecda613b7702b0d99de7d4e5f2052de63a38 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 19 Aug 2024 10:01:34 -0700 Subject: [PATCH 165/702] remove trailing comma --- jax/_src/basearray.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index e4c7fdc4a9ab..bd546169d341 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -171,7 +171,7 @@ class Array(abc.ABC): @property def nbytes(self) -> int: ... def nonzero(self, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None, - size: int | None = None,) -> tuple[Array, ...]: ... + size: int | None = None) -> tuple[Array, ...]: ... def prod(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, From 36739e84cec945db52293554245b6dbf829b9047 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 19 Aug 2024 10:39:38 -0700 Subject: [PATCH 166/702] Normalize "interpreter mode" to "interpret mode", and "InterpreterTest" to "InterpretTest" This is because both "interpret mode" and "interpreter mode" occur in code, and "interpret mode" is more frequent. PiperOrigin-RevId: 664873359 --- docs/pallas/CHANGELOG.md | 4 ++-- jax/_src/pallas/pallas_call.py | 2 +- tests/pallas/gpu_attention_test.py | 2 +- tests/pallas/gpu_ops_test.py | 8 ++++---- tests/pallas/indexing_test.py | 6 +++--- tests/pallas/ops_test.py | 22 +++++++++++----------- tests/pallas/pallas_test.py | 18 +++++++++--------- tests/pallas/pallas_vmap_test.py | 4 ++-- tests/pallas/tpu_ops_test.py | 2 +- tests/pallas/tpu_pallas_test.py | 10 +++++----- 10 files changed, 39 insertions(+), 39 deletions(-) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index e43a178db50e..c1ed1385bbbc 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -36,9 +36,9 @@ Remember to align the itemized text with the first line of an item within a list * The method `compute_index` of {class}`jax.experimental.pallas.GridSpec` has been removed because it is private. Similarly, the `get_grid_mapping` and `unzip_dynamic_bounds` have been removed from `BlockSpec` ({jax-issue}`#22593`). - * Fixed the interpreter mode to work with BlockSpec that involve padding + * Fixed the interpret mode to work with BlockSpec that involve padding ({jax-issue}`#22275`). - Padding in interpreter mode will be with NaN, to help debug out-of-bounds + Padding in interpret mode will be with NaN, to help debug out-of-bounds errors, but this behavior is not present when running in custom kernel mode, and should not be depended on. * Previously it was possible to import many APIs that are meant to be diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 4f3c9918f664..3a780cdca617 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -177,7 +177,7 @@ def _pallas_call_impl_interpret( cost_estimate: CostEstimate, ): del compiler_params, cost_estimate - # If we're in interpreter mode, we *scan* over the grid and eval the + # If we're in interpret mode, we *scan* over the grid and eval the # discharged jaxpr. dynamic_grid_args, args = split_list( # type: ignore args, [grid_mapping.num_dynamic_grid_bounds] diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index 571e1348a7a8..e7bc88cab811 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -148,7 +148,7 @@ def test_gqa( o_ref = decode_attention.gqa_reference(q, k, v) np.testing.assert_allclose(o, o_ref, atol=0.05) -class DecodeAttentionInterpreterTest(DecodeAttentionTest): +class DecodeAttentionInterpretTest(DecodeAttentionTest): INTERPRET = True diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index 4ab957a3a2a1..a18051b002fe 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -292,7 +292,7 @@ def f_ref(q, k, v): np.testing.assert_allclose(dv, dv_ref, atol=0.05) -class FusedAttentionInterpreterTest(FusedAttentionTest): +class FusedAttentionInterpretTest(FusedAttentionTest): INTERPRET = True @@ -340,7 +340,7 @@ def f_ref(x, w, b): np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2) -class FusedLayerNormInterpreterTest(FusedLayerNormTest): +class FusedLayerNormInterpretTest(FusedLayerNormTest): INTERPRET = True @@ -388,7 +388,7 @@ def f_ref(x, w, b): np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2) -class RmsNormInterpreterTest(RmsNormTest): +class RmsNormInterpretTest(RmsNormTest): INTERPRET = True @@ -422,7 +422,7 @@ def test_softmax(self, shape, dtype): ) -class SoftmaxInterpreterTest(SoftmaxTest): +class SoftmaxInterpretTest(SoftmaxTest): INTERPRET = True diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 696f12b0ed72..f706b36c5f90 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -253,7 +253,7 @@ class IndexerOpsTest(PallasBaseTest): def test_multi_indexing_interpreter_only(self): if not self.INTERPRET: self.skipTest("Only supported in interpret mode") - # Interpreter only test! YMMV actually compiling this. + # Interpret only test! YMMV actually compiling this. def permute(left, right, left_out_ref, right_out_ref): left_out = jnp.zeros_like(left) left_out = left_out.at[:, 0].set(left[:, 0]) @@ -302,7 +302,7 @@ def invoke_permutes(x_ref, y_ref, x_out_ref, y_out_ref): def test_ellipsis_indexing_iterpret_only(self): if not self.INTERPRET: self.skipTest("Only supported in interpret mode") - # Interpreter only test! YMMV actually compiling this. + # Interpret only test! YMMV actually compiling this. def permute_columns_in_row_kernel(left, right, new_left, new_right): shape = left.shape k = shape[-1] @@ -616,7 +616,7 @@ def kernel(x_ref, indices, y_ref): self.assertAllClose(res[:, start : start + 1, :], x, atol=0., rtol=0.) -class IndexerOpsInterpreterTest(IndexerOpsTest): +class IndexerOpsInterpretTest(IndexerOpsTest): INTERPRET = True diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 2a776b6347f1..e106a56e588e 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -249,7 +249,7 @@ def setUp(self): self.skipTest("Only works in 32-bit") if not self.INTERPRET: if jtu.device_under_test() == "cpu": - self.skipTest("Only interpreter mode supported on CPU") + self.skipTest("Only interpret mode supported on CPU") if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPUs with capability >= sm80") @@ -664,7 +664,7 @@ def run(interpret=False): self.assertAllClose(actual, expected) -class OpsInterpreterTest(OpsTest): +class OpsInterpretTest(OpsTest): INTERPRET = True def test_debug_print(self): @@ -691,7 +691,7 @@ class OpsExtraTest(PallasBaseTest): def setUp(self): super().setUp() if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - # TODO: most tests fail on TPU in non-interpreter mode + # TODO: most tests fail on TPU in non-interpret mode self.skipTest("On TPU the test works only in interpret mode") ELEMENTWISE_OPS = [ @@ -841,7 +841,7 @@ def kernel(x_ref, y_ref, o_ref): @parameterized.parameters("float16", "bfloat16") def test_true_divide_unsupported(self, dtype): if self.INTERPRET: - self.skipTest("No lowering in interpreter mode") + self.skipTest("No lowering in interpret mode") @functools.partial( self.pallas_call, @@ -911,7 +911,7 @@ def kernel(o_ref): @parameterized.parameters("float16", "bfloat16", "float32") def test_approx_tanh(self, dtype): if self.INTERPRET: - self.skipTest("approx_tanh is not supported in interpreter mode") + self.skipTest("approx_tanh is not supported in interpret mode") if (dtype == "bfloat16" and not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") @@ -935,7 +935,7 @@ def kernel(x_ref, o_ref): def test_elementwise_inline_asm(self): if self.INTERPRET: self.skipTest( - "elementwise_inline_asm is not supported in interpreter mode" + "elementwise_inline_asm is not supported in interpret mode" ) @functools.partial( @@ -1163,7 +1163,7 @@ def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): def test_strided_load(self): if self.INTERPRET: # TODO(b/329733289): Remove this once the bug is fixed. - self.skipTest("Strided load not yet supported in interpreter mode") + self.skipTest("Strided load not yet supported in interpret mode") # Reproducer from https://github.com/google/jax/issues/20895. @functools.partial( @@ -1199,7 +1199,7 @@ def load(x_ref, o_ref): ) def test_invalid_broadcasted_load(self, x_shape, mask_shape): if self.INTERPRET: - self.skipTest("No broadcasting checks in pl.load in interpreter mode") + self.skipTest("No broadcasting checks in pl.load in interpret mode") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32) @@ -1383,7 +1383,7 @@ def swap(_, lock_ref, out_ref): @parameterized.parameters(1, 2, 3, 4, 8) def test_atomic_counter(self, num_threads): if self.INTERPRET: - self.skipTest("While loop not supported in interpreter mode.") + self.skipTest("While loop not supported in interpret mode.") @functools.partial( self.pallas_call, out_shape=( @@ -1501,7 +1501,7 @@ def reduce(x_ref, y_ref): np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) -class OpsExtraInterpreterTest(OpsExtraTest): +class OpsExtraInterpretTest(OpsExtraTest): INTERPRET = True @@ -1558,7 +1558,7 @@ def body(x_ref): self.assertIn(expected, jaxpr.pretty_print(use_color=False)) -class PallasPrimitivesInterpreterTest(PallasPrimitivesTest): +class PallasPrimitivesInterpretTest(PallasPrimitivesTest): INTERPRET = True diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 63779319b0b9..7227c1b91ef3 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -687,7 +687,7 @@ def f(x): self.assertEqual(trace_count, 1) -class PallasCallInterpreterTest(PallasCallTest): +class PallasCallInterpretTest(PallasCallTest): INTERPRET = True @@ -921,7 +921,7 @@ def the_kernel(): return None self.assertEqual("", ns5.src_info) -class ApiErrorInterpreterTest(ApiErrorTest): +class ApiErrorInterpretTest(ApiErrorTest): INTERPRET = True @@ -957,7 +957,7 @@ def f(x): self.assertEqual(mem_analysis.temp_size_in_bytes, 0) -class PallasCallInputOutputAliasingInterpreterTest(PallasBaseTest): +class PallasCallInputOutputAliasingInterpretTest(PallasBaseTest): INTERPRET = True @@ -966,7 +966,7 @@ class PallasControlFlowTest(PallasBaseTest): def setUp(self): super().setUp() if self.INTERPRET: - self.skipTest("Control flow not supported in interpreter mode yet.") + self.skipTest("Control flow not supported in interpret mode yet.") def test_loop_with_float64_carry(self): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: @@ -1690,7 +1690,7 @@ def outer_body(carry): np.testing.assert_equal(sizes[0, 4], jnp.asarray(key_count - real_keys)) -class PallasControlFlowInterpreterTest(PallasControlFlowTest): +class PallasControlFlowInterpretTest(PallasControlFlowTest): INTERPRET = True AD_TEST_CASES = [ @@ -1713,7 +1713,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest): def setUp(self): super().setUp() if jtu.test_device_matches(["tpu"]): - # TODO: most tests fail on TPU in non-interpreter mode + # TODO: most tests fail on TPU in non-interpret mode self.skipTest("On TPU the test works only in interpret mode") # TODO: improve tolerance setting self.tol = 1e-5 @@ -1819,11 +1819,11 @@ def softmax_kernel(x_ref, y_ref): # jtu.check_grads(mm, (x, y), modes=["fwd"], order=1) -class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest): +class PallasCallAutodifferentiationInterpretTest(PallasCallAutodifferentiationTest): INTERPRET = True -class PallasOutOfBoundsInterpreterTest(PallasBaseTest): +class PallasOutOfBoundsInterpretTest(PallasBaseTest): INTERPRET = True def test_interpret_mode_out_of_bounds_access(self): @@ -1901,7 +1901,7 @@ def _(): np.testing.assert_allclose(out, expected, atol=atol) -class PallasCheckifyInterpreterTest(PallasBaseTest): +class PallasCheckifyInterpretTest(PallasBaseTest): # TODO(b/346651778): Support non-interpret mode checkify. INTERPRET = True diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index af8299e31689..3c33702b63e3 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -63,7 +63,7 @@ class PallasCallVmapTest(PallasBaseTest): def setUp(self): super().setUp() if jtu.test_device_matches(["tpu"]): - # TODO: most tests fail on TPU in non-interpreter mode + # TODO: most tests fail on TPU in non-interpret mode self.skipTest("On TPU the test works only in interpret mode") def test_vmap_of_simple_kernel(self): @@ -250,7 +250,7 @@ def add_one(x_ref, o_ref): np.testing.assert_allclose(out, out_ref) -class PallasCallVmapInterpreterTest(PallasCallVmapTest): +class PallasCallVmapInterpretTest(PallasCallVmapTest): INTERPRET = True def setUp(self): diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 75aab92af909..d9a66d1b2b34 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -175,7 +175,7 @@ def kernel(x_ref, y_ref, out_ref): np.testing.assert_array_equal(out, inp.reshape(m * 2, n)) -class OpsInterpreterTest(OpsTest): +class OpsInterpretTest(OpsTest): INTERPRET = True diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 20de5b585ac4..bfefe720a383 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -453,7 +453,7 @@ def f(x): self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes) -class PallasCallScalarPrefetchInterpreterTest(PallasCallScalarPrefetchTest): +class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): INTERPRET: bool = True @@ -703,7 +703,7 @@ def dynamic_kernel(steps, x): np.testing.assert_array_equal(dynamic_kernel(np.int32(4), x), x[8:16]) -class PallasCallDynamicGridInterpreterTest(PallasCallDynamicGridTest): +class PallasCallDynamicGridInterpretTest(PallasCallDynamicGridTest): INTERPRET = True @@ -1467,7 +1467,7 @@ def kernel(index, x, y, sem): del y -class PallasCallDMAInterpreterTest(PallasCallDMATest): +class PallasCallDMAInterpretTest(PallasCallDMATest): INTERPRET = True def test_interpret_local_dma(self): @@ -1725,7 +1725,7 @@ def kernel(x_ref, y_ref): np.testing.assert_array_equal(y, x) -class PallasCallUnblockedIndexingInterpreterTest( +class PallasCallUnblockedIndexingInterpretTest( PallasCallUnblockedIndexingTest ): INTERPRET = True @@ -2463,7 +2463,7 @@ def kernel(x_ref, out_ref): ) -class MiscellaneousInterpreterTest(MiscellaneousTest): +class MiscellaneousInterpretTest(MiscellaneousTest): INTERPRET: bool = True From ef82cb21aeed3363330b0c5454722664fc3222a4 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 19 Jul 2024 00:24:25 +0000 Subject: [PATCH 167/702] fix basic scan bug with attrs --- jax/_src/lax/control_flow/loops.py | 3 ++- tests/attrs_test.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index aa707386c5db..a2a6d71f55d6 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -268,7 +268,8 @@ def _create_jaxpr(init): if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) - carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves] + _, carry_avals_out, _ = split_list( + jaxpr.out_avals, [len(attrs_tracked), out_tree_children[0].num_leaves]) return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked) diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 5c834f314270..4378a3c7526d 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -344,6 +344,21 @@ def jitted(): jax.jit(jitted)() # don't crash + def test_scan_carry(self): + class A: + ... + + a = A() + + jax_setattr(a, 'x', jnp.zeros(3)) + + def body(i, _): + x = jax_getattr(a, 'x') + x = x.at[i].set(x[i] + 1) + jax_setattr(a, 'x', x) + return i + 1, None + _, _ = jax.lax.scan(body, 0, None, length=3) # don't crash + class AttrsJVPTest(jtu.JaxTestCase): From e06be544d441065115a5fb265a9a41a6df37ed26 Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Wed, 31 Jul 2024 15:04:58 -0500 Subject: [PATCH 168/702] [ROCm] improve gpu script --- build/rocm/Dockerfile.ms | 4 +-- build/rocm/run_single_gpu.py | 53 ++++++++++++++++++++++++++++++++---- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index 9d19486b6557..dffe42de77f6 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -49,8 +49,7 @@ RUN eval "$(pyenv init -)" && \ numpy setuptools build wheel six auditwheel scipy \ pytest pytest-html pytest_html_merger pytest-reportlog \ pytest-rerunfailures cloudpickle portpicker matplotlib absl-py \ - flatbuffers hypothesis - + flatbuffers hypothesis pytest-json-report pytest-csv ################################################################################ FROM rocm_base AS rt_build @@ -68,3 +67,4 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ RUN --mount=type=bind,source=wheelhouse,target=/wheelhouse \ pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt + diff --git a/build/rocm/run_single_gpu.py b/build/rocm/run_single_gpu.py index 4eedc8d4e2a5..4e7660ca1f15 100755 --- a/build/rocm/run_single_gpu.py +++ b/build/rocm/run_single_gpu.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import csv import json import argparse import threading @@ -29,6 +30,34 @@ def extract_filename(path): file_name, _ = os.path.splitext(base_name) return file_name + +def combine_json_reports(): + all_json_files = [f for f in os.listdir(base_dir) if f.endswith('_log.json')] + combined_data = [] + for json_file in all_json_files: + with open(os.path.join(base_dir, json_file), 'r') as infile: + data = json.load(infile) + combined_data.append(data) + combined_json_file = f"{base_dir}/final_compiled_report.json" + with open(combined_json_file, 'w') as outfile: + json.dump(combined_data, outfile, indent=4) + + +def combine_csv_reports(): + all_csv_files = [f for f in os.listdir(base_dir) if f.endswith('_log.csv')] + combined_csv_file = f"{base_dir}/final_compiled_report.csv" + with open(combined_csv_file, mode='w', newline='') as outfile: + csv_writer = csv.writer(outfile) + for i, csv_file in enumerate(all_csv_files): + with open(os.path.join(base_dir, csv_file), mode='r') as infile: + csv_reader = csv.reader(infile) + if i == 0: + # write headers only once + csv_writer.writerow(next(csv_reader)) + for row in csv_reader: + csv_writer.writerow(row) + + def generate_final_report(shell=False, env_vars={}): env = os.environ env = {**env, **env_vars} @@ -41,7 +70,10 @@ def generate_final_report(shell=False, env_vars={}): print("FAILED - {}".format(" ".join(cmd))) print(result.stderr.decode()) - return result.returncode, result.stderr.decode(), result.stdout.decode() + # Generate json reports. + combine_json_reports() + # Generate csv reports. + combine_csv_reports() def run_shell_command(cmd, shell=False, env_vars={}): @@ -66,7 +98,7 @@ def parse_test_log(log_file): report = json.loads(line) if "nodeid" in report: module = report["nodeid"].split("::")[0] - if module: + if module and ".py" in module: test_files.add(os.path.abspath(module)) return test_files @@ -100,9 +132,20 @@ def run_test(testmodule, gpu_tokens, continue_on_fail): } testfile = extract_filename(testmodule) if continue_on_fail: - cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-v", testmodule] + cmd = ["python3", "-m", "pytest", + "--json-report", f"--json-report-file={base_dir}/{testfile}_log.json", + f"--csv={base_dir}/{testfile}_log.csv", + "--csv-columns", "id,module,name,file,status,duration", + f"--html={base_dir}/{testfile}_log.html", + "--reruns", "3", "-v", testmodule] else: - cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", "-v", testmodule] + cmd = ["python3", "-m", "pytest", + "--json-report", f"--json-report-file={base_dir}/{testfile}_log.json", + f"--csv={base_dir}/{testfile}_log.csv", + "--csv-columns", "id,module,name,file,status,duration", + f"--html={base_dir}/{testfile}_log.html", + "--reruns", "3", "-x", "-v", testmodule] + return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars) with GPU_LOCK: gpu_tokens.append(target_gpu) @@ -115,7 +158,7 @@ def run_test(testmodule, gpu_tokens, continue_on_fail): def run_parallel(all_testmodules, p, c): - print(f"Running tests with parallelism=", p) + print(f"Running tests with parallelism = {p}") available_gpu_tokens = list(range(p)) executor = ThreadPoolExecutor(max_workers=p) # walking through test modules. From 3e764f617add2d5a59b57d23f3787059791118e9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 19 Aug 2024 13:45:06 -0700 Subject: [PATCH 169/702] [shard_map docs]: Fix doc typos PiperOrigin-RevId: 664960613 --- docs/notebooks/shard_map.ipynb | 2 +- docs/notebooks/shard_map.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index aa3c0e276fb1..157d6c567b24 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -531,7 +531,7 @@ "\n", "```python\n", "def f_shmapped_ref(x):\n", - " x_blocks = jnp.array_split(x, mesh.shape[0])\n", + " x_blocks = jnp.array_split(x, mesh.shape['i'])\n", " y_blocks = [f(x_blk) for x_blk in x_blocks]\n", " return jnp.concatenate(y_blocks)\n", "```\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 8b2c2d6fbdcd..5b40e78dcfc3 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -378,7 +378,7 @@ values, as this reference function: ```python def f_shmapped_ref(x): - x_blocks = jnp.array_split(x, mesh.shape[0]) + x_blocks = jnp.array_split(x, mesh.shape['i']) y_blocks = [f(x_blk) for x_blk in x_blocks] return jnp.concatenate(y_blocks) ``` From 292161ab4db444af6f611af38370e61a608b90ea Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 19 Aug 2024 14:42:38 -0700 Subject: [PATCH 170/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0c5475a11c47fd3aa7afdfa57f533ed9323133dd. PiperOrigin-RevId: 664987560 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 6419cee31204..62bdf98cf731 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "b43657003ce347f83a36fb6f56f528ce3ce982c4" -XLA_SHA256 = "d58acd72ffb049756690600c6752f054ff5274e2b097fb92fbcc0f661b26d4d8" +XLA_COMMIT = "0c5475a11c47fd3aa7afdfa57f533ed9323133dd" +XLA_SHA256 = "0306f260e83960a121fab59c1ea7ede09b898251b2b042913941fa161f1423a3" def repo(): tf_http_archive( From b1b3ea276b94c070db42dc88a3531fbed551189f Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 20 Aug 2024 00:03:56 +0200 Subject: [PATCH 171/702] Added py::mod_gil_not_used() to PYBIND11_MODULE register_jax_dialects --- jaxlib/mlir/_mlir_libs/register_jax_dialects.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index e1958c211b33..2e10062945b5 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -1,5 +1,6 @@ // Registers MLIR dialects used by JAX. // This module is called by mlir/__init__.py during initialization. +#include #include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" @@ -14,11 +15,13 @@ #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +namespace py = pybind11; + #define REGISTER_DIALECT(name) \ MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ mlirDialectHandleInsertDialect(name##_dialect, registry) -PYBIND11_MODULE(register_jax_dialects, m) { +PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) { m.doc() = "Registers upstream MLIR dialects used by JAX."; m.def("register_dialects", [](MlirDialectRegistry registry) { From 6e1c23610d4460958e3f893145c68ad112fe32d5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 19 Aug 2024 15:10:00 -0700 Subject: [PATCH 172/702] If input layouts are specified via `in_shardings` to `jit` and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user. Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good. Fixes: https://github.com/google/jax/issues/23100 PiperOrigin-RevId: 665000157 --- jax/_src/api.py | 2 +- jax/_src/array.py | 32 ++++---- jax/_src/dispatch.py | 7 +- jax/_src/earray.py | 5 +- jax/_src/interpreters/mlir.py | 2 +- jax/_src/interpreters/pxla.py | 126 ++++++++++++++++++++--------- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/layout.py | 4 +- jax/_src/pjit.py | 15 ++-- jax/_src/prng.py | 5 +- tests/lax_test.py | 2 +- tests/layout_test.py | 23 ++++++ tests/pmap_test.py | 2 +- 13 files changed, 155 insertions(+), 72 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 493f48a88624..5a773783b877 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1827,7 +1827,7 @@ def cache_miss(*args, **kwargs): cpp_mapped_f = pmap_lib.pmap( fun, cache_miss, static_broadcasted_tuple, - lambda x, s: pxla.shard_args([s], [x])[0], + lambda x, s: pxla.shard_args([s], [None], [x])[0], pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) diff --git a/jax/_src/array.py b/jax/_src/array.py index 118f2bfe8851..0f554a86a655 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1086,9 +1086,8 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): # Look up all buffers that contain the correct slice of the logical array. candidates_list = candidates[hashed_index(idx)] if not candidates_list: - # This array isn't sharded correctly. Reshard it via host roundtrip. - # TODO(skye): more efficient reshard? - return pxla.shard_args([sharding], [x._value], canonicalize=False)[0] + return pxla.shard_args([sharding], [None], [x._value], + canonicalize=False)[0] # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. for buf in candidates_list: @@ -1097,7 +1096,6 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): break else: bufs.append(buf) - return pxla.batched_device_put(x.aval, sharding, bufs, devices) @@ -1107,24 +1105,30 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): dst_indices = dst_sharding.addressable_devices_indices_map(shape).values() return dst_indices, tuple(src_indices) == tuple(dst_indices) +def _layout_eq(x, dst_layout, sharding): + if pxla.is_default_layout(dst_layout, sharding, x.aval): + return True + return x.layout.device_local_layout == dst_layout + -def _array_shard_arg(xs, shardings): +def _array_shard_arg(xs, shardings, layouts): results = [] batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], [] - for i, (x, sharding) in enumerate(safe_zip(xs, shardings)): + + for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)): x._check_if_deleted() + indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) + same_layout = _layout_eq(x, layout, sharding) - indices, same_indices = _sharding_indices_and_eq( - x.sharding, x.shape, sharding) if not x.is_fully_addressable: - if same_indices: + if same_indices and same_layout: results.append(x) else: raise NotImplementedError( "Cannot reshard an input that is not fully addressable") else: devices = sharding._addressable_device_assignment - if same_indices: + if same_indices and same_layout: # Add a placeholder result that will be filled in later. results.append(None) # Accumulate arguments to `batched_copy_array_to_devices_with_sharding`. @@ -1133,6 +1137,8 @@ def _array_shard_arg(xs, shardings): batch_shardings.append(sharding) batch_indices.append(i) # Resharding starts here: + elif not same_layout: + results.append(api.device_put(x, Layout(layout, sharding))) elif dispatch.is_single_device_sharding(x.sharding): results.append(shard_device_array(x, devices, indices, sharding)) else: @@ -1145,8 +1151,6 @@ def _array_shard_arg(xs, shardings): assert results[i] is None results[i] = copy_out return results - - pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg @@ -1178,8 +1182,8 @@ def _array_local_result_handler(aval, sharding, indices): # Token handlers -def _token_shard_arg(xs, shardings): - return _array_shard_arg([x._buf for x in xs], shardings) +def _token_shard_arg(xs, shardings, layouts): + return _array_shard_arg([x._buf for x in xs], shardings, layouts) pxla.shard_arg_handlers[core.Token] = _token_shard_arg diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index bb6f5f4110b6..59739f4130f3 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -134,7 +134,7 @@ def get_token_input( # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. s = jax.sharding.GSPMDSharding.get_replicated(devices) - sharded_tok = core.Token(pxla.shard_args([s], [tok])[0]) + sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0]) self.current_tokens[eff] = sharded_tok return sharded_tok @@ -515,7 +515,10 @@ def _batched_device_put_impl( if shard_arg_xs: # Batch shard_arg calls. Helps improve efficiency for backends that support # efficient batch transfer. - shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs) + # device_put handles `Layout` via a different path, so just pass `None` as + # the layout here. + shard_arg_results = pxla.shard_args( + shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs) for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results): assert isinstance(ys[i], _DeferredShardArg) ys[i] = ys[i].result_handler(shard_arg_result) diff --git a/jax/_src/earray.py b/jax/_src/earray.py index 36c8dc80c8ca..6598df01330a 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -104,11 +104,12 @@ def global_shards(self): # TODO(mattjj): _set_array_base_attributes -def _earray_shard_arg_handler(xs, shardings): +def _earray_shard_arg_handler(xs, shardings, layouts): arrs = [x._data for x in xs] phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding) for x, sharding in zip(xs, shardings)] - return pxla.shard_args(phys_shardings, arrs) + # TODO(yashkatariya): `layouts` should be converted to physical layouts. + return pxla.shard_args(phys_shardings, layouts, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 814c6a9886d7..e798a6fbdba9 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1000,7 +1000,7 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout, return "auto" if aval is core.abstract_token: return "default" - return layout._to_xla_layout(aval.dtype) # type: ignore + return str(layout._to_xla_layout(aval.dtype)) # type: ignore def _get_mem_kind(s: JSharding | None) -> str | None: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index baf475592f80..afb0addc2fef 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -32,6 +32,7 @@ import jax +from jax._src import api from jax._src import api_util from jax._src import compiler from jax._src import config @@ -60,6 +61,7 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -106,39 +108,67 @@ class WeakRefList(list): def identity(x): return x @profiler.annotate_function -def shard_args(shardings: Sequence[JSharding], args, canonicalize=True) -> Sequence[xc.ArrayImpl]: +def shard_args(shardings: Sequence[JSharding], layouts, args, + canonicalize=True) -> Sequence[xc.ArrayImpl]: # Fast path for one argument. if len(args) == 1: arg = args[0] if canonicalize: arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)]([arg], shardings) + return shard_arg_handlers[type(arg)]([arg], shardings, layouts) - # type(arg) -> (indices, args, shardings) - batches = collections.defaultdict(lambda: ([], [], [])) # type: ignore - for i, (arg, sharding) in enumerate(safe_zip(args, shardings)): + # type(arg) -> (list[indices], list[args], list[shardings]) + batches = collections.defaultdict(lambda: ([], [], [], [])) # type: ignore + for i, (arg, sharding, layout) in enumerate(safe_zip(args, shardings, layouts)): if canonicalize: arg = xla.canonicalize_dtype(arg) batch = batches[type(arg)] batch[0].append(i) batch[1].append(arg) batch[2].append(sharding) + batch[3].append(layout) # Call `shard_arg_handlers` per batch and build a flat list of arrays returned # from each call in the same order as `args`. Since `batches` is grouped by # types, we cannot simply flatten the results and we have to use the original # indices to put each array back to its original position. results: list[jax.Array | None] = [None] * len(args) - for t, (indices, a, s) in batches.items(): - outs = shard_arg_handlers[t](a, s) + for t, (indices, a, s, l) in batches.items(): + outs = shard_arg_handlers[t](a, s, l) for i, out in safe_zip(indices, outs): results[i] = out - assert all(result is not None for result in results) return results -shard_arg_handlers: dict[Any, Callable[[Sequence[Any], Sequence[Any]], Sequence[Any]]] = {} +shard_arg_handlers: dict[ + Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any]], Sequence[Any]] +] = {} + + +def is_default_layout(curr_layout, sharding, aval): + if curr_layout is None or sharding is None: + return True + if (aval is core.abstract_token or aval.dtype == dtypes.float0 or + dtypes.issubdtype(aval.dtype, dtypes.extended)): + return True + if isinstance(curr_layout, AutoLayout): + return False + d = sharding._device_assignment[0] + shard_shape = sharding.shard_shape(aval.shape) + try: + # TODO(yashkatariya): Replace this with normal `==` check once CPU supports + # int4. + return is_user_xla_layout_equal( + curr_layout, + DeviceLocalLayout.from_pjrt_layout( + d.client.get_default_layout(aval.dtype, shard_shape, d))) + except xe.XlaRuntimeError as e: + msg, *_ = e.args + if isinstance(msg, str) and msg.startswith("UNIMPLEMENTED"): + return True + else: + raise @lru_cache(maxsize=1024) @@ -146,34 +176,37 @@ def _get_replicated_slices(num_addressable_devices: int): return ((slice(None),),) * num_addressable_devices -def _masked_array_error(xs, shardings): +def _masked_array_error(xs, shardings, layouts): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error -def _shard_array(xs, shardings): +def _shard_np_array(xs, shardings, layouts): results = [] - for x, sharding in safe_zip(xs, shardings): + for x, sharding, layout in safe_zip(xs, shardings, layouts): devices = sharding._addressable_device_assignment if x.dtype == dtypes.float0: x = np.zeros(x.shape, dtype=np.dtype(bool)) aval = api_util.shaped_abstractify(x) - if sharding.is_fully_replicated: - shards = [x] * len(devices) + if not is_default_layout(layout, sharding, aval): + results.append(api.device_put(x, Layout(layout, sharding))) else: - indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) - shards = [x[i] for i in indices] - results.append(batched_device_put(aval, sharding, shards, devices)) + if sharding.is_fully_replicated: + shards = [x] * len(devices) + else: + indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) + shards = [x[i] for i in indices] + results.append(batched_device_put(aval, sharding, shards, devices)) return results for _t in array_types: - shard_arg_handlers[_t] = _shard_array + shard_arg_handlers[_t] = _shard_np_array -def _shard_darray(xs, shardings): - return shard_args(shardings, [x._data for x in xs]) +def _shard_darray(xs, shardings, layouts): + return shard_args(shardings, layouts, [x._data for x in xs]) shard_arg_handlers[core.DArray] = _shard_darray -def _shard_mutable_array(xs, shardings): - return shard_args(shardings, [x._buf for x in xs]) +def _shard_mutable_array(xs, shardings, layouts): + return shard_args(shardings, layouts, [x._buf for x in xs]) shard_arg_handlers[core.MutableArray] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, @@ -931,6 +964,7 @@ def build_execute_fun(self): handle_outs = local_avals_to_results_handler(self.local_output_avals, self.output_shardings) handle_args = InputsHandler(self.input_shardings, + [None] * len(self.input_shardings), self.compiled.local_devices(), input_indices) execute_fun = ExecuteReplicated(self.compiled, "parallel computation", self.backend, handle_args, handle_outs, @@ -1109,12 +1143,15 @@ def _get_pmap_sharding(devices, specs): class InputsHandler: - __slots__ = ("handler", "local_devices", "in_shardings", "input_indices") + __slots__ = ("handler", "in_shardings", "in_layouts", "local_devices", + "input_indices") - def __init__(self, in_shardings, local_devices=None, input_indices=None): - self.handler = partial(shard_args, in_shardings) - self.local_devices = local_devices + def __init__(self, in_shardings, in_layouts, local_devices=None, + input_indices=None): + self.handler = partial(shard_args, in_shardings, in_layouts) self.in_shardings = in_shardings + self.in_layouts = in_layouts + self.local_devices = local_devices self.input_indices = input_indices def __call__(self, input_buffers): @@ -1122,8 +1159,9 @@ def __call__(self, input_buffers): def __str__(self): return ("InputsHandler(\n" - f"local_devices={self.local_devices},\n" f"in_shardings={self.in_shardings},\n" + f"in_layouts={self.in_layouts},\n" + f"local_devices={self.local_devices},\n" f"input_indices={self.input_indices})") @@ -1849,7 +1887,7 @@ def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval if is_unspecified_or_auto(sharding): return None # TODO(yashkatariya): Figure out how layouts work with extended dtypes. - if dtypes.issubdtype(aval.dtype, dtypes.extended): + if aval is core.abstract_token or dtypes.issubdtype(aval.dtype, dtypes.extended): return None if not core.is_constant_shape(aval.shape): return None @@ -2505,7 +2543,7 @@ def maybe_recover_user_shardings( def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout, xl: DeviceLocalLayout) -> bool: - if isinstance(ul, DeviceLocalLayout) and ul._tiling is None: + if isinstance(ul, DeviceLocalLayout) and not ul._tiling: return ul.major_to_minor == xl.major_to_minor else: return ul == xl @@ -2742,7 +2780,7 @@ class UnloadedMeshExecutable: pgle_profiler: profiler.PGLEProfiler | None def build_unsafe_call(self): - handle_args = InputsHandler(self.input_shardings) + handle_args = InputsHandler(self.input_shardings, self.in_layouts) handle_outs = global_avals_to_results_handler( self.output_avals, self.output_shardings, self.committed) @@ -2882,9 +2920,7 @@ class MeshExecutableFastpathData(NamedTuple): out_avals: Sequence[ShapedArray] out_committed: Sequence[bool] kept_var_bitvec: Iterable[bool] - # TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24 - arg_handler_devices: Sequence[xc.Device] - arg_handler_indices: Sequence[tuple[Index | None, ...]] + in_device_local_layouts: Sequence[DeviceLocalLayout | None] def reflatten_outputs_for_dispatch(out_tree, out_flat): @@ -2992,18 +3028,36 @@ def aot_cache_miss(*args, **kwargs): else s for s, a in zip(self._in_shardings, self.in_avals) ] + in_dlls = get_layouts_for_fasthpath_data( + self._in_layouts, in_shardings, self.in_avals) fastpath_data = MeshExecutableFastpathData( self.xla_executable, out_tree_dispatch, in_shardings, self._out_shardings, out_avals, out_committed, kept_var_bitvec, - self.unsafe_call.in_handler.local_devices, - self.unsafe_call.in_handler.input_indices) + in_dlls) else: fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry return xc._xla.pjit( self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, lambda x, s: shard_args([s], [x])[0]) + tree_util.dispatch_registry, cc_shard_arg) + +if xla_extension_version < 282: + def cc_shard_arg(x, sharding): + return shard_args([sharding], [None], [x])[0] +else: + def cc_shard_arg(x, sharding, layout): # type: ignore + return shard_args([sharding], [layout], [x])[0] + + +def get_layouts_for_fasthpath_data(in_layouts, in_shardings, in_avals): + in_dlls = [] + for l, s, a in zip(in_layouts, in_shardings, in_avals): + if is_default_layout(l, s, a): + in_dlls.append(None) + else: + in_dlls.append(l) + return in_dlls def check_arg_avals_for_call(ref_avals, arg_avals, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index dd643b050c8f..443470e129fa 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -688,7 +688,7 @@ def _maybe_put(x): aval = shaped_abstractify(x) s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]) result_handler = pxla.global_aval_to_result_handler(aval, s, False) - return result_handler(pxla.shard_args([s], [x])) + return result_handler(pxla.shard_args([s], [None], [x])) else: return x diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 84708555041f..64bbd3268b16 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -69,7 +69,7 @@ def __eq__(self, other): self._tiling == other._tiling and self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits) - def _to_xla_layout(self, dtype) -> str: + def _to_xla_layout(self, dtype) -> xc.Layout: if self._tiling is None: xla_layout = xc.Layout(self.major_to_minor[::-1]) else: @@ -81,7 +81,7 @@ def _to_xla_layout(self, dtype) -> str: sub_byte_size = 0 xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, sub_byte_size) - return str(xla_layout) + return xla_layout def check_compatible_aval(self, aval_shape: Shape): if len(self.major_to_minor) != len(aval_shape): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a09a958ab8a3..63c2cedbe935 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -279,11 +279,12 @@ def _get_fastpath_data( else s for s, a in zip(executable._in_shardings, executable.in_avals) ] + in_dlls = pxla.get_layouts_for_fasthpath_data( + executable._in_layouts, in_shardings, executable.in_avals) fastpath_data = pxla.MeshExecutableFastpathData( executable.xla_executable, out_tree, in_shardings, executable._out_shardings, out_avals, out_committed, kept_var_bitvec, - executable.unsafe_call.in_handler.local_devices, - executable.unsafe_call.in_handler.input_indices) + in_dlls) else: fastpath_data = None return fastpath_data @@ -302,9 +303,7 @@ def _read_most_recent_pjit_call_executable(jaxpr): def _read_pgle_profiler(jaxpr): - return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get( - jaxpr, None - ) + return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(jaxpr, None) def _cpp_pjit_evict_fn(self): self._clear_cache() @@ -343,8 +342,7 @@ def cache_miss(*args, **kwargs): cpp_pjit_f = xc._xla.pjit( fun_name(fun), fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, - jit_info.donate_argnums, tree_util.dispatch_registry, - lambda x, sharding: pxla.shard_args([sharding], [x])[0], + jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg, _get_cpp_global_cache(jit_info.has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) @@ -1729,8 +1727,7 @@ def call_impl_cache_miss(*args_, **kwargs_): in_shardings, out_shardings, None, None) return xc._xla.pjit( name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, - lambda x, sharding: pxla.shard_args([sharding], [x])[0], + tree_util.dispatch_registry, pxla.cc_shard_arg, _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 7091305824ce..c4d6683c0262 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -466,11 +466,12 @@ def __hash__(self) -> int: xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x -def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings): +def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts): arrs = [x._base_array for x in xs] phys_shardings = [physical_sharding(x.aval, sharding) for x, sharding in zip(xs, shardings)] - return pxla.shard_args(phys_shardings, arrs) + # TODO(yashkatariya): `layouts` should be converted to physical layouts. + return pxla.shard_args(phys_shardings, layouts, arrs) pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler diff --git a/tests/lax_test.py b/tests/lax_test.py index 73b21d12923e..7ed17adf45bc 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3373,7 +3373,7 @@ def __repr__(self) -> str: size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) -def shard_foo_array_handler(xs, shardings): +def shard_foo_array_handler(xs, shardings, layouts): results = [] for x, sharding in safe_zip(xs, shardings): device, = sharding._addressable_device_assignment diff --git a/tests/layout_test.py b/tests/layout_test.py index c72082d0a16c..c390bdc9f186 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -500,6 +500,29 @@ def g(x): 'Layout passed to jit does not match the layout on the respective arg'): g(arr) + def test_in_layouts_jit_jnp_input(self): + major_last_layout = DLL(major_to_minor=(1, 0)) + sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) + + f = jax.jit(lambda x: x + 1, + in_shardings=Layout(major_last_layout, sharding)) + + arr = jnp.arange(8 * 128).reshape(8, 128) + out = f(arr) + self.assertArraysEqual(out, arr + 1) + + # cpp dispatch should call into shard_args from cpp. + out2 = f(arr) + self.assertArraysEqual(out2, arr + 1) + + np_inp = np.arange(8 * 128).reshape(8, 128) + out3 = f(np_inp) + self.assertArraysEqual(out3, np_inp + 1) + + # cpp dispatch should call into shard_args from cpp. + out4 = f(np_inp) + self.assertArraysEqual(out4, np_inp + 1) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index c0a3d27dadef..8b121d91ae85 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3015,7 +3015,7 @@ def testShardArgs(self, shape, spec, make_arg): x = np.arange(math.prod(shape)).reshape(shape) arg = make_arg(x) sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec) - results = pxla.shard_args([sharding], [arg]) + results = pxla.shard_args([sharding], [None], [arg]) self.assertEqual(len(results), 1) if isinstance(results[0], array.ArrayImpl): bufs = results[0]._arrays From 6546c4810bff0d7e91eecca8cab5fb25d380df90 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 20 Aug 2024 00:10:17 +0200 Subject: [PATCH 173/702] Added PyUnstable_Module_SetGIL to PyInit_cpu_feature_guard --- jaxlib/cpu_feature_guard.c | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/jaxlib/cpu_feature_guard.c b/jaxlib/cpu_feature_guard.c index 7c8ff2951a79..d18478eb57d5 100644 --- a/jaxlib/cpu_feature_guard.c +++ b/jaxlib/cpu_feature_guard.c @@ -172,5 +172,12 @@ static struct PyModuleDef cpu_feature_guard_module = { #endif EXPORT_SYMBOL PyMODINIT_FUNC PyInit_cpu_feature_guard(void) { - return PyModule_Create(&cpu_feature_guard_module); + PyObject *module = PyModule_Create(&cpu_feature_guard_module); + if (module == NULL) { + return NULL; + } +#ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED); +#endif + return module; } From 1ab6279d4fda6cb38f5ac06e4c5edac70dff10d0 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 19 Aug 2024 18:42:45 -0700 Subject: [PATCH 174/702] Skip the global jit cpp cache if in/out_layouts are not None PiperOrigin-RevId: 665085182 --- jax/_src/interpreters/pxla.py | 23 +++++++++++----------- jax/_src/pjit.py | 30 +++++++++++++---------------- jax/experimental/multihost_utils.py | 12 ++++++------ 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index afb0addc2fef..1398d58fc787 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -146,6 +146,7 @@ def shard_args(shardings: Sequence[JSharding], layouts, args, ] = {} +@lru_cache(maxsize=2048) def is_default_layout(curr_layout, sharding, aval): if curr_layout is None or sharding is None: return True @@ -2548,12 +2549,6 @@ def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout, else: return ul == xl -def _check_user_xla_layout(ul, xl, what: str): - if not is_user_xla_layout_equal(ul, xl): - raise AssertionError( - f"Unexpected XLA layout override: (XLA) {xl} != {ul} " - f"(User {what} layout)") - def _get_layouts_from_executable( xla_executable, in_layouts, out_layouts, num_ordered_effects @@ -2569,19 +2564,23 @@ def _get_layouts_from_executable( out_layouts_xla = out_layouts_xla[num_ordered_effects:] new_in_layouts = [] - for x, i in safe_zip(in_layouts_xla, in_layouts): + for x, l in safe_zip(in_layouts_xla, in_layouts): x = DeviceLocalLayout.from_pjrt_layout(x) - if isinstance(i, DeviceLocalLayout): - _check_user_xla_layout(i, x, "input") + if isinstance(l, DeviceLocalLayout) and not is_user_xla_layout_equal(l, x): + raise AssertionError( + f"Unexpected XLA layout override: (XLA) {x} != {l} " + f"(User input layout)") # Always append the XLA layout because it has the full information # (tiling, etc) even if the user layout does not specify tiling. new_in_layouts.append(x) new_out_layouts = [] - for x, o in safe_zip(out_layouts_xla, out_layouts): + for x, l in safe_zip(out_layouts_xla, out_layouts): x = DeviceLocalLayout.from_pjrt_layout(x) - if isinstance(o, DeviceLocalLayout): - _check_user_xla_layout(o, x, "output") + if isinstance(l, DeviceLocalLayout) and not is_user_xla_layout_equal(l, x): + raise AssertionError( + f"Unexpected XLA layout override: (XLA) {x} != {l} " + f"(User output layout)") # Always append the XLA layout because it has the full information # (tiling, etc) even if the user layout does not specify tiling. new_out_layouts.append(x) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 63c2cedbe935..9383c26bf7c4 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -351,14 +351,15 @@ def cache_miss(*args, **kwargs): return cpp_pjitted_f -def _pjit_explicit_sharding(in_shardings, out_shardings, device, - backend) -> bool: - in_shardings_flat, _ = tree_flatten(in_shardings) - out_shardings_flat, _ = tree_flatten(out_shardings) +def _pjit_explicit_sharding_and_layout( + in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, + device, backend) -> bool: return (device is not None or backend is not None or any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(i) for i in out_shardings_flat)) + any(not is_unspecified(o) for o in out_shardings_flat) or + any(i is not None for i in in_layouts_flat) or + any(o is not None for o in out_layouts_flat)) def _split_layout_and_sharding(entries): @@ -444,8 +445,9 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) - has_explicit_sharding = _pjit_explicit_sharding( - in_shardings, out_shardings, device, backend) + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings_leaves, out_shardings_leaves, in_layouts_leaves, + out_layouts_leaves, device, backend) return PjitInfo( fun_sourceinfo=fun_sourceinfo, @@ -1723,8 +1725,8 @@ def call_impl_cache_miss(*args_, **kwargs_): jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) donated_argnums = [i for i, d in enumerate(donated_invars) if d] - has_explicit_sharding = _pjit_explicit_sharding( - in_shardings, out_shardings, None, None) + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings, out_shardings, in_layouts, out_layouts, None, None) return xc._xla.pjit( name, f, call_impl_cache_miss, [], [], donated_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg, @@ -1753,14 +1755,8 @@ def _pjit_lower_cached( lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): - if resource_env is not None: - mesh = resource_env.physical_mesh - api_name = 'pjit' - else: - # resource_env is `None` in the jit wrapper around pjit. - mesh = None - api_name = 'jit' - + mesh, api_name = ((resource_env.physical_mesh, 'pjit') + if resource_env is not None else (None, 'jit')) return pxla.lower_sharding_computation( jaxpr, api_name, name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 1ca601da3942..554bf2641769 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -32,7 +32,6 @@ from jax._src.interpreters import pxla from jax.interpreters import xla from jax._src import pjit as pjit_lib -from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax._src import distributed from jax._src.util import safe_zip @@ -91,17 +90,19 @@ def sync_global_devices(name: str): assert_equal(h, f"sync_global_devices name mismatch ('{name}')") -# Identity function is at the top level so that `process_allgather` doesn't -# recompile on every invocation. def _identity_fn(x): return x +@lru_cache(maxsize=128) +def _jitted_identity_fn(sharding): + return jax.jit(_identity_fn, out_shardings=sharding) + def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: reps = sharding_impls.GSPMDSharding.get_replicated( inp.sharding._device_assignment) - out = pjit(_identity_fn, out_shardings=reps)(inp) + out = _jitted_identity_fn(reps)(inp) else: # All inputs here will be fully addressable. if jax.process_count() == 1: @@ -124,8 +125,7 @@ def _handle_array_process_allgather(inp, tiled): bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( global_aval.shape, s, bufs) - with global_mesh: - out = pjit(_identity_fn, out_shardings=None)(global_arr) + out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr) return np.asarray(out.addressable_data(0)) From 71a93d0c8706222088b626cb4ffe55ca17d898b2 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 20 Aug 2024 05:06:08 -0700 Subject: [PATCH 175/702] Port QR factorization GPU kernel to FFI. The biggest change here is that we now ignore the `info` parameter that is returned by `getrf`. In the previous implementation, we would return an error in the batched implementation, or set the relevant matrix entries to NaN in the non-batched version if `info != 0`. But, since info is only used for shape checking (see LAPACK, cuBLAS and cuSolver docs), I argue that we will never see `info != 0`, because we're including all the shape checks in the kernel already. PiperOrigin-RevId: 665307128 --- jaxlib/cuda/BUILD | 1 + jaxlib/ffi_helpers.h | 54 +++++++ jaxlib/gpu/gpu_kernels.cc | 2 + jaxlib/gpu/solver.cc | 1 + jaxlib/gpu/solver_kernels_ffi.cc | 257 +++++++++++++++++++++++++------ jaxlib/gpu/solver_kernels_ffi.h | 1 + jaxlib/rocm/BUILD.bazel | 1 + 7 files changed, 269 insertions(+), 48 deletions(-) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index b31ed78e3b58..07cb21078714 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -244,6 +244,7 @@ cc_library( "@xla//xla/tsl/cuda:cusolver", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index bedfdca2f11b..fba57d11b9f2 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -142,6 +143,59 @@ inline absl::StatusOr> SplitBatch2D( trailingDims.front(), trailingDims.back()); } +inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, + int64_t expected_batch, + std::string_view name, + std::string_view op) { + auto batch = GetBatchSize(dimensions); + if (batch != expected_batch) { + return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( + "Invalid total batch size for input %s to %s. Expected %d, got %d.", + name, op, expected_batch, batch)); + } + return ::xla::ffi::Error::Success(); +} + +inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, + std::tuple shape, + std::string_view name, + std::string_view op) { + FFI_ASSIGN_OR_RETURN((auto [batch, size]), SplitBatch1D(dimensions)); + auto [expected_batch, expected_size] = shape; + if (batch != expected_batch) { + return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( + "Invalid total batch size for input %s to %s. Expected %d, got %d.", + name, op, expected_batch, batch)); + } + if (batch != expected_batch || size != expected_size) { + return ::xla::ffi::Error::InvalidArgument( + absl::StrFormat("Invalid trailing dimension for input %s " + "to %s. Expected %d, got %d.", + name, op, expected_size, size)); + } + return ::xla::ffi::Error::Success(); +} + +inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, + std::tuple shape, + std::string_view name, + std::string_view op) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), SplitBatch2D(dimensions)); + auto [expected_batch, expected_rows, expected_cols] = shape; + if (batch != expected_batch) { + return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( + "Invalid total batch size for input %s to %s. Expected %d, got %d.", + name, op, expected_batch, batch)); + } + if (rows != expected_rows || cols != expected_cols) { + return ::xla::ffi::Error::InvalidArgument( + absl::StrFormat("Invalid matrix dimensions for input %s to %s. " + "Expected (%d, %d), got (%d, %d).", + name, op, expected_rows, expected_cols, rows, cols)); + } + return ::xla::ffi::Error::Success(); +} + template <::xla::ffi::DataType dtype> auto AllocateScratchMemory(std::size_t size) -> std::unique_ptr>[]> { diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index b76cea19ea2e..6eb426b7fb43 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -47,6 +47,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", + GeqrfFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 223c8a9798be..4ee7a9f1dbf7 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -476,6 +476,7 @@ nb::dict Registrations() { #endif // JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); + dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); return dict; } diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 7b4b673bf9ec..91124a847121 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -17,9 +17,11 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "jaxlib/ffi_helpers.h" #include "jaxlib/gpu/blas_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" @@ -33,6 +35,34 @@ namespace JAX_GPU_NAMESPACE { namespace ffi = ::xla::ffi; +namespace { +template +inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, + int64_t size, + std::string_view name) { + auto maybe_workspace = scratch.Allocate(sizeof(T) * size); + if (!maybe_workspace.has_value()) { + return absl::Status( + absl::StatusCode::kResourceExhausted, + absl::StrFormat("Unable to allocate workspace for %s", name)); + } + return static_cast(maybe_workspace.value()); +} +} // namespace + +#define SOLVER_DISPATCH_IMPL(impl, ...) \ + if (dataType == ffi::DataType::F32) { \ + return impl(__VA_ARGS__); \ + } else if (dataType == ffi::DataType::F64) { \ + return impl(__VA_ARGS__); \ + } else if (dataType == ffi::DataType::C64) { \ + return impl(__VA_ARGS__); \ + } else if (dataType == ffi::DataType::C128) { \ + return impl(__VA_ARGS__); \ + } + +// LU decomposition: getrf + namespace { #define GETRF_KERNEL_IMPL(type, name) \ template <> \ @@ -72,13 +102,8 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); FFI_ASSIGN_OR_RETURN(int lwork, GetrfKernel::BufferSize(handle.get(), m, n)); - - auto maybe_workspace = scratch.Allocate(sizeof(T) * lwork); - if (!maybe_workspace.has_value()) { - return ffi::Error(ffi::ErrorCode::kUnknown, - "Unable to allocate workspace for getrf"); - } - auto workspace = static_cast(maybe_workspace.value()); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "getrf")); auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); @@ -90,11 +115,12 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, gpuMemcpyDeviceToDevice, stream))); } + int ipiv_step = std::min(m, n); for (int i = 0; i < batch; ++i) { FFI_RETURN_IF_ERROR_STATUS(GetrfKernel::Run( handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data)); out_data += m * n; - ipiv_data += std::min(m, n); + ipiv_data += ipiv_step; ++info_data; } return ffi::Error::Success(); @@ -125,13 +151,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, ffi::Result> info) { FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); - - auto maybe_workspace = scratch.Allocate(sizeof(void*) * batch); - if (!maybe_workspace.has_value()) { - return ffi::Error(ffi::ErrorCode::kUnknown, - "Unable to allocate workspace for batched getrf"); - } - auto workspace = maybe_workspace.value(); + FFI_ASSIGN_OR_RETURN(auto batch_ptrs, + AllocateWorkspace(scratch, batch, "batched getrf")); auto a_data = a.untyped_data(); auto out_data = out->untyped_data(); @@ -143,10 +164,10 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, gpuMemcpyDeviceToDevice, stream))); } - MakeBatchPointersAsync(stream, out_data, workspace, batch, sizeof(T) * n * n); + MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch, + sizeof(T) * n * n); FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); - auto batch_ptrs = static_cast(workspace); FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run( handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch)); @@ -159,43 +180,24 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::Result> info) { auto dataType = a.element_type(); if (dataType != out->element_type()) { - return ffi::Error( - ffi::ErrorCode::kInvalidArgument, + return ffi::Error::InvalidArgument( "The input and output to getrf must have the same element type"); } FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), SplitBatch2D(a.dimensions())); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "getrf")); + FFI_RETURN_IF_ERROR(CheckShape( + ipiv->dimensions(), {batch, std::min(rows, cols)}, "ipiv", "getrf")); + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "getrf")); if (batch > 1 && rows == cols && rows / batch <= 128) { - if (dataType == ffi::DataType::F32) { - return GetrfBatchedImpl(batch, cols, stream, scratch, a, out, ipiv, - info); - } else if (dataType == ffi::DataType::F64) { - return GetrfBatchedImpl(batch, cols, stream, scratch, a, out, - ipiv, info); - } else if (dataType == ffi::DataType::C64) { - return GetrfBatchedImpl(batch, cols, stream, scratch, a, - out, ipiv, info); - } else if (dataType == ffi::DataType::C128) { - return GetrfBatchedImpl( - batch, cols, stream, scratch, a, out, ipiv, info); - } + SOLVER_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a, out, + ipiv, info); } else { - if (dataType == ffi::DataType::F32) { - return GetrfImpl(batch, rows, cols, stream, scratch, a, out, ipiv, - info); - } else if (dataType == ffi::DataType::F64) { - return GetrfImpl(batch, rows, cols, stream, scratch, a, out, ipiv, - info); - } else if (dataType == ffi::DataType::C64) { - return GetrfImpl(batch, rows, cols, stream, scratch, a, out, - ipiv, info); - } else if (dataType == ffi::DataType::C128) { - return GetrfImpl(batch, rows, cols, stream, scratch, a, - out, ipiv, info); - } - } - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "Unsupported element type for getrf"); + SOLVER_DISPATCH_IMPL(GetrfImpl, batch, rows, cols, stream, scratch, a, out, + ipiv, info); + } + return ffi::Error::InvalidArgument("Unsupported element type for getrf"); } } // namespace @@ -210,5 +212,164 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret>() // info ); +// QR decomposition: geqrf + +namespace { +#define GEQRF_KERNEL_IMPL(type, name) \ + template <> \ + struct GeqrfKernel { \ + static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \ + int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \ + return lwork; \ + } \ + static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \ + type* tau, type* workspace, int lwork, \ + int* info) { \ + return JAX_AS_STATUS( \ + name(handle, m, n, a, m, tau, workspace, lwork, info)); \ + } \ + } + +template +struct GeqrfKernel; +GEQRF_KERNEL_IMPL(float, gpusolverDnSgeqrf); +GEQRF_KERNEL_IMPL(double, gpusolverDnDgeqrf); +GEQRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgeqrf); +GEQRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgeqrf); +#undef GEQRF_KERNEL_IMPL + +template +ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result tau) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(int lwork, + GeqrfKernel::BufferSize(handle.get(), m, n)); + + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "geqrf")); + // Note: We ignore the returned value of info because it is only used for + // shape checking (which we already do ourselves), but it is expected to be + // in device memory, so we need to allocate it. + FFI_ASSIGN_OR_RETURN(auto info, AllocateWorkspace(scratch, 1, "geqrf")); + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto tau_data = static_cast(tau->untyped_data()); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( + gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols, + gpuMemcpyDeviceToDevice, stream))); + } + + int out_step = m * n; + int tau_step = std::min(m, n); + for (int i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel::Run( + handle.get(), m, n, out_data, tau_data, workspace, lwork, info)); + out_data += out_step; + tau_data += tau_step; + } + return ffi::Error::Success(); +} + +#define GEQRF_BATCHED_KERNEL_IMPL(type, name) \ + template <> \ + struct GeqrfBatchedKernel { \ + static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \ + type** tau, int* info, int batch) { \ + return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \ + } \ + } + +template +struct GeqrfBatchedKernel; +GEQRF_BATCHED_KERNEL_IMPL(float, gpublasSgeqrfBatched); +GEQRF_BATCHED_KERNEL_IMPL(double, gpublasDgeqrfBatched); +GEQRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgeqrfBatched); +GEQRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgeqrfBatched); +#undef GEQRF_BATCHED_KERNEL_IMPL + +template +ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result tau) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(auto out_batch_ptrs, + AllocateWorkspace(scratch, batch, "batched geqrf")); + FFI_ASSIGN_OR_RETURN(auto tau_batch_ptrs, + AllocateWorkspace(scratch, batch, "batched geqrf")); + + auto a_data = a.untyped_data(); + auto out_data = out->untyped_data(); + auto tau_data = tau->untyped_data(); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( + gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols, + gpuMemcpyDeviceToDevice, stream))); + } + + MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch, + sizeof(T) * m * n); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); + MakeBatchPointersAsync(stream, tau_data, tau_batch_ptrs, batch, + sizeof(T) * std::min(m, n)); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); + + // We ignore the output value of `info` because it is only used for shape + // checking. + int info; + FFI_RETURN_IF_ERROR_STATUS(GeqrfBatchedKernel::Run( + handle.get(), m, n, out_batch_ptrs, tau_batch_ptrs, &info, batch)); + + return ffi::Error::Success(); +} + +ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result tau) { + auto dataType = a.element_type(); + if (dataType != out->element_type() || dataType != tau->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to geqrf must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "geqrf")); + FFI_RETURN_IF_ERROR(CheckShape( + tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqrf")); + if (batch > 1 && rows / batch <= 128 && cols / batch <= 128) { + SOLVER_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream, scratch, + a, out, tau); + } else { + SOLVER_DISPATCH_IMPL(GeqrfImpl, batch, rows, cols, stream, scratch, a, out, + tau); + } + return ffi::Error::InvalidArgument("Unsupported element type for geqrf"); +} +} // namespace + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // a + .Ret() // out + .Ret() // tau +); + +#undef SOLVER_DISPATCH_IMPL + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 64fb1baba56a..d9c3da47655a 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -23,6 +23,7 @@ namespace jax { namespace JAX_GPU_NAMESPACE { XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index ce733d827e35..7f90d6ee35d5 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -177,6 +177,7 @@ cc_library( "//jaxlib:ffi_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", From bd90968a2595068f728744980c41b088aa2f604e Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 20 Aug 2024 05:45:19 -0700 Subject: [PATCH 176/702] Port the GPU Cholesky update custom call to the FFI. PiperOrigin-RevId: 665319689 --- jax/_src/lax/linalg.py | 38 +++++++++------- jaxlib/cuda/BUILD | 10 +--- jaxlib/gpu/gpu_kernels.cc | 2 + jaxlib/gpu/linalg.cc | 2 + jaxlib/gpu/linalg_kernels.cc | 81 ++++++++++++++++++++++++--------- jaxlib/gpu/linalg_kernels.cu.cc | 62 +++++++++++++++---------- jaxlib/gpu/linalg_kernels.h | 8 ++-- jaxlib/gpu/vendor.h | 14 +++++- jaxlib/gpu_linalg.py | 9 +++- jaxlib/rocm/BUILD.bazel | 2 - 10 files changed, 150 insertions(+), 78 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index f771b83b372f..37b2ccf61ec4 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -505,20 +505,26 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector): r_matrix.shape, w_vector.shape)) return ShapedArray(r_matrix.shape, r_matrix.dtype) -def _cholesky_update_cuda_lowering_rule(ctx, r_matrix, w_vector): - r_matrix_aval, _ = ctx.avals_in - try: - [platform] = ctx.module_context.platforms - except ValueError: - raise ValueError( - "Can only lower cholesky_update on a single platform." - ) from None - if platform != "cuda": - raise NotImplementedError( - "Can only lower fast cholesky_update on CUDA." - ) - return gpu_linalg.cuda_cholesky_update( - r_matrix, w_vector, r_matrix_aval.dtype) +def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector): + # TODO(b/360781533): Remove guard after 3 week forward compatibility period. + if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32): + r_matrix_aval, _ = ctx.avals_in + try: + [platform] = ctx.module_context.platforms + except ValueError: + raise ValueError( + "Can only lower cholesky_update on a single platform." + ) from None + if platform != "cuda": + raise NotImplementedError( + "Can only lower fast cholesky_update on CUDA." + ) + return gpu_linalg.cuda_cholesky_update( + r_matrix, w_vector, r_matrix_aval.dtype) + rule = ffi.ffi_lowering(f"{target_name_prefix}_cholesky_update_ffi", + operand_output_aliases={0: 0, 1: 1}) + sub_ctx = ctx.replace(avals_out=ctx.avals_in) + return rule(sub_ctx, r_matrix, w_vector)[:1] def _cholesky_update_jax_fn(R, z): @@ -557,8 +563,8 @@ def _drot( cholesky_update_p.def_impl(partial(dispatch.apply_primitive, cholesky_update_p)) mlir.register_lowering( - cholesky_update_p, _cholesky_update_cuda_lowering_rule, platform='cuda') - + cholesky_update_p, partial(_cholesky_update_gpu_lowering_rule, "cu"), + platform='cuda') mlir.register_lowering( cholesky_update_p, mlir.lower_fun(_cholesky_update_jax_fn, multiple_results=False)) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 07cb21078714..bd74be6732fd 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -357,7 +357,6 @@ cc_library( "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -367,18 +366,13 @@ cc_library( cuda_library( name = "cuda_linalg_kernels_impl", - srcs = [ - "//jaxlib/gpu:linalg_kernels.cu.cc", - ], - hdrs = [ - "//jaxlib/gpu:linalg_kernels.h", - ], + srcs = ["//jaxlib/gpu:linalg_kernels.cu.cc"], + hdrs = ["//jaxlib/gpu:linalg_kernels.h"], deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", - "@local_config_cuda//cuda:cuda_headers", ], ) diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 6eb426b7fb43..1814641bb4fb 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -57,6 +57,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_cholesky_update_ffi", "CUDA", + CholeskyUpdateFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_lu_pivots_to_permutation", "CUDA", LuPivotsToPermutation); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_threefry2x32_ffi", "CUDA", diff --git a/jaxlib/gpu/linalg.cc b/jaxlib/gpu/linalg.cc index 189d3a01e382..0ab2b87a290d 100644 --- a/jaxlib/gpu/linalg.cc +++ b/jaxlib/gpu/linalg.cc @@ -41,6 +41,8 @@ NB_MODULE(_linalg, m) { EncapsulateFfiHandler(LuPivotsToPermutation); dict[JAX_GPU_PREFIX "_cholesky_update"] = EncapsulateFunction(CholeskyUpdate); + dict[JAX_GPU_PREFIX "_cholesky_update_ffi"] = + EncapsulateFunction(CholeskyUpdateFfi); return dict; }); m.def("build_cholesky_update_descriptor", &BuildCholeskyUpdateDescriptor); diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index 6b143e893264..b22248409b60 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -16,13 +16,9 @@ limitations under the License. #include "jaxlib/gpu/linalg_kernels.h" #include -#include -#include #include #include -#include -#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -60,32 +56,73 @@ void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque, } namespace { -absl::StatusOr> GetDimensions( - ffi::Span dims, const std::string& arg_name) { - if (dims.size() < 1) { - return absl::InvalidArgumentError( - absl::StrFormat("%s must have at least one dimension", arg_name)); +ffi::Error CholeskyUpdateFfiImpl(gpuStream_t stream, ffi::AnyBuffer matrix_in, + ffi::AnyBuffer vector_in, + ffi::Result matrix_out, + ffi::Result vector_out) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(matrix_in.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The matrix input to Cholesky update must be square."); } - std::int64_t batch_size = 1; - if (dims.size() >= 2) { - batch_size = - absl::c_accumulate(dims.first(dims.size() - 1), 1, std::multiplies<>()); + FFI_RETURN_IF_ERROR(CheckShape(vector_in.dimensions(), {batch, cols}, + "vector", "cholesky_update")); + FFI_RETURN_IF_ERROR(CheckShape(matrix_out->dimensions(), {batch, rows, cols}, + "matrix_out", "cholesky_update")); + FFI_RETURN_IF_ERROR(CheckShape(vector_out->dimensions(), {batch, cols}, + "vector_out", "cholesky_update")); + FFI_ASSIGN_OR_RETURN(auto size, MaybeCastNoOverflow(cols)); + auto dtype = matrix_in.element_type(); + if (dtype != ffi::F32 && dtype != ffi::F64) { + return ffi::Error::InvalidArgument( + "Invalid input type for Cholesky update; must be float32 or float64."); } - JAX_ASSIGN_OR_RETURN(auto size, - MaybeCastNoOverflow(dims.back())); - return std::make_pair(batch_size, size); + if (vector_in.element_type() != dtype || + matrix_out->element_type() != dtype || + vector_out->element_type() != dtype) { + return ffi::Error::InvalidArgument( + "All input and output types for Cholesky update must match."); + } + bool is_single_precision = dtype == ffi::F32; + auto matrix = matrix_out->untyped_data(); + if (matrix_in.untyped_data() != matrix) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( + gpuMemcpyAsync(matrix, matrix_in.untyped_data(), matrix_in.size_bytes(), + gpuMemcpyDeviceToDevice, stream))); + } + auto vector = vector_out->untyped_data(); + if (vector_in.untyped_data() != vector) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( + gpuMemcpyAsync(vector, vector_in.untyped_data(), vector_in.size_bytes(), + gpuMemcpyDeviceToDevice, stream))); + } + for (auto n = 0; n < batch; ++n) { + LaunchCholeskyUpdateFfiKernel(stream, matrix, vector, size, + is_single_precision); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); + } + return ffi::Error::Success(); } +} // namespace +XLA_FFI_DEFINE_HANDLER_SYMBOL(CholeskyUpdateFfi, CholeskyUpdateFfiImpl, + ffi::Ffi::Bind() + .Ctx>() + .Arg() + .Arg() + .Ret() + .Ret()); + +namespace { ffi::Error LuPivotsToPermutationImpl( gpuStream_t stream, ffi::Dictionary /* unused */, ffi::Buffer pivots, ffi::Result> permutation) { - FFI_ASSIGN_OR_RETURN(auto pivots_dims, - GetDimensions(pivots.dimensions(), "pivots")); - FFI_ASSIGN_OR_RETURN(auto permutation_dims, - GetDimensions(permutation->dimensions(), "permutation")); - auto [batch_size, pivot_size] = pivots_dims; - auto [permutation_batch, permutation_size] = permutation_dims; + FFI_ASSIGN_OR_RETURN((auto [batch_size, pivot_size]), + SplitBatch1D(pivots.dimensions())); + FFI_ASSIGN_OR_RETURN((auto [permutation_batch, permutation_size]), + SplitBatch1D(permutation->dimensions())); if (permutation_batch != batch_size) { return ffi::Error(ffi::ErrorCode::kInvalidArgument, "pivots and permutation must have the same batch size."); diff --git a/jaxlib/gpu/linalg_kernels.cu.cc b/jaxlib/gpu/linalg_kernels.cu.cc index 8aa769bb5735..50c653d8cf16 100644 --- a/jaxlib/gpu/linalg_kernels.cu.cc +++ b/jaxlib/gpu/linalg_kernels.cu.cc @@ -15,18 +15,11 @@ limitations under the License. #include "jaxlib/gpu/linalg_kernels.h" -#include +#include #include -#include #include "jaxlib/gpu/vendor.h" -#ifdef JAX_GPU_HIP -#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h" -#else // JAX_GPU_CUDA -#include "third_party/gpus/cuda/include/cooperative_groups.h" -#endif - namespace cg = cooperative_groups; namespace jax { @@ -47,7 +40,6 @@ __device__ void drotg(T* da, T* db, T* c, T* s) { T rh = rhypot(a, b); *c = a * rh; *s = -(b * rh); - return; } template @@ -85,15 +77,9 @@ void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers, reinterpret_cast(&uVector), reinterpret_cast(&nSize), }; -#ifdef JAX_GPU_HIP - hipLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, + gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, block_dim, arg_ptrs, /*dynamic_shared_mem_bytes=*/0, stream); -#else // JAX_GPU_CUDA - cudaLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, - block_dim, arg_ptrs, - /*dynamic_shared_mem_bytes=*/0, stream); -#endif } void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, @@ -102,13 +88,8 @@ void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, LinalgType type = descriptor.linalg_type; int dev = 0; -#ifdef JAX_GPU_HIP - hipDeviceProp_t deviceProp; - hipGetDeviceProperties(&deviceProp, dev); -#else // JAX_GPU_CUDA - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, dev); -#endif + gpuDeviceProp deviceProp; + gpuGetDeviceProperties(&deviceProp, dev); int block_dim = deviceProp.maxThreadsPerBlock; int grid_dim = deviceProp.multiProcessorCount; @@ -125,6 +106,41 @@ void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, } } +template +void LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix, + void* vector, int grid_dim, + int block_dim, int nSize) { + T* rMatrix = reinterpret_cast(matrix); + T* uVector = reinterpret_cast(vector); + + void* arg_ptrs[3] = { + reinterpret_cast(&rMatrix), + reinterpret_cast(&uVector), + reinterpret_cast(&nSize), + }; + gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, + block_dim, arg_ptrs, + /*dynamic_shared_mem_bytes=*/0, stream); +} + +void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, + void* vector, int size, + bool is_single_precision) { + int dev = 0; + gpuDeviceProp deviceProp; + gpuGetDeviceProperties(&deviceProp, dev); + int block_dim = deviceProp.maxThreadsPerBlock; + int grid_dim = deviceProp.multiProcessorCount; + + if (is_single_precision) { + LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, grid_dim, + block_dim, size); + } else { + LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, grid_dim, + block_dim, size); + } +} + namespace { __device__ void ComputePermutation(const std::int32_t* pivots, diff --git a/jaxlib/gpu/linalg_kernels.h b/jaxlib/gpu/linalg_kernels.h index 73a0ac173d41..47ada398c3a2 100644 --- a/jaxlib/gpu/linalg_kernels.h +++ b/jaxlib/gpu/linalg_kernels.h @@ -26,8 +26,6 @@ limitations under the License. namespace jax { namespace JAX_GPU_NAMESPACE { -namespace ffi = xla::ffi; - enum LinalgType { F32 = 0, F64 = 1, @@ -44,13 +42,17 @@ void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); +void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, + void* vector, int size, + bool is_single_precision); +XLA_FFI_DECLARE_HANDLER_SYMBOL(CholeskyUpdateFfi); + void LaunchLuPivotsToPermutationKernel(gpuStream_t stream, std::int64_t batch_size, std::int32_t pivot_size, std::int32_t permutation_size, const std::int32_t* pivots, std::int32_t* permutation); - XLA_FFI_DECLARE_HANDLER_SYMBOL(LuPivotsToPermutation); } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 96266ca93378..ef635bebd401 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -23,6 +23,7 @@ limitations under the License. #if defined(JAX_GPU_CUDA) #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" // IWYU pragma: export +#include "third_party/gpus/cuda/include/cooperative_groups.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export @@ -31,8 +32,8 @@ limitations under the License. #include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export #include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export -#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export +#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export +#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export #if CUDA_VERSION < 11080 #error "JAX requires CUDA 11.8 or newer." @@ -292,6 +293,10 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuStreamWaitEvent cudaStreamWaitEvent #define gpuSuccess cudaSuccess +#define gpuDeviceProp cudaDeviceProp +#define gpuGetDeviceProperties cudaGetDeviceProperties +#define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel + namespace jax::JAX_GPU_NAMESPACE { namespace { constexpr uint32_t kNumThreadsPerWarp = 32; @@ -300,6 +305,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #elif defined(JAX_GPU_HIP) +#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h" #include "rocm/include/hip/hip_runtime_api.h" #include "rocm/include/hipblas/hipblas.h" #include "rocm/include/hipsolver/hipsolver.h" @@ -541,6 +547,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES #define GPU_EVENT_DEFAULT hipEventDefault +#define gpuDeviceProp hipDeviceProp_t +#define gpuGetDeviceProperties hipGetDeviceProperties +#define gpuLaunchCooperativeKernel hipLaunchCooperativeKernel + namespace jax::JAX_GPU_NAMESPACE { namespace { constexpr uint32_t kNumThreadsPerWarp = 64; diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 39b3aaea2072..88b7ff463800 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -37,7 +37,9 @@ if _cuda_linalg: for _name, _value in _cuda_linalg.registrations().items(): - api_version = 0 if _name == "cu_cholesky_update" else 1 + api_version = (1 + if _name.endswith("lu_pivots_to_permutation") + or _name.endswith("_ffi") else 0) xla_client.register_custom_call_target( _name, _value, platform="CUDA", api_version=api_version ) @@ -54,8 +56,11 @@ if _hip_linalg: for _name, _value in _hip_linalg.registrations().items(): + api_version = (1 + if _name.endswith("lu_pivots_to_permutation") + or _name.endswith("_ffi") else 0) xla_client.register_custom_call_target( - _name, _value, platform="ROCM", api_version=1 + _name, _value, platform="ROCM", api_version=api_version ) _prod = lambda xs: functools.reduce(operator.mul, xs, 1) diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index 7f90d6ee35d5..342e65ea2c6f 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -271,7 +271,6 @@ cc_library( ":hip_vendor", "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -288,7 +287,6 @@ rocm_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", ], From da77b710b8350fc910f1bc6c92c2a2888cc8bdde Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 20 Aug 2024 15:08:36 +0200 Subject: [PATCH 177/702] Added `py::mod_gil_not_used()` to `PYBIND11_MODULE` for `_triton_ext` and `_tpu_ext` Description: - Added `py::mod_gil_not_used()` to `PYBIND11_MODULE` for `_triton_ext` and `_tpu_ext`. Refs: - https://py-free-threading.github.io/porting/#__tabbed_1_2 Context: - https://github.com/google/jax/issues/23073 --- jaxlib/mlir/_mlir_libs/tpu_ext.cc | 2 +- jaxlib/mlir/_mlir_libs/triton_ext.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index b09e5744b619..a50aef1ca6d4 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -314,7 +314,7 @@ MlirContext getDefaultContext() { } // namespace -PYBIND11_MODULE(_tpu_ext, m) { +PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { mlirRegisterTPUPasses(); // Register all passes on load. py::class_(m, "ApplyVectorLayoutCtx", diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 4b900b7c1cbf..e02e4f3d86e4 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -22,7 +22,7 @@ limitations under the License. namespace py = pybind11; -PYBIND11_MODULE(_triton_ext, m) { +PYBIND11_MODULE(_triton_ext, m, py::mod_gil_not_used()) { // // Dialects. // From 16eb13e9db0560fec2aa2c19aadd18974af52d1e Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 20 Aug 2024 09:58:09 -0700 Subject: [PATCH 178/702] Fix empty mesh `size` and `abstract_mesh` * Fix `size` to return 0 rather than 1 for the empty mesh. * Fix `abstract_mesh` to return an empty abstract mesh. PiperOrigin-RevId: 665408468 --- jax/_src/mesh.py | 11 +++++++---- tests/array_test.py | 13 +++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index fec1f5ef1779..b30286b36a76 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -247,11 +247,11 @@ def shape_tuple(self): @property def size(self): - return math.prod(self.shape.values()) + return math.prod(self.shape.values()) if self.devices.ndim else 0 @property def empty(self): - return self.devices.ndim == 0 + return self.size == 0 @functools.cached_property def is_multi_process(self): @@ -337,7 +337,10 @@ class AbstractMesh: def __init__(self, shape_tuple: tuple[tuple[str, int], ...]): self.shape_tuple = shape_tuple - self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple)) + if self.shape_tuple: + self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple)) + else: + self._axis_names, self._axis_sizes = (), () def __hash__(self): return hash(self.shape_tuple) @@ -358,7 +361,7 @@ def axis_names(self): @functools.cached_property def size(self): - return math.prod(self._axis_sizes) + return math.prod(self._axis_sizes) if self._axis_sizes else 0 @functools.cached_property def shape(self): diff --git a/tests/array_test.py b/tests/array_test.py index f13ecbb51adb..c2e11268d714 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1409,6 +1409,19 @@ def f(x): y_ref1 = f(jax.device_put(x, jax.devices()[0])) self.assertArraysEqual(y, y_ref1) + def test_empty_mesh_creation(self): + mesh = jax.sharding.Mesh(devices=np.empty([]), axis_names=[]) + self.assertTrue(mesh.empty) + self.assertEqual(mesh.size, 0) + + abstract_mesh = mesh.abstract_mesh + self.assertTrue(abstract_mesh.empty) + self.assertEqual(abstract_mesh.size, 0) + + abstract_mesh2 = jax.sharding.AbstractMesh(()) + self.assertTrue(abstract_mesh2.empty) + self.assertEqual(abstract_mesh2.size, 0) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 3b2ce682a8322d9572d826af4beffe87c1ebbfca Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 20 Aug 2024 11:54:11 -0700 Subject: [PATCH 179/702] Fix index_swap_array with multiple indexers for the destination ref. When the destination ref has multiple indexers, the indexing needs to be undone in reverse order, not forward order as originally implemented. PiperOrigin-RevId: 665463297 --- jax/_src/state/discharge.py | 24 ++++++++++++++++++------ tests/pallas/indexing_test.py | 17 +++++++++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index e6ac0db98d08..8666f4cb08f4 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -262,8 +262,11 @@ def index_array(x, indexers): def index_swap_array(x, indexers, val): result = x result_val = val + # Compute updated "val" (result). + _results = [x] for indexer in indexers: if _is_trivial_indexer(indexer): + _results.append(None) continue # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. @@ -271,15 +274,24 @@ def index_swap_array(x, indexers, val): if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_slice result_old = lax_slicing.dynamic_slice(result, starts, sizes) - result_val = lax.expand_dims(result_val, squeeze_dims) - y = lax_slicing.dynamic_update_slice(result, result_val, starts) result = lax.squeeze(result_old, squeeze_dims) - result_val = y else: indexer = _convert_to_array_indexer(indexer) - result_old = _prepend_gather(result, indexer) - result_val = _prepend_scatter(result, indexer, result_val) - result = result_old + result = _prepend_gather(result, indexer) + _results.append(result) + + # Compute updated "x" (result_val) + for i, indexer in reversed(list(enumerate(indexers))): + if _is_trivial_indexer(indexer): + continue + if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, _, squeeze_dims = maybe_slice + result_val = lax.expand_dims(result_val, squeeze_dims) + result_val = lax_slicing.dynamic_update_slice( + _results[i], result_val, starts) + else: + indexer = _convert_to_array_indexer(indexer) + result_val = _prepend_scatter(_results[i], indexer, result_val) return result, result_val def _get_discharge(x, idx, tree): diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index f706b36c5f90..9bf48f609215 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -299,6 +299,23 @@ def invoke_permutes(x_ref, y_ref, x_out_ref, y_out_ref): interpret=True, )(x, y) + def test_multi_indexing_destination_ref(self): + if not self.INTERPRET: + self.skipTest("Only supported in interpret mode") + def kernel(x_ref, o_ref): + o_ref[...] = jnp.zeros_like(o_ref) + new_o_ref = o_ref.at[pl.ds(0, 8)].at[0].at[pl.ds(0, 4), pl.ds(0, 4)] + new_o_ref[...] = x_ref[...] + + x = jax.random.normal(jax.random.key(0), shape=(4, 4)) + result = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 16, 16), x.dtype), + interpret=True, + )(x) + expected = jnp.zeros((16, 16, 16)).at[0, 0:4, 0:4].set(x) + np.testing.assert_array_equal(result, expected) + def test_ellipsis_indexing_iterpret_only(self): if not self.INTERPRET: self.skipTest("Only supported in interpret mode") From f04a35bda50ff1d05b8de5c16d5e9fa771efebd4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 20 Aug 2024 13:03:08 -0700 Subject: [PATCH 180/702] Update jax landing page --- docs/_static/style.css | 9 +++++ docs/index.rst | 92 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 95 insertions(+), 6 deletions(-) diff --git a/docs/_static/style.css b/docs/_static/style.css index 7a5c647052f0..296912ace2c8 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -20,6 +20,15 @@ background-color: rgba(171, 0, 182, var(--block-bg-opacity)); } +.ecosystem-grid { + font-size: smaller; +} + +.ecosystem-grid ul { + list-style-type: none; + padding-inline-start: 0.5em; +} + div.red-background pre { background-color: rgba(244, 204, 204, var(--block-bg-opacity)); } diff --git a/docs/index.rst b/docs/index.rst index 11d2807bf77e..92422edc069f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,10 +4,6 @@ JAX: High performance array computing JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. -If you're looking to train neural networks, use Flax_ and start with its documentation. -Some associated tools are Optax_ and Orbax_. -For an end-to-end transformer library built on JAX, see MaxText_. - .. grid:: 3 :margin: 0 :padding: 0 @@ -54,6 +50,69 @@ For an end-to-end transformer library built on JAX, see MaxText_. :link-type: ref :class-card: developer-docs +If you're looking to train neural networks, use Flax_ and start with its tutorials. +For an end-to-end transformer library built on JAX, see MaxText_. + +Ecosystem +--------- +JAX itself is narrowly-scoped and focuses on efficient array operations & program +transformations. Built around JAX is an evolving ecosystem of machine learning and +numerical computing tools; the following is just a small sample of what is out there: + +.. grid:: 4 + :class-container: ecosystem-grid + + .. grid-item:: :material-outlined:`hub;2em` **Neural networks** + + - Flax_ + - NNX_ + - Equinox_ + - Keras_ + + .. grid-item:: :material-regular:`show_chart;2em` **Optimizers & solvers** + + - Optax_ + - Optimistix_ + - Lineax_ + - Diffrax_ + + .. grid-item:: :material-outlined:`storage;2em` **Data loading** + + - Grain_ + - `Tensorflow datasets`_ + - `Hugging Face datasets`_ + + .. grid-item:: :material-regular:`construction;2em` **Miscellaneous tools** + + - Orbax_ + - Chex_ + + .. grid-item:: :material-regular:`lan;2em` **Probabilistic programming** + + - Blackjax_ + - Numpyro_ + - PyMC_ + + .. grid-item:: :material-regular:`bar_chart;2em` **Probabilistic modeling** + + - `Tensorflow probabilty`_ + - Distrax_ + + .. grid-item:: :material-outlined:`animation;2em` **Physics & simulation** + + - `JAX MD`_ + - Brax_ + + .. grid-item:: :material-regular:`language;2em` **LLMs** + + - MaxText_ + - AXLearn_ + - Levanter_ + - EasyLM_ + + +Many more JAX-based libraries have been developed; the community-run `Awesome JAX`_ page +maintains an up-to-date list. .. toctree:: :hidden: @@ -93,7 +152,28 @@ For an end-to-end transformer library built on JAX, see MaxText_. glossary +.. _Awesome JAX: https://github.com/n2cholas/awesome-jax +.. _AXLearn: https://github.com/apple/axlearn +.. _Blackjax: https://blackjax-devs.github.io/blackjax/ +.. _Brax: https://github.com/google/brax/ +.. _Chex: https://chex.readthedocs.io/ +.. _Diffrax: https://docs.kidger.site/diffrax/ +.. _Distrax: https://github.com/google-deepmind/distrax +.. _EasyLM: https://github.com/young-geng/EasyLM +.. _Equinox: https://docs.kidger.site/equinox/ .. _Flax: https://flax.readthedocs.io/ -.. _Orbax: https://orbax.readthedocs.io/ -.. _Optax: https://optax.readthedocs.io/ +.. _Grain: https://github.com/google/grain +.. _Hugging Face datasets: https://huggingface.co/docs/datasets/ +.. _JAX MD: https://jax-md.readthedocs.io/ +.. _Keras: https://keras.io/ +.. _Levanter: https://github.com/stanford-crfm/levanter +.. _Lineax: https://github.com/patrick-kidger/lineax .. _MaxText: https://github.com/google/maxtext/ +.. _NNX: https://flax.readthedocs.io/en/latest/nnx/ +.. _Numpyro: https://num.pyro.ai/en/latest/index.html +.. _Optax: https://optax.readthedocs.io/ +.. _Optimistix: https://github.com/patrick-kidger/optimistix +.. _Orbax: https://orbax.readthedocs.io/ +.. _PyMC: https://www.pymc.io/ +.. _Tensorflow datasets: https://www.tensorflow.org/datasets +.. _Tensorflow probabilty: https://www.tensorflow.org/probability From b241d73592fefed51100e70852db732170d4fb60 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 20 Aug 2024 13:17:32 -0700 Subject: [PATCH 181/702] Fix argument names for jnp.arctan2/jnp.atan2 --- jax/_src/numpy/ufuncs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 531d0bec813f..2893b14f7059 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -252,8 +252,8 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: @implements(np.arctan2, module='numpy') @partial(jit, inline=True) -def arctan2(x: ArrayLike, y: ArrayLike, /) -> Array: - return lax.atan2(*promote_args_inexact("arctan2", x, y)) +def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: + return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) @implements(np.minimum, module='numpy') @partial(jit, inline=True) @@ -357,9 +357,9 @@ def atanh(x: ArrayLike, /) -> Array: return arctanh(*promote_args('atanh', x)) @partial(jit, inline=True) -def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: +def atan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan2`""" - return arctan2(*promote_args('atan2', x, y)) + return arctan2(*promote_args('atan2', x1, x2)) @jit def bitwise_count(x: ArrayLike, /) -> Array: From 21d57030dc3453eaee761c9e7ec07bb7789f089b Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 20 Aug 2024 15:05:31 -0700 Subject: [PATCH 182/702] [Pallas] Skip TPU-specific tests on win32 PiperOrigin-RevId: 665543351 --- tests/pallas/indexing_test.py | 8 ++++++++ tests/pallas/ops_test.py | 6 ++++++ tests/pallas/pallas_test.py | 6 +++--- tests/pallas/tpu_ops_test.py | 5 ++--- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 9bf48f609215..2cad1d064e87 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -531,6 +531,8 @@ def body(x_ref, y_ref1, y_ref2): ) def test_load_with_dynamic_2nd_minor_index(self): + if pltpu is None: + self.skipTest("No TPU module available.") # We can take any dynamic index on the 2nd minor dimension as long as # the minormost dimsize is vreg lane count. m, n = 32, 128 @@ -557,6 +559,8 @@ def kernel(x_ref, indices, y_ref): self.assertAllClose(res, x[start : start + k, :], atol=0., rtol=0.) def test_store_with_dynamic_2nd_minor_index(self): + if pltpu is None: + self.skipTest("No TPU module available.") # We can take any dynamic index on the 2nd minor dimension as long as # the minormost dimsize is vreg lane count. m, n = 10, 128 @@ -583,6 +587,8 @@ def kernel(x_ref, indices, y_ref): self.assertAllClose(res[start : start + m, :], x, atol=0., rtol=0.) def test_load_one_row_with_dynamic_2nd_minor_index(self): + if pltpu is None: + self.skipTest("No TPU module available.") # This test triggers strided load. We can take any dynamic index on the # 2nd minor dimension as long as we load one row on the 2nd minor dim. b, m, n = 4, 16, 256 @@ -608,6 +614,8 @@ def kernel(x_ref, indices, y_ref): self.assertAllClose(res, x[:, start : start + 1, :], atol=0., rtol=0.) def test_store_one_row_with_dynamic_2nd_minor_index(self): + if pltpu is None: + self.skipTest("No TPU module available.") # This test triggers strided store. We can take any dynamic index on the # 2nd minor dimension as long as we store one row on the 2nd minor dim. b, m, n = 4, 16, 256 diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index e106a56e588e..cf247ac3f6a4 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -611,6 +611,8 @@ def kernel(x_ref, y_ref): dtype=(jnp.int32, jnp.int16, jnp.int8), ) def test_scalar_map(self, shape, dtype): + if pltpu is None: + self.skipTest("No TPU module available.") if dtype != jnp.int32 and len(shape) < 2: # TODO(b/299280718): Implement this. self.skipTest( @@ -633,6 +635,8 @@ def kernel(x_ref, y_ref): @jtu.skip_on_devices("gpu") # TODO: not implemented def test_extract_scalar(self): + if pltpu is None: + self.skipTest("No TPU module available.") def kernel(x_ref, y_ref): y_ref[0, 0] = x_ref[:][0, 0] f = self.pallas_call( @@ -645,6 +649,8 @@ def kernel(x_ref, y_ref): @jtu.skip_on_devices("gpu") # TODO: not implemented def test_concat_constant(self): + if pltpu is None: + self.skipTest("No TPU module available.") def kernel(out): result = [] for i in range(16): diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 7227c1b91ef3..2762eb28755e 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -768,7 +768,7 @@ def my_index_map(): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - "Index map function my_index_map at .*/pallas_test.py:.* for " + f"Index map function my_index_map at .*{os.sep}pallas_test.py:.* for " "x_ref must return 1 values to match .*" "Currently returning 2 values."): f(a) @@ -783,7 +783,7 @@ def my_index_map(i): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - "Index map function my_index_map at .*/pallas_test.py:.* for " + f"Index map function my_index_map at .*{os.sep}pallas_test.py:.* for " "x_ref must return integer scalars. Output\\[0\\] has " "type .*float"): f(a) @@ -798,7 +798,7 @@ def my_index_map(i): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - "Index map function my_index_map at .*/pallas_test.py:.* for " + f"Index map function my_index_map at .*{os.sep}pallas_test.py:.* for " "x_ref must return integer scalars. Output\\[0\\] has " "type .*int32\\[4\\]"): f(a) diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index d9a66d1b2b34..cc39c879b121 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -54,9 +54,8 @@ class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False def setUp(self): - if not self.INTERPRET: - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only interpret mode supported on non-TPU") + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test only supported on TPU.") super().setUp() From 9bebf577ddee945a1cf50753937e9ed1b0124258 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 20 Aug 2024 15:06:27 -0700 Subject: [PATCH 183/702] Reverts 66a3f87a24016594794c2ee289826baed5e979a4 PiperOrigin-RevId: 665543713 --- jax/_src/core.py | 24 ++- jax/_src/interpreters/batching.py | 6 +- jax/_src/pallas/core.py | 61 +++++-- jax/_src/pallas/mosaic/lowering.py | 28 ++- jax/_src/pallas/pallas_call.py | 275 ++++++++++++++++++++++++++--- tests/pallas/BUILD | 23 +++ tests/pallas/pallas_jumble_test.py | 201 +++++++++++++++++++++ 7 files changed, 568 insertions(+), 50 deletions(-) create mode 100644 tests/pallas/pallas_jumble_test.py diff --git a/jax/_src/core.py b/jax/_src/core.py index ebf29cf0b253..61ed81cdeea9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1954,6 +1954,7 @@ def __init__(self, aval, data): assert data.shape == pad_shape self._aval = aval self._data = data + shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) aval = property(lambda self: self._aval) @@ -1964,21 +1965,38 @@ def __repr__(self) -> str: dtypestr = _short_dtype_name(self._aval.dtype) shapestr = ','.join(map(str, self.shape)) - slices = tuple(slice(int(d._data)) if type(d) is DArray and - type(d.dtype) is bint else slice(None) for d in self.shape) - data = self._data[slices] + data = self.data return f'{dtypestr}[{shapestr}] with value: {data}' + def __hash__(self) -> int: if not self.shape: return hash((self._aval, int(self._data))) raise TypeError("unhashable type: DArray") + def __eq__(self, other): if isinstance(other, DArray) and self._aval == other._aval: return self._data == other._data return False + def __len__(self): return self.shape[0] + @property + def data(self): + if not self.shape and type(self.dtype) is bint: + # special-case scalar bints + return self._data + + slices = tuple( + slice(int(d._data)) + if type(d) is DArray and type(d.dtype) is bint + else slice(None) + for d in self.shape + ) + data = self._data[slices] + return data + + pytype_aval_mappings[DArray] = \ lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, x._data) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index fbcd2c4a7a30..27cde6d31d35 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -88,6 +88,7 @@ def _jumble_flatten(jumble): elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape)) aval = jumble.aval.replace(elt_ty=elt_ty) return (lengths, jumble.data), aval + def _jumble_unflatten(aval, x): lengths, data = x new_shape = [d.replace(lengths=lengths[d.lengths - 1]) @@ -251,7 +252,10 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: - assert False + # TODO(mvoz): This is a terrible place to fall into if you pass + # a non jumble type in, make it clearer what went wrong. + assert False, f'Unexpected type in ELT? {type(x)}' + to_elt_handlers: dict[type, ToEltHandler] = {} def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 0ef208f755e5..1d99680b646e 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -112,7 +112,10 @@ class AbstractMemoryRef(state.AbstractRef): def __init__(self, inner_aval: jax_core.AbstractValue, memory_space: Any): - assert isinstance(inner_aval, jax_core.ShapedArray) + + assert isinstance( + inner_aval, jax_core.ShapedArray + ), f"Illegal ref, got {type(inner_aval)}" self.inner_aval = inner_aval self.memory_space = memory_space @@ -167,9 +170,7 @@ class PallasGridContext: mapped_dims: tuple[int, ...] def size(self, axis: int) -> int | DynamicGridDim: - valid_grid = tuple( - s for i, s in enumerate(self.grid) if i not in self.mapped_dims - ) + valid_grid = tuple(self.grid) try: size = valid_grid[axis] except IndexError as e: @@ -338,7 +339,10 @@ def check_invariants(self) -> None: ) assert not self.index_map_jaxpr.consts - assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals) + assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), ( + self.block_shape, + self.index_map_jaxpr.out_avals, + ) assert all(ov.shape == () and (ov.dtype == jnp.int32 or ov.dtype == jnp.int64) for ov in self.index_map_jaxpr.out_avals), ( @@ -422,6 +426,8 @@ class GridMapping: num_inputs: int num_outputs: int num_scratch_operands: int + get_grid_indices: Callable | None = None + local_grid_env: Callable | None = None def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -442,8 +448,8 @@ def check_invariants(self) -> None: assert len(index_map_args) >= len(self.grid) for i in range(len(self.grid)): index_map_arg = index_map_args[i] - assert index_map_arg.shape == () - assert index_map_arg.dtype == jnp.int32 + assert index_map_arg.shape == (), f"index_map_arg: {index_map_arg}" + assert index_map_arg.dtype == jnp.int32, f"index_map_arg: {index_map_arg}" assert len(self.vmapped_dims) <= len(self.grid) for i in self.vmapped_dims: @@ -454,8 +460,11 @@ def check_invariants(self) -> None: for bm in self.block_mappings: bm.check_invariants() - assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), ( + assert tuple(self.index_map_avals) == tuple( + bm.index_map_jaxpr.in_avals + ), ( self.index_map_avals, + "|", bm.index_map_jaxpr.in_avals, ) @@ -547,6 +556,25 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool: return True return jax_core.is_dim(dim) + +def _max_shape_from_aval(array_aval: jax_core.ShapedArray): + array_aval_shape = list(array_aval.shape) + for i, s in enumerate(array_aval.shape): + try: + aval = jax_core.get_aval(s) + if isinstance(aval, jax_core.DShapedArray): + array_aval_shape[i] = aval.dtype.bound + except OverflowError as e: + # Note - there are annoying cases where on 32 bit hardware, + # a flattened index space may overflow - for these cases, + # we just take the shape as is. + # In most places, this is totally sound to do. + # For ragged/jumble inputs, this will fail downstream. + return array_aval.shape + + return tuple(array_aval_shape) + + def _convert_block_spec_to_block_mapping( block_spec: BlockSpec, origin: OriginStr, @@ -575,8 +603,15 @@ def _convert_block_spec_to_block_mapping( f"array shape {array_aval.shape}.") unmapped_block_shape = tuple(s for s in block_shape if s is not None) - block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape), - block_spec.memory_space) + block_array_aval = array_aval.update(shape=unmapped_block_shape) + if isinstance(array_aval, jax_core.DShapedArray): + # Get the "max" shape for the ragged array. + block_array_aval = jax_core.ShapedArray( + block_array_aval.shape, + block_array_aval.dtype, + block_array_aval.weak_type, + ) + block_aval = AbstractMemoryRef(block_array_aval, block_spec.memory_space) if not jax_core.is_constant_shape(block_aval.shape): raise ValueError( @@ -609,12 +644,12 @@ def _convert_block_spec_to_block_mapping( f"{origin} must return integer scalars. Output[{i}] has type " f"{ov}.") - if consts: raise ValueError( f"Index map function {index_map_src_info} for " f"{origin} must not capture constants: {consts}") + array_aval_shape = _max_shape_from_aval(array_aval) mapping = BlockMapping( block_shape=mapped_block_shape, @@ -622,7 +657,9 @@ def _convert_block_spec_to_block_mapping( index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), index_map_src_info=index_map_src_info, indexing_mode=block_spec.indexing_mode, - array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype), + array_shape_dtype=jax.ShapeDtypeStruct( + array_aval_shape, array_aval.dtype + ), origin=origin, ) mapping.check_invariants() diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 86ce2f0b1b81..aee894ee1b7e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -298,6 +298,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, self.jaxpr = jaxpr self.block_mappings = grid_mapping.block_mappings self.mapped_dims = grid_mapping.vmapped_dims + # TODO(mvoz): Generalize to not need this user_grid = tuple( g for i, g in enumerate(self.grid) if i not in self.mapped_dims ) @@ -345,9 +346,19 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, for _ in range(len(self.grid)) ]) self._prepare_mesh_info(mesh) - def _get_grid_indices(indices): - return indices - self.get_grid_indices = _get_grid_indices + + if grid_mapping.get_grid_indices is None: + + def _get_grid_indices(indices, maybe_include_mapped_dims: bool): + if maybe_include_mapped_dims: + return indices + return tuple( + idx for i, idx in enumerate(indices) if i not in self.mapped_dims + ) + + self.get_grid_indices = _get_grid_indices + else: + self.get_grid_indices = grid_mapping.get_grid_indices def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): if not self.has_communication: @@ -595,7 +606,9 @@ def lower_jaxpr_to_transform_func( ] def body_func(*args): grid_indices, scalar_prefetch = split_list(args, [num_grid]) - jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices( + grid_indices, maybe_include_mapped_dims=True + ) arg_block_shapes = [ *[()] * len(jaxpr_indices), *mosaic_grid_mapping.scalar_prefetch_block_shapes, @@ -663,9 +676,9 @@ def lower_jaxpr_to_func( def body_func(*args): grid_indices, scalar_prefetch, operands_and_scratch = split_list( args, [num_grid, num_scalar_prefetch]) - grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) - jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices) - if i not in mosaic_grid_mapping.mapped_dims) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices( + grid_indices, maybe_include_mapped_dims=False + ) mesh_info = mosaic_grid_mapping.mesh_info if mesh_info is not None: mesh_context = MeshContext( @@ -2365,6 +2378,7 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): + if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 3a780cdca617..e948ff374acc 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -228,6 +228,12 @@ def _pallas_call_impl_interpret( # Pad values to evenly divide into block dimensions. This matches the # behavior of the non-interpret mode. We pad with NaN, to make it easier # to catch OOB accesses. + for carry_element in carry: + aval = carry_element.aval + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype) + carry_element.aval = aval + carry = map(_pad_values_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) @@ -247,11 +253,16 @@ def cond(carry): return i < num_iterations def body(carry): i, loop_idx, *carry_blocks = carry - local_grid_env = tuple( - pallas_core.GridAxis(idx, b) - for dim, (idx, b) in enumerate(zip(loop_idx, grid)) - if dim not in grid_mapping.vmapped_dims - ) + + if grid_mapping.local_grid_env is not None: + local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + else: + local_grid_env = tuple( + pallas_core.GridAxis(idx, b) + for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + if dim not in grid_mapping.vmapped_dims + ) + carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks]) with pallas_core.grid_env(local_grid_env): start_indices = [ @@ -268,8 +279,14 @@ def body(carry): len(blocks), len(scratch_values), ) - blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, - *blocks, *scratch) + for s in scalars: + aval = jax_core.get_aval(s) + if isinstance(aval, jax_core.DShapedArray): + s.aval = aval.update(dtype=jnp.int32) + + blocks = jax_core.eval_jaxpr( + discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch + ) _, out_inout, out_scratch = split_list( blocks, [grid_mapping.num_index_operands, num_inout_blocks]) @@ -390,19 +407,55 @@ def _pallas_call_jvp_rule( ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule -def _batch_block_mapping(grid_mapping: GridMapping, - axis_size: int, - aval: jax_core.ShapedArray, - dim: int | batching.NotMapped, - block_mapping: BlockMapping) -> BlockMapping: + +def _batch_block_mapping( + grid_mapping: GridMapping, + axis_size: int, + aval: jax_core.ShapedArray, + dim: int | batching.NotMapped, + block_mapping: BlockMapping, + for_ragged: bool, +) -> BlockMapping: def _block_map_function(new_idx, *args): - indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr, - block_mapping.index_map_jaxpr.consts, - *args) + if for_ragged: + drop_last_args = args[:-1] + else: + drop_last_args = args + + indices = jax_core.eval_jaxpr( + block_mapping.index_map_jaxpr.jaxpr, + block_mapping.index_map_jaxpr.consts, + *drop_last_args, + ) if dim is not batching.not_mapped: - indices.insert(dim, new_idx) + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + stacked_axis = dim.stacked_axis + indices.insert(stacked_axis, new_idx) + else: + indices.insert(dim, new_idx) return tuple(indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] + + if for_ragged: + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + _, _, ragged_axis_length = _ragged_axis_parts(dim) + aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) + lengths_aval = pallas_core.AbstractMemoryRef( + aval, + pallas_core.MemorySpace.INDEX, + ) + idx_avals = [*idx_avals, lengths_aval] + else: + i32_aval_memref = pallas_core.AbstractMemoryRef( + jax_core.ShapedArray(([axis_size]), jnp.int32), + pallas_core.MemorySpace.INDEX, + ) + idx_avals = [*idx_avals, i32_aval_memref] + with grid_mapping.trace_env(): block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_block_map_function), idx_avals) @@ -411,12 +464,27 @@ def _block_map_function(new_idx, *args): new_block_shape = shape new_array_shape_dtype = block_mapping.array_shape_dtype else: - new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + new_block_shape = shape + stacked_axis = dim.stacked_axis + new_block_shape = tuple_insert( + new_block_shape, stacked_axis, pallas_core.mapped + ) + else: + new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + + array_shape = block_mapping.array_shape_dtype.shape + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + stacked_axis = dim.stacked_axis + array_shape = tuple_insert(array_shape, stacked_axis, axis_size) + else: + array_shape = tuple_insert(array_shape, dim, axis_size) + new_array_shape_dtype = jax.ShapeDtypeStruct( - tuple_insert(block_mapping.array_shape_dtype.shape, - dim, - axis_size), - block_mapping.array_shape_dtype.dtype) + array_shape, block_mapping.array_shape_dtype.dtype + ) jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts) return block_mapping.replace(block_shape=new_block_shape, @@ -547,6 +615,16 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: return result, (0,) * len(result) +def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]: + stacked_axis = dim.stacked_axis + ragged_axes = dim.ragged_axes + if len(ragged_axes) != 1: + raise ValueError("Multiple ragged axes not yet implemented.") + ragged_axis_dim = ragged_axes[0][0] + ragged_axis_length = ragged_axes[0][1] + return stacked_axis, ragged_axis_dim, ragged_axis_length + + def _pallas_call_batching_rule( args, dims, @@ -567,8 +645,26 @@ def _maybe_squeeze_out_bdim( return x return jnp.squeeze(x, axis=bdim) + all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)] + if len(all_ragged_axes) > 1: + raise ValueError("Multiple ragged dimensions not yet implemented.") + + if all_ragged_axes: + stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts( + all_ragged_axes[0] + ) + else: + stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None + + def get_size(i, x, d): + if not isinstance(d, batching.RaggedAxis): + return x.shape[d] + return x.aval.shape[i] + (axis_size,) = { - x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped + get_size(i=i, x=x, d=d) + for i, (x, d) in enumerate(zip(args, dims)) + if d is not batching.not_mapped } if axis_size == 1: # Why are we even vmapping? @@ -670,12 +766,27 @@ def _maybe_squeeze_out_bdim( num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands + lengths_aval = None + if ragged_axis_length is not None: + aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) + lengths_aval = pallas_core.AbstractMemoryRef( + aval, + pallas_core.MemorySpace.INDEX, + ) + # Only add a batch dimension for the avals that actually have a grid mapping. # This excludes scalar prefetch inputs (the first in the list) and scratch # operands (the last in the list). avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] batched_block_mappings = map( - partial(_batch_block_mapping, grid_mapping, axis_size), + partial( + _batch_block_mapping, + grid_mapping, + axis_size, + for_ragged=lengths_aval is not None, + ), avals_to_batch, all_dims[num_index_operands:], block_mappings, @@ -685,15 +796,23 @@ def _maybe_squeeze_out_bdim( grid_mapping.index_map_avals) assert not index_map_tree_kwargs batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args + + if lengths_aval: + batched_index_map_args = batched_index_map_args + (lengths_aval,) + num_index_operands += 1 + batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten( (batched_index_map_args, {})) + batched_grid_mapping = grid_mapping.replace( grid=(axis_size, *grid_mapping.grid), block_mappings=tuple(batched_block_mappings), - index_map_avals=batched_index_map_avals, + index_map_avals=tuple(batched_index_map_avals), index_map_tree=batched_index_map_tree, + num_index_operands=num_index_operands, vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims), ) + if cost_estimate is not None: batched_cost_estimate = CostEstimate( flops=cost_estimate.flops * axis_size, @@ -702,6 +821,103 @@ def _maybe_squeeze_out_bdim( ) else: batched_cost_estimate = None + + if lengths_aval: + batched_grid_mapping = batched_grid_mapping.replace( + get_grid_indices=lambda indices, maybe_include_mapped_dims: indices, + local_grid_env=lambda loop_idx, grid: tuple( + pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid) + ), + ) + + # Note - on zero filling counterfactuals + # A debug util to produce a counterfactual version of the when + # gating, where for all values that don't pass the @when check, + # we write 0s. This is useful for debugging, as certain lowering paths + # like mosaic will write the last data as passthrough, leading to + # potentially confusing results. + debug_zero_fill_counterfactual = debug + + first_block_mapping = batched_grid_mapping.block_mappings[0] + for block_mapping in batched_grid_mapping.block_mappings: + # This invariant may already be checked elsewhere, but lets reaffirm it + assert block_mapping.block_shape == first_block_mapping.block_shape, ( + f"block_mapping.block_shape: {block_mapping.block_shape}, " + f"first_block_mapping.block_shape: {first_block_mapping.block_shape}" + ) + assert ( + block_mapping.array_shape_dtype + == first_block_mapping.array_shape_dtype + ), ( + f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype}," + " first_block_mapping.array_shape_dtype:" + f" {first_block_mapping.array_shape_dtype}" + ) + + mapped_dim_idxs = [ + i + for i, d in enumerate(first_block_mapping.block_shape) + if d is pallas_core.mapped + ] + assert len(mapped_dim_idxs) == 1 + mapped_dim_idx = mapped_dim_idxs[0] + if stacked_axis != mapped_dim_idx: + raise ValueError( + f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}" + ) + + assert ragged_axis_dim is not None, "Invariant violation" + # This is the blockspec size of the dimension + val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim] + + def when_wrapped_kernel(lengths_ref, *args, **kwargs): + b_idx = jax.experimental.pallas.program_id(stacked_axis) + i_idx = ( + jax.experimental.pallas.program_id(ragged_axis_dim) + * val_at_ragged_dim + ) + b_len = lengths_ref[b_idx] + + # TODO(mvoz): Unimplemented primitive in pallas + # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) + # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") + + @jax.experimental.pallas.when(i_idx < b_len) + def f(): + # Important! This allows us to trace the inner kernel with the correct + # grid to preserve user program_id semantics. Ex: program_id(0) will + # always be analogous to program_id(1) in the outer kernel. + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs) + + if debug_zero_fill_counterfactual: + + @jax.experimental.pallas.when(i_idx >= b_len) + def g(): + for arg_ref in args: + arg_ref[...] = jnp.zeros_like(arg_ref) + + kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars] + flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten( + list(kernel_avals) + ) + # Important! This allows us to trace the outer kernel with the correct grid + # to enable accessing the batch program_id. + with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): + kernel_src_info: pallas_core.SrcInfoStr = "" + + jaxpr = _trace_kernel_to_jaxpr( + when_wrapped_kernel, + kernel_src_info, + batched_grid_mapping, + tuple(flat_kernel_avals), + kernel_in_tree, + interpret=interpret, + ) + + assert ragged_axis_length is not None + args = (ragged_axis_length, *args) + out = pallas_call_p.bind( *dynamic_grid_args, *args, @@ -1097,12 +1313,14 @@ def pallas_call( out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths) flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore for x in flat_out_shapes] + @jax.jit def wrapped(*args): flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args) in_paths, flat_args = unzip2(flat_args_with_paths) flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) for a in flat_args) + flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype) for v in flat_out_shapes) @@ -1172,15 +1390,18 @@ def wrapped(*args): return wrapped -def in_path_to_input_origin(in_path: tree_util.KeyPath, - arg_names: tuple[str, ...] | None) -> pallas_core.OriginStr: +def in_path_to_input_origin( + in_path: tree_util.KeyPath, arg_names: tuple[str, ...] | None +) -> pallas_core.OriginStr: """Converts `args[k]` into `arg_k_name`.""" if arg_names is None: return f"args{tree_util.keystr(in_path)}" if len(in_path) == 0: return "args" arg_idx, *rest_path = in_path - if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names): + if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len( + arg_names + ): return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path)) else: return f"args{tree_util.keystr(tuple(in_path))}" diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index c0cf61387cbb..5559a0552f9f 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -62,6 +62,29 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_test( + name = "pallas_jumble_test", + srcs = [ + "pallas_jumble_test.py", + ], + disable_configs = [ + "gpu", + "gpu_x32", + "gpu_a100", + "gpu_p100", + "gpu_p100_x32", + "gpu_h100", + ], + shard_count = { + "tpu": 1, + }, + deps = [ + "//jax:pallas", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_test( name = "ops_test", srcs = [ diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py new file mode 100644 index 000000000000..5ed15fe964dd --- /dev/null +++ b/tests/pallas/pallas_jumble_test.py @@ -0,0 +1,201 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" + +from absl.testing import absltest +import jax +from jax import lax +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.interpreters import batching +from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + + +# TODO(mvoz): Update signatures of pallas_call to correct inputs/outputs. +# pylint: disable=no-value-for-parameter + +config.parse_flags_with_absl() + + +intx = dtypes.canonicalize_dtype(jnp.int64) +floatx = dtypes.canonicalize_dtype(jnp.float64) + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: + self.skipTest("On CPU the test works only in interpret mode") + if jtu.test_device_matches( + ["cuda"] + ) and not jtu.is_cuda_compute_capability_at_least("8.0"): + self.skipTest("Only works on GPU with capability >= sm80") + if sys.platform == "win32" and not self.INTERPRET: + self.skipTest("Only works on non-Windows platforms") + + super().setUp() + _trace_kernel_to_jaxpr.cache_clear() + + def pallas_call(self, *args, **kwargs): + return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + + +@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_dtype_promotion="standard") +class PallasCallRaggedVmapTest(PallasBaseTest): + + def test_vmap_jumble_over_sin_kernel(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([128 * x for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct( + (8, col_grid_size * 128), dtype=jnp.float32 + ), + grid=(1, col_grid_size), + interpret=self.INTERPRET, + # See note - on zero filling counterfactuals + debug=True, + )(x) + + res = jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + res = res.data + total = len(ragged_shape) * row_count * col_grid_size * 128 + res_total = np.prod(res.shape) + self.assertEqual(res_total, total) + ragged_total = 0 + for dim in ragged_shape: + ragged_total += row_count * dim * 128 + # See note - on zero filling counterfactuals + self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total) + + def test_vmap_jumble_over_sin_kernel_grid_remapping(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([128 * x for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) * pl.program_id(2) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), + grid=(1, 5), + interpret=False, + )(x) + + with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): + jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + self.skipTest("Checkify NYI") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([(128 * x) - 1 for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), + grid=(1, 5), + interpret=False, + )(x) + + with self.assertRaisesRegex( + ValueError, + "Ragged input shape must be evenly divisble by the grid" # noqa: W605 + " size at the ragged dimension 2", + ): + jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + +class PallasCallNamedGridInterpretTest(PallasCallRaggedVmapTest): + INTERPRET = True + + +if __name__ == "__main__": + absltest.main() From 9b3c19c5dfd28250a892e4314f3c4475dfca6cc6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 20 Aug 2024 15:19:01 -0700 Subject: [PATCH 184/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ac89370dce9df3d850bb51a1576ca39f1efec63b. PiperOrigin-RevId: 665549202 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 62bdf98cf731..291de1db18c2 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "0c5475a11c47fd3aa7afdfa57f533ed9323133dd" -XLA_SHA256 = "0306f260e83960a121fab59c1ea7ede09b898251b2b042913941fa161f1423a3" +XLA_COMMIT = "ac89370dce9df3d850bb51a1576ca39f1efec63b" +XLA_SHA256 = "054a56ecd26babe32deebdf0782e1090b6a1f2a6442c2752602b65ce87747d9a" def repo(): tf_http_archive( From 2bd8c3f6911f0fb30907c5cb7cf63b3f47c0c3be Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 20 Aug 2024 15:38:03 -0700 Subject: [PATCH 185/702] [Pallas] Add explicit TPU compiler params interface with docstrings. PiperOrigin-RevId: 665557475 --- jax/_src/pallas/core.py | 7 +++- jax/_src/pallas/mosaic/core.py | 34 ++++++++++++++++++- .../pallas/mosaic/pallas_call_registration.py | 11 +----- jax/_src/pallas/pallas_call.py | 15 ++++++-- jax/experimental/pallas/tpu.py | 1 + tests/pallas/tpu_pallas_test.py | 14 ++++---- 6 files changed, 62 insertions(+), 20 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 1d99680b646e..f1d79e1dbf85 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -23,7 +23,7 @@ import functools import itertools import threading -from typing import Any, Hashable, Union +from typing import Any, ClassVar, Hashable, Union import warnings import jax @@ -60,6 +60,11 @@ def __repr__(self): OriginStr = str # The origin of a block spec, e.g. input[2]["field"] +@dataclasses.dataclass(frozen=True) +class CompilerParams: + """Base class for compiler parameters.""" + PLATFORM: ClassVar[str] = "unspecified" + @dataclasses.dataclass(frozen=True) class NameAndSrcInfo: #: The name of the pallas_call or the name of the kernel function. diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index ca333f97626c..abc7aca59cc7 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,7 +19,7 @@ import dataclasses import enum import functools -from typing import Any, Hashable +from typing import Any, ClassVar, Hashable import jax from jax._src import core as jax_core @@ -44,6 +44,38 @@ _convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping split_list = util.split_list +@dataclasses.dataclass(frozen=True) +class TPUCompilerParams(pallas_core.CompilerParams): + """Mosaic TPU compiler parameters. + + Attributes: + dimension_semantics: A list of dimension semantics for each grid + dimension of the kernel. Either "parallel" for dimensions that can + execute in any order, or "arbitrary" for dimensions that must be + executed sequentially. + allow_input_fusion: A list of booleans indicating whether input fusion is + allowed for each argument. + vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note + that this must be used in conjunction with the + --xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes. + collective_id: Indicates which barrier semaphore to use for the kernel. + Note that using the same collective_id does not guarantee that + the same barrier semaphore will be allocated between kernels. + internal_scratch_in_bytes: The size of the internal scratch space used by + Mosaic. + flags: A dictionary of command line flags for the kernel. + serialization_format: The serialization format for the kernel body. + device_type: The device type to compile for. + """ + PLATFORM: ClassVar[str] = "mosaic" + dimension_semantics: list[str] | None = None + allow_input_fusion: list[bool] | None = None + vmem_limit_bytes: int | None = None + collective_id: int | None = None + flags: dict[str, Any] | None = None + internal_scratch_in_bytes: int | None = None + serialization_format: int = 1 + device_type: str | None = None class TPUMemorySpace(enum.Enum): ANY = "any" diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index c6edddca035b..cfb55240a876 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -80,16 +80,7 @@ def pallas_call_tpu_lowering_rule( if debug: print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") print(jaxpr) - if "mosaic_params" in compiler_params: - # TODO(slebedev): Remove this branch after July 12th 2024. - warnings.warn( - "Passing Mosaic parameters via compiler_params=dict(mosaic_params=...)" - " is deprecated. Use compiler_params=dict(mosaic=...) instead.", - DeprecationWarning, - ) - assert "mosaic" not in compiler_params - mosaic_params = compiler_params["mosaic_params"] - elif "mosaic" in compiler_params: + if "mosaic" in compiler_params: mosaic_params = compiler_params["mosaic"] else: mosaic_params = {} diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index e948ff374acc..5bbf37dc3663 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -16,6 +16,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Sequence +import dataclasses from functools import partial, reduce import itertools from typing import Any @@ -1232,7 +1233,7 @@ def pallas_call( debug: bool = False, interpret: bool = False, name: str | None = None, - compiler_params: dict[str, Any] | None = None, + compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, cost_estimate: CostEstimate | None = None, ) -> Callable[..., Any]: """Invokes a Pallas kernel on some inputs. @@ -1274,7 +1275,10 @@ def pallas_call( where the kernel function is defined, .e.g: `{name} for kernel function {kernel_name} at {file}:{line}`. If missing, then we use `{kernel_name} at {file}:{line}`. - compiler_params: TO BE DOCUMENTED. + compiler_params: Optional compiler parameters. If a dict is provided, it + should be of the form {platform: {param_name: param_value}}, where + platform is either 'mosaic' or 'triton'. For TPUs, it is also possible + to pass in a pallas.tpu.TPUCompilerParams struct. Returns: A function that can be called on a number of positional array arguments to @@ -1286,6 +1290,13 @@ def pallas_call( name, kernel_src_info) if compiler_params is None: compiler_params = {} + if isinstance(compiler_params, pallas_core.CompilerParams): + if compiler_params.PLATFORM not in ["mosaic", "triton"]: + raise ValueError( + f"Unknown platform in compiler params: {compiler_params.PLATFORM}") + compiler_params = { + compiler_params.PLATFORM: dataclasses.asdict(compiler_params) + } if grid_spec is None: grid_spec = GridSpec(grid, in_specs, out_specs) diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 79d773379f9b..e7fa25a3fc0d 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -21,6 +21,7 @@ from jax._src.pallas.mosaic.core import semaphore from jax._src.pallas.mosaic.core import SemaphoreType from jax._src.pallas.mosaic.core import TPUMemorySpace +from jax._src.pallas.mosaic.core import TPUCompilerParams from jax._src.pallas.mosaic.lowering import LoweringException from jax._src.pallas.mosaic.pipeline import ARBITRARY from jax._src.pallas.mosaic.pipeline import BufferedRef diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index bfefe720a383..94d169713c34 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -413,7 +413,9 @@ def kernel(s, x): ), grid=8, ), - compiler_params=dict(mosaic=dict(allow_input_fusion=[False, True])), + compiler_params=pltpu.TPUCompilerParams( + allow_input_fusion=[False, True] + ), )(s, x) first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:]) @@ -1556,12 +1558,12 @@ def kernel(x_ref, y_ref): self.pallas_call( kernel, out_shape=x, - compiler_params=dict(mosaic=dict(vmem_limit_bytes=256)), + compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=256), )(x) self.pallas_call( kernel, out_shape=x, - compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))), + compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=int(2**18)), )(x) def test_allow_input_fusion(self): @@ -1578,7 +1580,7 @@ def f(x, y): in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))], out_specs=pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0)), out_shape=x, - compiler_params=dict(mosaic=dict(allow_input_fusion=[True])), + compiler_params=pltpu.TPUCompilerParams(allow_input_fusion=[True]), )(z) x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) @@ -1606,8 +1608,8 @@ def kernel(x_ref, y_ref): self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - compiler_params=dict( - mosaic=dict(internal_scratch_in_bytes=requested_bytes) + compiler_params=pltpu.TPUCompilerParams( + internal_scratch_in_bytes=requested_bytes, ), )(x) From 7cd10d8854071f74b8dd7e9946ce44a8613c3b4e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 20 Aug 2024 15:48:00 -0700 Subject: [PATCH 186/702] Skip test_in_layouts_jit_jnp_input if xla_extension_version < 282 PiperOrigin-RevId: 665561830 --- tests/layout_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/layout_test.py b/tests/layout_test.py index c390bdc9f186..3cfc117b925e 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -15,6 +15,7 @@ import contextlib import math from functools import partial +import unittest from absl.testing import absltest import numpy as np @@ -25,6 +26,7 @@ from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip +from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -500,6 +502,8 @@ def g(x): 'Layout passed to jit does not match the layout on the respective arg'): g(arr) + @unittest.skipIf(xla_extension_version < 282, + "Requires xla_extension_version >= 282") def test_in_layouts_jit_jnp_input(self): major_last_layout = DLL(major_to_minor=(1, 0)) sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) From 82c9da020a78997862a8f7ccd494bed363f7ed01 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 20 Aug 2024 16:18:21 -0700 Subject: [PATCH 187/702] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums. This allows us to get more cache hits globally. For example: Before: ``` jax.jit(f, out_shardings=s)(arr) jax.jit(f, out_shardings=s)(arr) # cpp cache miss ``` After: ``` jax.jit(f, out_shardings=s)(arr) jax.jit(f, out_shardings=s)(arr) # cpp cache hit ``` Also, we can remove the hack (which I didn't like) in multihost_utils.py. PiperOrigin-RevId: 665574475 --- jax/_src/api.py | 6 +- jax/_src/interpreters/pxla.py | 40 +++++++++- jax/_src/pjit.py | 115 +++++++++++++++++++--------- jax/experimental/multihost_utils.py | 11 ++- tests/pjit_test.py | 32 +++++--- 5 files changed, 148 insertions(+), 56 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 5a773783b877..f83a8d73165d 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2965,7 +2965,8 @@ def clear_backends(): pjit._infer_params_cached.cache_clear() pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error - pjit._cpp_pjit_cache.clear() + pjit._cpp_pjit_cache_fun_only.clear() + pjit._cpp_pjit_cache_explicit_attributes.clear() xc._xla.PjitFunctionCache.clear_all() @atexit.register @@ -2993,7 +2994,8 @@ def clear_caches(): util.clear_all_weakref_lru_caches() # Clear all C++ compiled executable caches for pjit - pjit._cpp_pjit_cache.clear() + pjit._cpp_pjit_cache_fun_only.clear() + pjit._cpp_pjit_cache_explicit_attributes.clear() pjit._infer_params_cached.cache_clear() xc._xla.PjitFunctionCache.clear_all() diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 1398d58fc787..7a231c06fd61 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -22,6 +22,7 @@ from collections.abc import Callable, Sequence, Iterable, Iterator import dataclasses from functools import partial, lru_cache, cached_property +import functools import itertools as it import logging import math @@ -89,6 +90,7 @@ class WeakRefList(list): logger = logging.getLogger(__name__) Index = Union[int, slice, tuple[Union[int, slice], ...]] +PyTreeDef = tree_util.PyTreeDef NoSharding = sharding_specs.NoSharding Chunked = sharding_specs.Chunked @@ -2922,6 +2924,33 @@ class MeshExecutableFastpathData(NamedTuple): in_device_local_layouts: Sequence[DeviceLocalLayout | None] +@dataclasses.dataclass(frozen=True) +class JitGlobalCppCacheKeys: + donate_argnums: tuple[int, ...] | None = None + donate_argnames: tuple[str, ...] | None = None + device: xc.Device | None = None + backend: str | None = None + in_shardings_treedef: PyTreeDef | None = None + in_shardings_leaves: tuple[Any, ...] | None = None + out_shardings_treedef: PyTreeDef | None = None + out_shardings_leaves: tuple[Any, ...] | None = None + in_layouts_treedef: PyTreeDef | None = None + in_layouts_leaves: tuple[Any, ...] | None = None + out_layouts_treedef: PyTreeDef | None = None + out_layouts_leaves: tuple[Any, ...] | None = None + + @functools.cached_property + def contains_explicit_attributes(self): + return (self.donate_argnums is not None or + self.donate_argnames is not None or + self.device is not None or + self.backend is not None or + any(not is_unspecified(i) for i in self.in_shardings_leaves) or + any(not is_unspecified(o) for o in self.out_shardings_leaves) or + any(i is not None for i in self.in_layouts_leaves) or + any(o is not None for o in self.out_layouts_leaves)) + + def reflatten_outputs_for_dispatch(out_tree, out_flat): # We arrive at dispatch having flattened according to the default # pytree registry, but we want to re-flatten according to our @@ -3037,9 +3066,14 @@ def aot_cache_miss(*args, **kwargs): fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, cc_shard_arg) + if xla_extension_version >= 283: + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], + JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) + else: + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], [], + tree_util.dispatch_registry, cc_shard_arg) if xla_extension_version < 282: def cc_shard_arg(x, sharding): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 9383c26bf7c4..d8784848fe58 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -63,6 +63,7 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src import sharding from jax._src.mesh import AbstractMesh from jax._src.sharding_impls import ( @@ -165,7 +166,6 @@ class PjitInfo(NamedTuple): keep_unused: bool inline: bool abstracted_axes: Any | None - has_explicit_sharding: bool use_resource_env: bool # False for jit, True for pjit # Hash and compare PjitInfo by identity when used as a cache key. @@ -314,14 +314,39 @@ def _cpp_pjit_evict_fn(self): # The entries are doubled here from the default 4096 because _pjit_call_impl # also has a cpp dispatch path and that would double the number of entries in # the global shared cache. -_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192) +# This cache is only used for jit's with only fun. For example: jax.jit(f) +_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192) +# This cache is used for jit where extra arguments are defined other than the +# fun. For example: jax.jit(f, donate_argnums=...) OR +# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the +# capacity might get full very fast because of all the jitted function in JAX +# which might evict train_step for example. +_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192) -def _get_cpp_global_cache(pjit_has_explicit_sharding): - if pjit_has_explicit_sharding: - return xc._xla.PjitFunctionCache() - else: - return _cpp_pjit_cache + +if xla_extension_version < 283: + def _get_cpp_global_cache(pjit_has_explicit_sharding): + if pjit_has_explicit_sharding: + return xc._xla.PjitFunctionCache() + else: + return _cpp_pjit_cache_fun_only + + def _pjit_explicit_sharding_and_layout( + in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, + device, backend) -> bool: + return (device is not None or + backend is not None or + any(not is_unspecified(i) for i in in_shardings_flat) or + any(not is_unspecified(o) for o in out_shardings_flat) or + any(i is not None for i in in_layouts_flat) or + any(o is not None for o in out_layouts_flat)) +else: + def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore + if contains_explicit_attributes: + return _cpp_pjit_cache_explicit_attributes + else: + return _cpp_pjit_cache_fun_only def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @@ -339,11 +364,34 @@ def cache_miss(*args, **kwargs): return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), - fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, - jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(jit_info.has_explicit_sharding)) + if xla_extension_version >= 283: + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=jit_info.donate_argnums, + donate_argnames=jit_info.donate_argnames, + device=jit_info.device, backend=jit_info.backend, + in_shardings_treedef=jit_info.in_shardings_treedef, + in_shardings_leaves=jit_info.in_shardings_leaves, + out_shardings_treedef=jit_info.out_shardings_treedef, + out_shardings_leaves=jit_info.out_shardings_leaves, + in_layouts_treedef=jit_info.in_layouts_treedef, + in_layouts_leaves=jit_info.in_layouts_leaves, + out_layouts_treedef=jit_info.out_layouts_treedef, + out_layouts_leaves=jit_info.out_layouts_leaves) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore + pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes)) + else: + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + jit_info.in_shardings_leaves, jit_info.out_shardings_leaves, + jit_info.in_layouts_leaves, jit_info.out_layouts_leaves, + jit_info.device, jit_info.backend) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, jit_info.donate_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun @@ -351,17 +399,6 @@ def cache_miss(*args, **kwargs): return cpp_pjitted_f -def _pjit_explicit_sharding_and_layout( - in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, - device, backend) -> bool: - return (device is not None or - backend is not None or - any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(o) for o in out_shardings_flat) or - any(i is not None for i in in_layouts_flat) or - any(o is not None for o in out_layouts_flat)) - - def _split_layout_and_sharding(entries): entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) layouts, shardings = [], [] @@ -445,10 +482,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings_leaves, out_shardings_leaves, in_layouts_leaves, - out_layouts_leaves, device, backend) - return PjitInfo( fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, @@ -466,7 +499,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, abstracted_axes=abstracted_axes, - has_explicit_sharding=has_explicit_sharding, use_resource_env=use_resource_env) @@ -1724,13 +1756,26 @@ def call_impl_cache_miss(*args_, **kwargs_): f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) - donated_argnums = [i for i, d in enumerate(donated_invars) if d] - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings, out_shardings, in_layouts, out_layouts, None, None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(has_explicit_sharding))(*args) + donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) + if xla_extension_version >= 283: + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=donated_argnums, donate_argnames=None, + device=None, backend=None, + in_shardings_treedef=None, in_shardings_leaves=in_shardings, + out_shardings_treedef=None, out_shardings_leaves=out_shardings, + in_layouts_treedef=None, in_layouts_leaves=in_layouts, + out_layouts_treedef=None, out_layouts_leaves=out_layouts) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], cache_key, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) + else: + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings, out_shardings, in_layouts, out_layouts, None, None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], donated_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 554bf2641769..56003ea7af5d 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -90,19 +90,17 @@ def sync_global_devices(name: str): assert_equal(h, f"sync_global_devices name mismatch ('{name}')") +# Identity function is at the top level so that `process_allgather` doesn't +# recompile on every invocation. def _identity_fn(x): return x -@lru_cache(maxsize=128) -def _jitted_identity_fn(sharding): - return jax.jit(_identity_fn, out_shardings=sharding) - def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: reps = sharding_impls.GSPMDSharding.get_replicated( inp.sharding._device_assignment) - out = _jitted_identity_fn(reps)(inp) + out = jax.jit(_identity_fn, out_shardings=reps)(inp) else: # All inputs here will be fully addressable. if jax.process_count() == 1: @@ -125,7 +123,8 @@ def _handle_array_process_allgather(inp, tiled): bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( global_aval.shape, s, bufs) - out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr) + out = jax.jit(_identity_fn, + out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr) return np.asarray(out.addressable_data(0)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index db105b527246..b1ab9a613b58 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -635,18 +635,16 @@ def testAutodiff(self, mesh, resources): @jtu.with_mesh([('x', 2), ('y', 1)]) def testAutodiffCache(self): - f = pjit( - lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None - ) + f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None) x = jnp.arange(16, dtype=jnp.float32) - jax.grad(f)(x) # Warm up the cache. - before = pjit_lib._pjit_lower_cached.cache_info() - jax.grad(f)(x) - after = pjit_lib._pjit_lower_cached.cache_info() - # One hit for the forward pass, one hit for backward. - self.assertEqual(after.hits, before.hits + 2) - self.assertEqual(after.misses, before.misses) + jax.grad(f)(x) # Warm up the cache. + with jtu.count_pjit_cpp_cache_miss() as count: + jax.grad(f)(x) + if xla_extension_version >= 283: + self.assertEqual(count[0], 0) # no cache miss i.e. cache hit + else: + self.assertEqual(count[0], 2) @jtu.with_mesh([('x', 2), ('y', 1)]) def testEvalJaxpr(self): @@ -4467,6 +4465,20 @@ def test_wsc_abstract_mesh_errors(self): ' match the mesh shape of the target sharding.*'): with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y'))) + @unittest.skipIf(xla_extension_version < 283, + "Requires xla_extension_version >= 283") + def test_global_jit_cpp_cache_hit_out_shardings(self): + mesh = jtu.create_global_mesh((2,), 'x') + s = NamedSharding(mesh, P('x')) + + def f(x): + return x * 2 + + with jtu.count_pjit_cpp_cache_miss() as count: + jax.jit(f, out_shardings=s)(np.arange(8)) + jax.jit(f, out_shardings=s)(np.arange(8)) + self.assertEqual(count[0], 1) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") From e51848ea3d6d1f4122527ad4e1736b2e1453147c Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 21 Aug 2024 05:07:58 -0700 Subject: [PATCH 188/702] Activate GPU kernel for LU decomposition. This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests. One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names). PiperOrigin-RevId: 665829252 --- jax/_src/export/_export.py | 4 +- .../cuda_lu_cusolver_getrf.py | 196 ++++++++++++++++++ jax/_src/lax/linalg.py | 86 ++++---- jaxlib/gpu_solver.py | 63 ++---- jaxlib/lapack.py | 71 +++---- tests/export_back_compat_test.py | 23 +- tests/shape_poly_test.py | 35 +++- 7 files changed, 351 insertions(+), 127 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index e2c60d3778fe..1bc9d8ab1c8c 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -969,9 +969,7 @@ def _check_lowering(lowering) -> None: "cu_lu_pivots_to_permutation", # "cublas_getrf_batched", "cusolver_getrf", # "hipblas_getrf_batched", "hipsolver_getrf", - # TODO(b/357034884): This can be added once the mimimum version of jaxlib - # (v0.4.32) includes this new FFI call. - # "cusolver_getrf_ffi", + "cusolver_getrf_ffi", # lu on TPU "LuDecomposition", # ApproxTopK on TPU diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py new file mode 100644 index 000000000000..47da841aec0a --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py @@ -0,0 +1,196 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, int32, float32, complex64 + +data_2024_08_19 = {} + +data_2024_08_19["f32"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cu_lu_pivots_to_permutation', 'cusolver_getrf_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[ 8. , 9. , 10. , 11. ], + [ 0. , 1. , 2. , 3. ], + [ 0.5, 0.5, 0. , 0. ]], dtype=float32), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<12xf32> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<12xf32>) -> tensor<3x4xf32> loc(#loc5) + %2:3 = stablehlo.custom_call @cusolver_getrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<3x4xf32>) -> (tensor<3x4xf32>, tensor<3xi32>, tensor) loc(#loc6) + %c = stablehlo.constant dense<1> : tensor loc(#loc6) + %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3xi32> loc(#loc6) + %4 = stablehlo.subtract %2#1, %3 : tensor<3xi32> loc(#loc6) + %c_0 = stablehlo.constant dense<0> : tensor loc(#loc6) + %5 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor loc(#loc6) + %6 = stablehlo.compare GE, %2#2, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc6) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc6) + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc6) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x4xf32> loc(#loc6) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x4xi1> loc(#loc6) + %10 = stablehlo.select %9, %2#0, %8 : tensor<3x4xi1>, tensor<3x4xf32> loc(#loc6) + %11 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%4) {mhlo.backend_config = {}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<3xi32>) -> tensor<3xi32> loc(#loc7) + return %10, %4, %11 : tensor<3x4xf32>, tensor<3xi32>, tensor<3xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":442:11) +#loc4 = loc("jit()/jit(main)/iota[dtype=float32 shape=(12,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu"(#loc3)) +#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe9\xab+\x01c\x0f\x13\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17S\x0b\x13\x13\x1b\x0b\x0b\x13\x13S\x0f\x0b\x03I\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0bO/\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f\x0b\x0f\x0f\x17\x17\x0f\x1f\x0f\x1f\x0b\x0b\x1fO\x0b\x01\x05\x0b\x0f\x03'\x13\x0f\x17\x07\x07\x07\x07\x07\x0f\x1b\x13\x13\x13\x13\x13\x0f\x17\x17\x13\x02^\x06\x1dM!\x03\x03#\x9d\x1f\x05\x1b\x05\x1d\x11\x03\x05\x05\x1f\x05!\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x17\x07\xea\x06\x17\x053\x03\t')+\x0b-\x0b\r/\x055\x11\x01\x00\x057\x059\x05;\x03\x0b3c5y7{\r\x899\x8b\x05=\x05?\x05A\x05C\x03\x03=\x8d\x05E\x1dAC\x05G\x17\x07\xe6\x065\x1dGI\x05I\x17\x07\xe6\x06\x1d\x03\x13\x0fk\x11m\x13\x8f\x15c\x17o\x19q\x1b\x91\x1d\x93\x1f\x97\x05K\x03\x03\t\x9b\x03\x03\t\x9f\x03\x05U\xa1W\xa3\x05M\x05O\x03\x03\t\xa5\x03\x03#\xa7\x03\x13\x0fk\x11m\x13\xa9\x15c\x17o\x19q\x1bw\x1dc\x1fw\x1da!\x05Q\x03\x01\x1dS\x1dU\x1dW\x0b\x03\x1dY\x05\x01\r\x01\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03u#\x17\x03\x07}\x81\x85\r\x05e\x7fgi\x1d[\r\x05e\x83gi\x1d]\r\x05e\x87gi\x1d_\x1da\x1dc\x13\r\x01\x1de\x03\x03s\x03\x03\x95\x15\x03\x01\x01\x01\x03\x07su\x99\x1f\x1f\x01\x1f\x07\t\x01\x00\x00\x00\x1f!\x01\x1f\x07\t\x00\x00\x00\x00\t\x07\x07\x05\x1f\x15\t\x00\x00\xc0\x7f\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dg\x01\t\x01\x02\x02)\x03\r\x13)\x01\x13)\x05\r\x11\x0b\t\x1d\x13\x01\x1b)\x01\x0b\x11\x01\x07\t\x05\x05)\x031\x0b)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0f)\x03\x01\r)\x01\x11)\x05\x05\x05\x11)\x05\r\x11\x11)\x03\t\r\x04.\x02\x05\x01\x11\x05%\x07\x03\x01\x05\t\x11\x051\x07\x03#A\x0b\x03?;\x03\x19\r\x06E\x03\t\x03\x01\x07\x07\x01K\x07\t\x05\x07\x03\x03\x05\x03\x01O\x03\x07\x03\x07\x01\x03\x03\x05\x03\x0b\x0f\x06\x01\x03\x05\x05\x07\r\x05\x03\x01Q\x03\x07\x03\x07\x01\x03\x03\x07\x03\x11\x11\x07\x01S\x03#\x05\t\x13\x03\x07\x01\x03\x03%\x03\x15\x05\x03\x01Y\x03\x15\x03\x07\x01\x03\x03\t\x03\x19\x03\x07\x01[\x03'\x03\x17\x13\x06\x01\x03\t\x07\x1d\x05\x1b\x07\x07_]\x03\x05\x03\x0f\x15\x04\x05\x07\x1f\x0f!\x06\x03\x01\x05\x01\x00\xe2\x0ei9'\x0f\x0b\t\t\t\x03\x11#!\x8b+\x1b7\x85\x89\x1f\x1f\x15\x1d\x15\x1b%)9+\x1f/!)!)#\x1f\x19\x13\ri\x15\x15\x17\x19\x17\x11\x11\x1f\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00subtract_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00broadcast_dimensions\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float32 shape=(12,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]\x00jit()/jit(main)/lu\x00compare_type\x00comparison_direction\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00[2]\x00main\x00public\x00cusolver_getrf_ffi\x00cu_lu_pivots_to_permutation\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_08_19["f64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cu_lu_pivots_to_permutation', 'cusolver_getrf_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[ 8. , 9. , 10. , 11. ], + [ 0. , 1. , 2. , 3. ], + [ 0.5, 0.5, 0. , 0. ]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<12xf64> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<12xf64>) -> tensor<3x4xf64> loc(#loc5) + %2:3 = stablehlo.custom_call @cusolver_getrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<3x4xf64>) -> (tensor<3x4xf64>, tensor<3xi32>, tensor) loc(#loc6) + %c = stablehlo.constant dense<1> : tensor loc(#loc6) + %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3xi32> loc(#loc6) + %4 = stablehlo.subtract %2#1, %3 : tensor<3xi32> loc(#loc6) + %c_0 = stablehlo.constant dense<0> : tensor loc(#loc6) + %5 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor loc(#loc6) + %6 = stablehlo.compare GE, %2#2, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc6) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc6) + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc6) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x4xf64> loc(#loc6) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x4xi1> loc(#loc6) + %10 = stablehlo.select %9, %2#0, %8 : tensor<3x4xi1>, tensor<3x4xf64> loc(#loc6) + %11 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%4) {mhlo.backend_config = {}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<3xi32>) -> tensor<3xi32> loc(#loc7) + return %10, %4, %11 : tensor<3x4xf64>, tensor<3xi32>, tensor<3xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":442:11) +#loc4 = loc("jit()/jit(main)/iota[dtype=float64 shape=(12,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu"(#loc3)) +#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe9\xab+\x01c\x0f\x13\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17S\x0b\x13\x13\x1b\x0b\x0b\x13\x13S\x0f\x0b\x03I\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0bO/\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f\x0b\x0f\x0f\x17\x17\x0f\x1f\x0f\x1f\x0b\x0b/O\x0b\x01\x05\x0b\x0f\x03'\x13\x0f\x17\x07\x07\x07\x07\x07\x0f\x1b\x13\x13\x13\x13\x13\x0f\x17\x17\x13\x02n\x06\x1dM!\x03\x03#\x9d\x1f\x05\x1b\x05\x1d\x11\x03\x05\x05\x1f\x05!\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x17\x07\xea\x06\x17\x053\x03\t')+\x0b-\x0b\r/\x055\x11\x01\x00\x057\x059\x05;\x03\x0b3c5y7{\r\x899\x8b\x05=\x05?\x05A\x05C\x03\x03=\x8d\x05E\x1dAC\x05G\x17\x07\xe6\x065\x1dGI\x05I\x17\x07\xe6\x06\x1d\x03\x13\x0fk\x11m\x13\x8f\x15c\x17o\x19q\x1b\x91\x1d\x93\x1f\x97\x05K\x03\x03\t\x9b\x03\x03\t\x9f\x03\x05U\xa1W\xa3\x05M\x05O\x03\x03\t\xa5\x03\x03#\xa7\x03\x13\x0fk\x11m\x13\xa9\x15c\x17o\x19q\x1bw\x1dc\x1fw\x1da!\x05Q\x03\x01\x1dS\x1dU\x1dW\x0b\x03\x1dY\x05\x01\r\x01\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03u#\x17\x03\x07}\x81\x85\r\x05e\x7fgi\x1d[\r\x05e\x83gi\x1d]\r\x05e\x87gi\x1d_\x1da\x1dc\x13\r\x01\x1de\x03\x03s\x03\x03\x95\x15\x03\x01\x01\x01\x03\x07su\x99\x1f\x1f\x01\x1f\x07\t\x01\x00\x00\x00\x1f!\x01\x1f\x07\t\x00\x00\x00\x00\t\x07\x07\x05\x1f\x15\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dg\x01\t\x01\x02\x02)\x03\r\x13)\x01\x13)\x05\r\x11\x0b\x0b\x1d\x13\x01\x1b)\x01\x0b\x11\x01\x07\t\x05\x05)\x031\x0b)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0f)\x03\x01\r)\x01\x11)\x05\x05\x05\x11)\x05\r\x11\x11)\x03\t\r\x04.\x02\x05\x01\x11\x05%\x07\x03\x01\x05\t\x11\x051\x07\x03#A\x0b\x03?;\x03\x19\r\x06E\x03\t\x03\x01\x07\x07\x01K\x07\t\x05\x07\x03\x03\x05\x03\x01O\x03\x07\x03\x07\x01\x03\x03\x05\x03\x0b\x0f\x06\x01\x03\x05\x05\x07\r\x05\x03\x01Q\x03\x07\x03\x07\x01\x03\x03\x07\x03\x11\x11\x07\x01S\x03#\x05\t\x13\x03\x07\x01\x03\x03%\x03\x15\x05\x03\x01Y\x03\x15\x03\x07\x01\x03\x03\t\x03\x19\x03\x07\x01[\x03'\x03\x17\x13\x06\x01\x03\t\x07\x1d\x05\x1b\x07\x07_]\x03\x05\x03\x0f\x15\x04\x05\x07\x1f\x0f!\x06\x03\x01\x05\x01\x00\xe2\x0ei9'\x0f\x0b\t\t\t\x03\x11#!\x8b+\x1b7\x85\x89\x1f\x1f\x15\x1d\x15\x1b%)9+\x1f/!)!)#\x1f\x19\x13\ri\x15\x15\x17\x19\x17\x11\x11\x1f\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00subtract_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00broadcast_dimensions\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float64 shape=(12,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]\x00jit()/jit(main)/lu\x00compare_type\x00comparison_direction\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00[2]\x00main\x00public\x00cusolver_getrf_ffi\x00cu_lu_pivots_to_permutation\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_08_19["c64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cu_lu_pivots_to_permutation', 'cusolver_getrf_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[ 8. +0.j, 9. +0.j, 10. +0.j, 11. +0.j], + [ 0. +0.j, 1. +0.j, 2. +0.j, 3. +0.j], + [ 0.5+0.j, 0.5+0.j, 0. +0.j, 0. +0.j]], dtype=complex64), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<12xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<12xcomplex>) -> tensor<3x4xcomplex> loc(#loc5) + %2:3 = stablehlo.custom_call @cusolver_getrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<3x4xcomplex>) -> (tensor<3x4xcomplex>, tensor<3xi32>, tensor) loc(#loc6) + %c = stablehlo.constant dense<1> : tensor loc(#loc6) + %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3xi32> loc(#loc6) + %4 = stablehlo.subtract %2#1, %3 : tensor<3xi32> loc(#loc6) + %c_0 = stablehlo.constant dense<0> : tensor loc(#loc6) + %5 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor loc(#loc6) + %6 = stablehlo.compare GE, %2#2, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc6) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc6) + %cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc6) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x4xcomplex> loc(#loc6) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x4xi1> loc(#loc6) + %10 = stablehlo.select %9, %2#0, %8 : tensor<3x4xi1>, tensor<3x4xcomplex> loc(#loc6) + %11 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%4) {mhlo.backend_config = {}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<3xi32>) -> tensor<3xi32> loc(#loc7) + return %10, %4, %11 : tensor<3x4xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":442:11) +#loc4 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(12,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu"(#loc3)) +#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xeb\xab-\x01c\x0f\x13\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17S\x0b\x13\x13\x1b\x0b\x0b\x13\x13S\x0f\x0b\x03I\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0bO/\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f\x0b\x0f\x0f\x17\x17\x0f\x1f\x0f\x1f\x0b\x0b/O\x0b\x01\x05\x0b\x0f\x03)\x13\x0f\x17\x0b\x07\x07\x07\x07\x0f\x1b\x07\x13\x13\x13\x13\x13\x0f\x17\x17\x13\x02v\x06\x1dM!\x03\x03#\x9d\x1f\x05\x1b\x05\x1d\x11\x03\x05\x05\x1f\x05!\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x17\x07\xea\x06\x17\x053\x03\t')+\x0b-\x0b\r/\x055\x11\x01\x00\x057\x059\x05;\x03\x0b3c5y7{\r\x899\x8b\x05=\x05?\x05A\x05C\x03\x03=\x8d\x05E\x1dAC\x05G\x17\x07\xe6\x065\x1dGI\x05I\x17\x07\xe6\x06\x1d\x03\x13\x0fk\x11m\x13\x8f\x15c\x17o\x19q\x1b\x91\x1d\x93\x1f\x97\x05K\x03\x03\t\x9b\x03\x03\t\x9f\x03\x05U\xa1W\xa3\x05M\x05O\x03\x03\t\xa5\x03\x03#\xa7\x03\x13\x0fk\x11m\x13\xa9\x15c\x17o\x19q\x1bw\x1dc\x1fw\x1da!\x05Q\x03\x01\x1dS\x1dU\x1dW\x0b\x03\x1dY\x05\x01\r\x01\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03u#\x17\x03\x07}\x81\x85\r\x05e\x7fgi\x1d[\r\x05e\x83gi\x1d]\r\x05e\x87gi\x1d_\x1da\x1dc\x13\r\x01\x1de\x03\x03s\x03\x03\x95\x15\x03\x01\x01\x01\x03\x07su\x99\x1f!\x01\x1f\x07\t\x01\x00\x00\x00\x1f#\x01\x1f\x07\t\x00\x00\x00\x00\t\x07\x07\x05\x1f\x15\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dg\x01\t\x01\x02\x02)\x03\r\x13)\x01\x13)\x05\r\x11\x0b\x03\x19\x1d\x13\x01\x1b)\x01\x0b\x11\x01\x07\t\x05\x05\t)\x031\x0b)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0f)\x03\x01\r)\x01\x11)\x05\x05\x05\x11)\x05\r\x11\x11)\x03\t\r\x04.\x02\x05\x01\x11\x05%\x07\x03\x01\x05\t\x11\x051\x07\x03#A\x0b\x03?;\x03\x1b\r\x06E\x03\t\x03\x01\x07\x07\x01K\x07\t\x05\x07\x03\x03\x05\x03\x01O\x03\x07\x03\x07\x01\x03\x03\x05\x03\x0b\x0f\x06\x01\x03\x05\x05\x07\r\x05\x03\x01Q\x03\x07\x03\x07\x01\x03\x03\x07\x03\x11\x11\x07\x01S\x03%\x05\t\x13\x03\x07\x01\x03\x03'\x03\x15\x05\x03\x01Y\x03\x15\x03\x07\x01\x03\x03\t\x03\x19\x03\x07\x01[\x03)\x03\x17\x13\x06\x01\x03\t\x07\x1d\x05\x1b\x07\x07_]\x03\x05\x03\x0f\x15\x04\x05\x07\x1f\x0f!\x06\x03\x01\x05\x01\x00\xea\x0ei9'\x0f\x0b\t\t\t\x03\x11#!\x8b+\x1b7\x85\x8d\x1f\x1f\x15\x1d\x15\x1b%)9+\x1f/!)!)#\x1f\x19\x13\ri\x15\x15\x17\x19\x17\x11\x11\x1f\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00subtract_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00broadcast_dimensions\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex64 shape=(12,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]\x00jit()/jit(main)/lu\x00compare_type\x00comparison_direction\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00[2]\x00main\x00public\x00cusolver_getrf_ffi\x00cu_lu_pivots_to_permutation\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_08_19["c128"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cu_lu_pivots_to_permutation', 'cusolver_getrf_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[ 8. +0.j, 9. +0.j, 10. +0.j, 11. +0.j], + [ 0. +0.j, 1. +0.j, 2. +0.j, 3. +0.j], + [ 0.5+0.j, 0.5+0.j, 0. +0.j, 0. +0.j]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<12xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<12xcomplex>) -> tensor<3x4xcomplex> loc(#loc5) + %2:3 = stablehlo.custom_call @cusolver_getrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<3x4xcomplex>) -> (tensor<3x4xcomplex>, tensor<3xi32>, tensor) loc(#loc6) + %c = stablehlo.constant dense<1> : tensor loc(#loc6) + %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3xi32> loc(#loc6) + %4 = stablehlo.subtract %2#1, %3 : tensor<3xi32> loc(#loc6) + %c_0 = stablehlo.constant dense<0> : tensor loc(#loc6) + %5 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor loc(#loc6) + %6 = stablehlo.compare GE, %2#2, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc6) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc6) + %cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc6) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x4xcomplex> loc(#loc6) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x4xi1> loc(#loc6) + %10 = stablehlo.select %9, %2#0, %8 : tensor<3x4xi1>, tensor<3x4xcomplex> loc(#loc6) + %11 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%4) {mhlo.backend_config = {}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<3xi32>) -> tensor<3xi32> loc(#loc7) + return %10, %4, %11 : tensor<3x4xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":442:11) +#loc4 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(12,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu"(#loc3)) +#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xeb\xab-\x01c\x0f\x13\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17S\x0b\x13\x13\x1b\x0b\x0b\x13\x13S\x0f\x0b\x03I\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0bO/\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f\x0b\x0f\x0f\x17\x17\x0f\x1f\x0f\x1f\x0b\x0bOO\x0b\x01\x05\x0b\x0f\x03)\x13\x0f\x17\x0b\x07\x07\x07\x07\x0f\x1b\x07\x13\x13\x13\x13\x13\x0f\x17\x17\x13\x02\x96\x06\x1dM!\x03\x03#\x9d\x1f\x05\x1b\x05\x1d\x11\x03\x05\x05\x1f\x05!\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x17\x07\xea\x06\x17\x053\x03\t')+\x0b-\x0b\r/\x055\x11\x01\x00\x057\x059\x05;\x03\x0b3c5y7{\r\x899\x8b\x05=\x05?\x05A\x05C\x03\x03=\x8d\x05E\x1dAC\x05G\x17\x07\xe6\x065\x1dGI\x05I\x17\x07\xe6\x06\x1d\x03\x13\x0fk\x11m\x13\x8f\x15c\x17o\x19q\x1b\x91\x1d\x93\x1f\x97\x05K\x03\x03\t\x9b\x03\x03\t\x9f\x03\x05U\xa1W\xa3\x05M\x05O\x03\x03\t\xa5\x03\x03#\xa7\x03\x13\x0fk\x11m\x13\xa9\x15c\x17o\x19q\x1bw\x1dc\x1fw\x1da!\x05Q\x03\x01\x1dS\x1dU\x1dW\x0b\x03\x1dY\x05\x01\r\x01\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03u#\x17\x03\x07}\x81\x85\r\x05e\x7fgi\x1d[\r\x05e\x83gi\x1d]\r\x05e\x87gi\x1d_\x1da\x1dc\x13\r\x01\x1de\x03\x03s\x03\x03\x95\x15\x03\x01\x01\x01\x03\x07su\x99\x1f!\x01\x1f\x07\t\x01\x00\x00\x00\x1f#\x01\x1f\x07\t\x00\x00\x00\x00\t\x07\x07\x05\x1f\x15!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dg\x01\t\x01\x02\x02)\x03\r\x13)\x01\x13)\x05\r\x11\x0b\x03\x19\x1d\x13\x01\x1b)\x01\x0b\x11\x01\x07\t\x05\x05\x0b)\x031\x0b)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0f)\x03\x01\r)\x01\x11)\x05\x05\x05\x11)\x05\r\x11\x11)\x03\t\r\x04.\x02\x05\x01\x11\x05%\x07\x03\x01\x05\t\x11\x051\x07\x03#A\x0b\x03?;\x03\x1b\r\x06E\x03\t\x03\x01\x07\x07\x01K\x07\t\x05\x07\x03\x03\x05\x03\x01O\x03\x07\x03\x07\x01\x03\x03\x05\x03\x0b\x0f\x06\x01\x03\x05\x05\x07\r\x05\x03\x01Q\x03\x07\x03\x07\x01\x03\x03\x07\x03\x11\x11\x07\x01S\x03%\x05\t\x13\x03\x07\x01\x03\x03'\x03\x15\x05\x03\x01Y\x03\x15\x03\x07\x01\x03\x03\t\x03\x19\x03\x07\x01[\x03)\x03\x17\x13\x06\x01\x03\t\x07\x1d\x05\x1b\x07\x07_]\x03\x05\x03\x0f\x15\x04\x05\x07\x1f\x0f!\x06\x03\x01\x05\x01\x00\xee\x0ei9'\x0f\x0b\t\t\t\x03\x11#!\x8b+\x1b7\x85\x8f\x1f\x1f\x15\x1d\x15\x1b%)9+\x1f/!)!)#\x1f\x19\x13\ri\x15\x15\x17\x19\x17\x11\x11\x1f\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00subtract_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00broadcast_dimensions\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex128 shape=(12,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]\x00jit()/jit(main)/lu\x00compare_type\x00comparison_direction\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00[2]\x00main\x00public\x00cusolver_getrf_ffi\x00cu_lu_pivots_to_permutation\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 37b2ccf61ec4..0d0db8eb38c1 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -27,6 +27,7 @@ from jax._src import ad_util from jax._src import api +from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src.core import ( @@ -199,7 +200,7 @@ def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array: An int32 array of shape (..., permutation_size). """ permutation = lu_pivots_to_permutation_p.bind( - pivots, permutation_size=int(permutation_size)) + pivots, permutation_size=permutation_size) return permutation @@ -1164,7 +1165,7 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): permutation = lax.broadcasted_iota(jnp.int32, batch_dims + (m,), len(batch_dims)) - if m == 0: + if m == 0 or k == 0: return permutation upper = np.array(k, np.int32) if is_constant_dim(k) else k result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn, @@ -1200,10 +1201,12 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *, def _lu_pivots_to_permutation_gpu_lowering(platform, ctx, pivots, *, permutation_size): rule = ffi.ffi_lowering(f"{platform}_lu_pivots_to_permutation") - return rule(ctx, pivots, - # TODO(b/358275922): remove unused parameter 12 weeks after - # the release of jaxlib v0.4.32. - permutation_size=np.int32(permutation_size)) + # TODO(b/358275922): remove unused once jaxlib v0.4.32 is the minimum version. + if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32): + kwargs = dict(permutation_size=np.int32(permutation_size)) + else: + kwargs = {} + return rule(ctx, pivots, **kwargs) lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation') @@ -1313,7 +1316,8 @@ def _lu_abstract_eval(operand): batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] - pivot = operand.update(shape=batch_dims + (min(m, n),), dtype=jnp.int32) + pivot = operand.update(shape=batch_dims + (core.min_dim(m, n),), + dtype=jnp.int32) perm = operand.update(shape=batch_dims + (m,), dtype=jnp.int32) else: pivot = operand @@ -1375,39 +1379,51 @@ def _lu_batching_rule(batched_args, batch_dims): x = batching.moveaxis(x, bd, 0) return lu_p.bind(x), (0, 0, 0) -def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, - platform: str): +def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, platform: str, + target_name_prefix: str): operand_aval, = ctx.avals_in - # It should be possible to support fully-dynamic shapes, but since - # the last two dimensions (m, n) are used in more involved ways, we only - # support dynamic dimensions for the batch size for now. - if not is_constant_shape(operand_aval.shape[-2:]): - raise NotImplementedError( - "Shape polymorphism for native lowering for lu on CPU and GPU is " - f"implemented only for the batch dimensions: {operand_aval.shape}") - - # TODO(b/357034884): Remove once jaxlib 0.4.32 is the minimum version. - ctx_arg = (ctx,) if jaxlib_version >= (0, 4, 32) else () - out_aval, pivot_aval, perm_aval = ctx.avals_out batch_dims = operand_aval.shape[:-2] + info_aval = ShapedArray(batch_dims, np.dtype(np.int32)) m = operand_aval.shape[-2] - if platform in ["cuda", "rocm"]: - # TODO(necula): remove the platform kwarg when we implement GPU support. - if not is_constant_shape(operand_aval.shape): + + # TODO(b/357034884): Remove version gate once jaxlib 0.4.32 is the minimum + # version and the forward compat flag after the 3 week compatibility window. + if jaxlib_version < (0, 4, 32) or ctx.is_forward_compat(): + if not is_constant_shape(operand_aval.shape[-2:]): raise NotImplementedError( - "Shape polymorphism for native serialization for lu on GPU is not " - f"implemented; b/261671778; {operand_aval.shape}") - lu, pivot, info = getrf_impl(*ctx_arg, operand_aval.dtype, operand) + "Shape polymorphism for native lowering for lu on CPU and GPU is " + f"implemented only for the batch dimensions: {operand_aval.shape}") + if platform in ["cuda", "rocm"]: + if not is_constant_shape(operand_aval.shape): + raise NotImplementedError( + "Shape polymorphism for native serialization for lu on GPU is not " + f"implemented; b/261671778; {operand_aval.shape}") + lu, pivot, info = getrf_impl(operand_aval.dtype, operand) + else: + op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) + lu, pivot, info = getrf_impl( + operand_aval.dtype, operand, a_shape_vals=op_shape_vals) else: - op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - # TODO(b/344892332): Remove the conditional after the compatibility period. - lu, pivot, info = getrf_impl( - *ctx_arg, operand_aval.dtype, operand, a_shape_vals=op_shape_vals) + if target_name_prefix == "cpu": + target_name = lapack.prepare_lapack_call("getrf_ffi", operand_aval.dtype) + else: + target_name = f"{target_name_prefix}solver_getrf_ffi" + # We manually construct the layouts because the input and output are + # expected to be in Fortran order. + nb = len(batch_dims) + layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1)) + result_layouts = [layout, tuple(range(nb, -1, -1)), + tuple(range(nb - 1, -1, -1))] + rule = ffi.ffi_lowering(target_name, operand_layouts=[layout], + result_layouts=result_layouts, + operand_output_aliases={0: 0}) + sub_ctx = ctx.replace(avals_out=[out_aval, pivot_aval, info_aval]) + lu, pivot, info = rule(sub_ctx, operand) + # Subtract 1 from the pivot to get 0-based indices. pivot = hlo.subtract(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)) - ok = mlir.compare_hlo( - info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), + ok = mlir.compare_hlo(info, mlir.full_like_aval(ctx, 0, info_aval), "GE", "SIGNED") select_lu_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) lu = _broadcasting_select_hlo( @@ -1452,16 +1468,16 @@ def _lu_tpu_lowering_rule(ctx, operand): mlir.register_lowering(lu_p, partial(_lu_cpu_gpu_lowering, lapack.getrf_hlo, - platform='cpu'), + platform='cpu', target_name_prefix="cpu"), platform='cpu') mlir.register_lowering( lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf, - platform='cuda'), + platform='cuda', target_name_prefix="cu"), platform='cuda') mlir.register_lowering( lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.rocm_getrf, - platform='rocm'), + platform='rocm', target_name_prefix="hip"), platform='rocm') mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu') diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index baa84e8eb9de..ff1e5570bb04 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -99,7 +99,8 @@ def _real_type(dtype): return np.finfo(dtype).dtype -def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a): +# TODO(b/357034884): Remove this function after the forward compat window. +def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a): """LU decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -110,60 +111,40 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a): i32_type = ir.IntegerType.get_signless(32) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - # TODO(b/357034884): Remove after 3 week forward compatibility window. - if ctx.is_forward_compat(): - if not gpu_blas: - raise GpuLibNotLinkedError() - - batch = math.prod(batch_dims) - if batch > 1 and m == n and m // batch <= 128: - lwork, opaque = gpu_blas.build_getrf_batched_descriptor( - np.dtype(dtype), batch, m) - workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) - kernel = f"{platform}blas_getrf_batched" - else: - lwork, opaque = gpu_solver.build_getrf_descriptor( - np.dtype(dtype), batch, m, n) - workspace = ir.RankedTensorType.get([lwork], a_type.element_type) - kernel = f"{platform}solver_getrf" + if not gpu_blas: + raise GpuLibNotLinkedError() - out = custom_call( - kernel, - result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type), - ir.RankedTensorType.get(batch_dims, i32_type), - workspace, - ], - operands=[a], - backend_config=opaque, - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={0: 0}).results - return out[:3] + batch = math.prod(batch_dims) + if batch > 1 and m == n and m // batch <= 128: + lwork, opaque = gpu_blas.build_getrf_batched_descriptor( + np.dtype(dtype), batch, m) + workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) + kernel = f"{platform}blas_getrf_batched" + else: + lwork, opaque = gpu_solver.build_getrf_descriptor( + np.dtype(dtype), batch, m, n) + workspace = ir.RankedTensorType.get([lwork], a_type.element_type) + kernel = f"{platform}solver_getrf" - return custom_call( - f"{platform}solver_getrf_ffi", + out = custom_call( + kernel, result_types=[ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type), ir.RankedTensorType.get(batch_dims, i32_type), + workspace, ], operands=[a], + backend_config=opaque, operand_layouts=[layout], result_layouts=[ layout, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), + [0], ], - operand_output_aliases={0: 0}, - backend_config={}, - api_version=4).results + operand_output_aliases={0: 0}).results + return out[:3] cuda_getrf = partial(_getrf_hlo, "cu", _cublas, _cusolver) diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index d43eef8c8fc3..09bb75597904 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -95,6 +95,12 @@ def _svd_computation_attr( } +def prepare_lapack_call(fn_base, dtype): + """Initializes the LAPACK library and returns the LAPACK target name.""" + _lapack.initialize() + return build_lapack_fn_target(fn_base, dtype) + + def build_lapack_fn_target(fn_base: str, dtype) -> str: """Builds the target name for a LAPACK function custom call.""" try: @@ -157,15 +163,13 @@ def trsm_hlo(dtype, alpha, a, b, # # ?getrf: LU decomposition -def getrf_hlo(ctx, dtype, a: ir.Value, *, - a_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() +def getrf_hlo(dtype, a: ir.Value, *, a_shape_vals: tuple[DimensionSize, ...]): a_type = ir.RankedTensorType(a.type) assert len(a_shape_vals) >= 2 batch_dims_vals = a_shape_vals[:-2] num_bd = len(a_shape_vals) - 2 m, n = a_shape_vals[-2:] - fn_base = build_lapack_fn_target(fn_base="getrf", dtype=dtype) + fn = prepare_lapack_call(fn_base="getrf", dtype=dtype) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) @@ -177,43 +181,24 @@ def getrf_hlo(ctx, dtype, a: ir.Value, *, ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - if ctx.is_forward_compat(): - fn = fn_base - scalar_layout = [] - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + scalar_layout = [] + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - return custom_call( - fn, - result_types=result_types, - operands=[batch_size_val, ensure_hlo_s32(m), ensure_hlo_s32(n), a], - operand_layouts=[scalar_layout] * 3 + [layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - ], - operand_output_aliases={3: 0}, - result_shapes=result_shapes, - ).results - else: - fn = fn_base + "_ffi" - return custom_call( - fn, - result_types=result_types, - operands=[a], - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - ], - operand_output_aliases={0: 0}, - result_shapes=result_shapes, - backend_config={}, - api_version=4, - ).results + return custom_call( + fn, + result_types=result_types, + operands=[batch_size_val, ensure_hlo_s32(m), ensure_hlo_s32(n), a], + operand_layouts=[scalar_layout] * 3 + [layout], + result_layouts=[ + layout, + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + ], + operand_output_aliases={3: 0}, + result_shapes=result_shapes, + ).results # # ?geqrf: QR decomposition @@ -344,9 +329,8 @@ def orgqr_hlo(dtype, a: ir.Value, tau, *, def potrf_hlo(ctx, dtype, a: ir.Value, *, lower=False, a_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() a_type = ir.RankedTensorType(a.type) - fn_base = build_lapack_fn_target(fn_base="potrf", dtype=dtype) + fn_base = prepare_lapack_call(fn_base="potrf", dtype=dtype) batch_dims_vals = a_shape_vals[:-2] num_bd = len(batch_dims_vals) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) @@ -395,7 +379,6 @@ def potrf_hlo(ctx, dtype, a: ir.Value, *, lower=False, def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, a_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() a_type = ir.RankedTensorType(a.type) assert len(a_shape_vals) >= 2 m, n = a_shape_vals[-2:] @@ -403,7 +386,7 @@ def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, assert type(n) is int batch_dims_vals = a_shape_vals[:-2] num_bd = len(batch_dims_vals) - fn_base = build_lapack_fn_target(fn_base="gesdd", dtype=dtype) + fn_base = prepare_lapack_call(fn_base="gesdd", dtype=dtype) i32_type = ir.IntegerType.get_signless(32) workspace: list[ShapeTypePair] diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 7fdf15c598ea..92072d0a0168 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -45,6 +45,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32 from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation +from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK @@ -127,7 +128,9 @@ def test_custom_call_coverage(self): cuda_threefry2x32.data_2023_03_15, cuda_threefry2x32.data_2024_07_30, cpu_lu_lapack_getrf.data_2023_06_14, cuda_lu_pivots_to_permutation.data_2024_08_08, - cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17, + cuda_lu_cusolver_getrf.data_2024_08_19, + cuda_qr_cusolver_geqrf.data_2023_03_18, + cuda_eigh_cusolver_syev.data_2023_03_17, rocm_qr_hipsolver_geqrf.data_2024_08_05, rocm_eigh_hipsolver_syev.data_2024_08_05, cpu_schur_lapack_gees.data_2023_07_16, @@ -356,6 +359,24 @@ def test_cuda_lu_pivots_to_permutation(self): data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08) self.run_one_test(func, data) + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", + dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) + def test_cuda_lu_lapack_getrf(self, dtype_name:str): + if not config.enable_x64.value and dtype_name in ["f64", "c128"]: + self.skipTest("Test disabled for x32 mode") + if jaxlib_version < (0, 4, 32): + self.skipTest("Not implemented in older versions of jaxlib") + dtype = dict(f32=np.float32, f64=np.float64, + c64=np.complex64, c128=np.complex128)[dtype_name] + shape = (3, 4) + func = lambda: CompatTest.lu_harness(shape, dtype) + # TODO(b/360788062): Clean up after the compatibility period. + with config.export_ignore_forward_compatibility(True): + data = self.load_testdata(cuda_lu_cusolver_getrf.data_2024_08_19[dtype_name]) + self.run_one_test(func, data) + @staticmethod def qr_harness(shape, dtype): # In order to keep inputs small, we construct the input programmatically diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index d5b32cdbd7fc..c830678b1869 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2767,6 +2767,32 @@ def test_vmap_error(self): ((2, 3, 4), "b1, b2, b3", 8, ["b3 >= 9"]), ] ], + [ + PolyHarness( # pylint: disable=g-complex-comprehension + "lu", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}", + lax.linalg.lu, + arg_descriptors=[RandArg(shape, dtype)], + polymorphic_shapes=[poly], + # TODO(b/360788062): Remove once the forward compatibility window is + # closed. + override_jax_config_flags={ + "jax_export_ignore_forward_compatibility": True}) + for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() + for shape, poly in [ + ((5, 4), "m, n"), + ((2, 0, 4), "b, ..."), + ((2, 4, 0), "b, ..."), + ((2, 3, 4, 4), "b1, b2, ..."), + ((2, 3, 4, 5), "b1, b2, ..."), + ((2, 3, 8, 4), "b1, b2, ..."), + ((2, 3, 4, 5), "b1, b2, m, n"), + ] + # TODO(danfm): Remove once jaxlib v0.4.32 is the minimum version. + # jaxlib versions before 0.4.32 require a static shape for the non-batch + # dimensions because these are used for computing the "permuation_size" + # which is passed to lu_pivots_to_permutation. + if jaxlib_version >= (0, 4, 32) or not poly.endswith("m, n") + ], [ # The random primitive tests, with threefry (both partitionable and # non-partitionable), and unsafe_rbg. @@ -3390,9 +3416,6 @@ def test_harness(self, harness: PolyHarness): custom_call_harnesses = { "householder_product:gpu", "vmap_geqrf:gpu", # used for linalg.qr - "vmap_lu:gpu", - # custom_linear_solve works as long as lu works. - "vmap_custom_linear_solve:gpu", "vmap_qr:gpu", "qr:gpu", "vmap_svd:gpu", } @@ -3462,6 +3485,12 @@ def test_harness(self, harness: PolyHarness): if "cholesky" in harness.group_name and jtu.test_device_matches(["tpu"]): harness.tol = 5e-5 + # TODO(b/360788062): Clean up after the compatibility period. + if harness.group_name in [ + "lu", "vmap_lu", "custom_linear_solve", "vmap_custom_linear_solve" + ] and jtu.test_device_matches(["gpu"]): + config_flags = {**config_flags, "jax_export_ignore_forward_compatibility": True} + with jtu.global_config_context(**config_flags): harness.run_test(self) From c41d644886674fe868b3d28b55e0986b848a2a0b Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Tue, 20 Aug 2024 23:59:40 -0500 Subject: [PATCH 189/702] [ROCm] Fix bazel build issue --- jaxlib/rocm/BUILD.bazel | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index ce733d827e35..0c1fe2582603 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -60,8 +60,8 @@ cc_library( rocm_library( name = "hip_make_batch_pointers", - srcs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.cu.cc"], - hdrs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.h"], + srcs = ["//jaxlib/gpu:make_batch_pointers.cu.cc"], + hdrs = ["//jaxlib/gpu:make_batch_pointers.h"], deps = [ ":hip_vendor", "@local_config_rocm//rocm:rocm_headers", From 1e3c079821f5b4811dff37235f1e776eef1b14e4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 21 Aug 2024 07:55:12 -0700 Subject: [PATCH 190/702] Strip primitive params from location info because the amount of metadata included leads to huge HLO size increase and causes compilation cache misses in some other setting too. PiperOrigin-RevId: 665879688 --- jax/_src/interpreters/mlir.py | 7 +++---- jax/_src/pallas/mosaic/lowering.py | 4 +--- jax/_src/pallas/triton/lowering.py | 4 +--- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index e798a6fbdba9..bc1c00948943 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -483,10 +483,9 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location: return loc def _source_info_to_location( - ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any], + ctx: ModuleContext, primitive: core.Primitive, source_info: source_info_util.SourceInfo) -> ir.Location: - eqn_str = (f'{source_info.name_stack}/' - f'{core.str_eqn_compact(primitive, params)}') + eqn_str = f'{source_info.name_stack}/{primitive.name}' if config.include_full_tracebacks_in_locations.value: if source_info.traceback is None: loc = ir.Location.unknown() @@ -1745,7 +1744,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None in_nodes = map(read, eqn.invars) source_info = eqn.source_info.replace( name_stack=name_stack + eqn.source_info.name_stack) - loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info) + loc = _source_info_to_location(ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: override_rule = get_override_lowering_rule(eqn.primitive) platform_rules: dict[str, LoweringRule] = {} diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index aee894ee1b7e..2c5854876009 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -788,9 +788,7 @@ def write_env(var: jax_core.Var, val): source_info = eqn.source_info.replace( name_stack=ctx.name_stack + eqn.source_info.name_stack ) - loc = mlir._source_info_to_location( - ctx, eqn.primitive, eqn.params, source_info - ) + loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: if eqn.primitive in lowering_rules: if eqn.primitive not in skip_mlir_conversions: diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 852ac714d3c9..6db00a53671e 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -382,9 +382,7 @@ def write_env(var: jax_core.Var, val): avals_in = [v.aval for v in eqn.invars] avals_out = [v.aval for v in eqn.outvars] eqn_block_infos = map(read_block_info_env, eqn.invars) - loc = mlir._source_info_to_location( - ctx, eqn.primitive, eqn.params, eqn.source_info - ) + loc = mlir._source_info_to_location(ctx, eqn.primitive, eqn.source_info) rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos) try: with source_info_util.user_context(eqn.source_info.traceback), loc: From d49d070f0eef80f8690fcaf5365eafa9b3b24f63 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 21 Aug 2024 08:34:52 -0700 Subject: [PATCH 191/702] Skip shape polymorphism tests that are incompatible with released jaxlib version. PiperOrigin-RevId: 665893050 --- tests/shape_poly_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index c830678b1869..0a6f955cd5c3 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -3431,6 +3431,9 @@ def test_harness(self, harness: PolyHarness): # TODO(danfm): remove these checks when jaxlib 0.4.32 is released. "lu_pivots_to_permutation:gpu": (0, 4, 32), "lu_pivots_to_permutation_error:gpu": (0, 4, 32), + "lu:gpu": (0, 4, 32), + "vmap_lu:gpu": (0, 4, 32), + "vmap_custom_linear_solve:gpu": (0, 4, 32), } if version_gated.get(name_device_key, jaxlib_version) > jaxlib_version: raise unittest.SkipTest(f"shape polymorphism not supported by jaxlib version {jaxlib_version}") From ce3ea109a43fb159b42b566d19780e70def89dfb Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 21 Aug 2024 08:36:24 -0700 Subject: [PATCH 192/702] [Mosaic GPU] Add a fast type conversion from s8 vectors to bf16 vectors Regular conversion instructions have a ridiculously low throughput on Hopper, so replacing them with some bit tricks yields a much faster implementation. Co-authored-by: Benjamin Chetioui PiperOrigin-RevId: 665893696 --- .../mosaic/gpu/fragmented_array.py | 41 +++++++++++++++++++ tests/mosaic/gpu_test.py | 20 +++++++++ 2 files changed, 61 insertions(+) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 259cfe4ae430..15dab2eeabc1 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -422,9 +422,50 @@ def __getitem__(self, idx): # TODO(apaszke): Support JAX dtypes here as well? def astype(self, new_dtype: ir.Type): + i8 = ir.IntegerType.get_signless(8) + i16 = ir.IntegerType.get_signless(16) + i32 = ir.IntegerType.get_signless(32) + bf16 = ir.BF16Type.get() + cur_dtype = self.mlir_dtype if cur_dtype == new_dtype: return self + reg_type = self.registers.flat[0].type + is_vector_reg = ir.VectorType.isinstance(reg_type) + reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else () + if cur_dtype == i8 and new_dtype == bf16 and reg_shape == (2,): + new_registers = np.empty_like(self.registers) + for idx, reg in np.ndenumerate(self.registers): + reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg) + val_16 = llvm.extractelement(reg_16, c(0, i32)) + # We first embed the s8 into a bf16 with the exponent equal to + # bias + mantissa bits. Then, we zero the msb that didn't fit into the + # mantissa, zero out all bits other than msb, and subtract the last + # two values from each other. This takes advantage of the fact that the + # lsb of the exponent (msb of the second byte) is zero, which allows us + # to losslesly pack the msb there. When 1, it doubles the value of s2, + # making the result negative. + new_val_32 = llvm.inline_asm( + i32, + [val_16], + """ + { + .reg .b32 s<3>; + prmt.b32 s0, $1, 0x43, 0x4140; + and.b32 s1, s0, 0xff7fff7f; + and.b32 s2, s0, 0xff80ff80; + sub.bf16x2 $0, s1, s2; + } + """, + "=r,r", + ) + new_vec = llvm.mlir_undef(ir.VectorType.get((1,), i32)) + new_vec = llvm.insertelement(new_vec, new_val_32, c(0, i32)) + new_registers[idx] = vector.bitcast( + ir.VectorType.get((2,), new_dtype), new_vec + ) + return FragmentedArray(_registers=new_registers, _layout=self.layout) + # Generic path. from_float = ir.FloatType.isinstance(cur_dtype) to_float = ir.FloatType.isinstance(new_dtype) from_integer = ir.IntegerType.isinstance(cur_dtype) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index baad759df38d..f737ce721510 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1192,6 +1192,26 @@ def kernel(ctx, out, *_): np.testing.assert_array_equal(result, x) + @parameterized.named_parameters( + ("_bf16", jnp.bfloat16) + ) + def test_fast_i8_convert(self, jax_dtype_to): + jax_dtype_to = jnp.dtype(jax_dtype_to) + jax_dtype_from = jnp.dtype(jnp.int8) + mlir_dtype_to = mlir.dtype_to_ir_type(jax_dtype_to) + def kernel(ctx, inp, out, smem): + del ctx, smem + arr = mgpu.FragmentedArray.load_strided(inp) + arr.astype(mlir_dtype_to).store_untiled(out) + + x = jnp.arange(-128, 128, dtype=jax_dtype_from) + reference = x.astype(jax_dtype_to) + + result = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, reference, None, + )(x) + np.testing.assert_array_equal(result, reference) + class ProfilerTest(TestCase): From 8c7e798bd22e6e26999be7a6679d4eb942ded0e1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 21 Aug 2024 09:02:32 -0700 Subject: [PATCH 193/702] Fix MSAN use-of-uninitialized-value failure in array_test PiperOrigin-RevId: 665902448 --- tests/array_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/array_test.py b/tests/array_test.py index c2e11268d714..9d62b68bb0a5 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1410,7 +1410,7 @@ def f(x): self.assertArraysEqual(y, y_ref1) def test_empty_mesh_creation(self): - mesh = jax.sharding.Mesh(devices=np.empty([]), axis_names=[]) + mesh = jax.sharding.Mesh(devices=np.empty((), dtype=object), axis_names=[]) self.assertTrue(mesh.empty) self.assertEqual(mesh.size, 0) From d3fd262c9c24859fd3f4e53d69e8b3204ac7230b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 21 Aug 2024 09:39:19 -0700 Subject: [PATCH 194/702] [Mosaic GPU] Replace block barriers with warpgroup barriers Block barriers don't work in warp-specialized kernels. Also, expose the `when` syntax sugar. PiperOrigin-RevId: 665916133 --- jax/experimental/mosaic/gpu/__init__.py | 3 +-- jax/experimental/mosaic/gpu/dsl.py | 2 ++ .../mosaic/gpu/examples/flash_attention.py | 5 +---- jax/experimental/mosaic/gpu/fragmented_array.py | 6 +++--- jax/experimental/mosaic/gpu/utils.py | 15 ++++++++++++++- 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 88e80a79cc76..f5a92cc67f21 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -497,8 +497,7 @@ def await_async_copy( self, allow_groups: int, await_read_only: bool = False ): nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) - # TODO(apaszke): Use a warpgroup barrier!!! - gpu.barrier() # Groups are supposedly tracked per-thread + utils.warpgroup_barrier() # ShapeTrees currently can not contain unions. diff --git a/jax/experimental/mosaic/gpu/dsl.py b/jax/experimental/mosaic/gpu/dsl.py index 82e0aa4abb12..a12e5bc18803 100644 --- a/jax/experimental/mosaic/gpu/dsl.py +++ b/jax/experimental/mosaic/gpu/dsl.py @@ -47,7 +47,9 @@ thread_idx, tile_shape, warp_idx, + warpgroup_barrier, warpgroup_idx, + when, ) from .wgmma import ( WGMMAAccumulator, diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 0675844227ba..a9a533ca361c 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -287,10 +287,7 @@ def kv_loop(kv_step, carry): with ctx.named_region("Acc store"): acc.astype(f16).store_tiled(qo_smem, swizzle=128) - gpu.barrier() - nvvm.fence_proxy( - nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta - ) # Make sure the store is visible to the TMA. + commit_shared() # Make sure the store is visible to the TMA. with ctx.named_region("GMEM store"): ctx.async_copy( diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 15dab2eeabc1..406f5eaba3b2 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -522,9 +522,9 @@ def reduce_sum(self, scratch) -> ir.Value: warp_result = utils.warp_tree_reduce(result, op, 32) warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index)) memref.store(warp_result, scratch, [warp_id]) - utils.commit_shared() + utils.warpgroup_barrier() zero_index = c(0, index) - with mgpu.single_thread(): + with mgpu.single_thread(per_block=False): scratch_vec = vector.load( ir.VectorType.get((4,), self.mlir_dtype), scratch, @@ -534,7 +534,7 @@ def reduce_sum(self, scratch) -> ir.Value: self.mlir_dtype, vector.CombiningKind.ADD, scratch_vec ) memref.store(scratch_sum, scratch, [zero_index]) - utils.commit_shared() + utils.warpgroup_barrier() return memref.load(scratch, [zero_index]) def reduce(self, op, axis): diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index c892c672ffe2..64c9f409ef7f 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -503,12 +503,25 @@ def parse_indices( def commit_shared(): - gpu.barrier() + warpgroup_barrier() nvvm.fence_proxy( nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta ) +def warpgroup_barrier(): + # gpu.barrier() uses barrier number 0, and it would be unsafe to reuse it, + # so we shift the warpgroup index by 1. + i32 = ir.IntegerType.get_signless(32) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [arith.addi(warpgroup_idx(sync=False), c(1, i32))], + f"bar.sync $0, {WARPGROUP_SIZE};", + "r", + has_side_effects=True, + ) + + @dataclasses.dataclass(frozen=True) class BarrierRef: base_address: ir.Value From 9d1cc33e39526b951f3119b58795cbf4c84fd296 Mon Sep 17 00:00:00 2001 From: Zhuo Peng Date: Wed, 21 Aug 2024 09:48:59 -0700 Subject: [PATCH 195/702] Relaxed the assertion for is_same_structure in `jax2tf.call_tf` so that `tf_fun` may mutate the structure of its input parameters. PiperOrigin-RevId: 665919824 --- jax/experimental/jax2tf/call_tf.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 9018f781198c..04dd9f17933d 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -255,7 +255,11 @@ def replace_non_float_or_none(arg_tf): lambda x: x if x is None else tf.convert_to_tensor(x), dres_darg, ) - tf.nest.assert_same_structure(dres_darg, args_tf) + + # callable_tf may mutate (the structure of) args_tf, thus we check against + # watched_args_tf which should be structurally the same as the original + # args_tf. + tf.nest.assert_same_structure(dres_darg, watched_args_tf) return dres_darg # Use call_tf to call the VJP function From abd442b12a967f1738691b145c557df5df555dcc Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 21 Aug 2024 10:56:14 -0700 Subject: [PATCH 196/702] Reverts 1e3c079821f5b4811dff37235f1e776eef1b14e4 PiperOrigin-RevId: 665947283 --- jax/_src/interpreters/mlir.py | 7 ++++--- jax/_src/pallas/mosaic/lowering.py | 4 +++- jax/_src/pallas/triton/lowering.py | 4 +++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index bc1c00948943..e798a6fbdba9 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -483,9 +483,10 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location: return loc def _source_info_to_location( - ctx: ModuleContext, primitive: core.Primitive, + ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any], source_info: source_info_util.SourceInfo) -> ir.Location: - eqn_str = f'{source_info.name_stack}/{primitive.name}' + eqn_str = (f'{source_info.name_stack}/' + f'{core.str_eqn_compact(primitive, params)}') if config.include_full_tracebacks_in_locations.value: if source_info.traceback is None: loc = ir.Location.unknown() @@ -1744,7 +1745,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None in_nodes = map(read, eqn.invars) source_info = eqn.source_info.replace( name_stack=name_stack + eqn.source_info.name_stack) - loc = _source_info_to_location(ctx, eqn.primitive, source_info) + loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: override_rule = get_override_lowering_rule(eqn.primitive) platform_rules: dict[str, LoweringRule] = {} diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 2c5854876009..aee894ee1b7e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -788,7 +788,9 @@ def write_env(var: jax_core.Var, val): source_info = eqn.source_info.replace( name_stack=ctx.name_stack + eqn.source_info.name_stack ) - loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info) + loc = mlir._source_info_to_location( + ctx, eqn.primitive, eqn.params, source_info + ) with source_info_util.user_context(eqn.source_info.traceback), loc: if eqn.primitive in lowering_rules: if eqn.primitive not in skip_mlir_conversions: diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 6db00a53671e..852ac714d3c9 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -382,7 +382,9 @@ def write_env(var: jax_core.Var, val): avals_in = [v.aval for v in eqn.invars] avals_out = [v.aval for v in eqn.outvars] eqn_block_infos = map(read_block_info_env, eqn.invars) - loc = mlir._source_info_to_location(ctx, eqn.primitive, eqn.source_info) + loc = mlir._source_info_to_location( + ctx, eqn.primitive, eqn.params, eqn.source_info + ) rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos) try: with source_info_util.user_context(eqn.source_info.traceback), loc: From ce2306bbc18696627b8d448dcdd795be448fa80f Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 21 Aug 2024 11:10:32 -0700 Subject: [PATCH 197/702] [Pallas] Add interpret mode rules for semaphores (local signal, wait, read, DMAs). PiperOrigin-RevId: 665953666 --- jax/_src/pallas/core.py | 6 + jax/_src/pallas/mosaic/core.py | 4 +- jax/_src/pallas/mosaic/primitives.py | 168 ++++++++++++++++++++++----- jax/_src/pallas/primitives.py | 5 + jax/_src/state/discharge.py | 4 + tests/pallas/tpu_pallas_test.py | 109 +++++++++++------ 6 files changed, 229 insertions(+), 67 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index f1d79e1dbf85..03bfd28d0b9a 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -59,6 +59,12 @@ def __repr__(self): GridMappingGrid = tuple[int | DynamicGridDim, ...] OriginStr = str # The origin of a block spec, e.g. input[2]["field"] +# Datatype for semaphore values in interpret mode. +# For now, we choose a relatively uncommon datatype (i16) so it is more easily +# identifiable in kernels. +# TODO(justinfu): Handle semaphores with a custom extended dtype. +SEMAPHORE_INTERPRET_DTYPE = jnp.int16 + @dataclasses.dataclass(frozen=True) class CompilerParams: diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index abc7aca59cc7..94b53a4067f0 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -99,7 +99,7 @@ class barrier_semaphore(semaphore_dtype): pass class AbstractSemaphoreTyRules: @staticmethod def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: - return pallas_core.index_map_grid_aval + return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE) class AbstractSemaphoreTy(dtypes.ExtendedDType): name: str @@ -142,7 +142,7 @@ def __call__(self, shape: tuple[int, ...]): else: dtype = SemaphoreTy() if pallas_core.is_interpret_mode(): - dtype = jnp.int32 + dtype = pallas_core.SEMAPHORE_INTERPRET_DTYPE return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) def get_aval(self) -> AbstractMemoryRef: diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 7ddd8fb9d8c5..60d3d23cb884 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -159,7 +159,10 @@ class DeviceIdType(enum.Enum): def check_sem_avals(sem_aval, sem_indexers_avals, name, allowed_semaphore_types=None): if allowed_semaphore_types is None: - allowed_semaphore_types = {tpu_core.semaphore, tpu_core.barrier_semaphore} + allowed_semaphore_types = {tpu_core.semaphore, + tpu_core.barrier_semaphore, + # For interpret mode. + pl_core.SEMAPHORE_INTERPRET_DTYPE} if not isinstance(sem_aval, state.AbstractRef): raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") sem_shape = sem_aval.shape @@ -174,9 +177,20 @@ def check_sem_avals(sem_aval, sem_indexers_avals, name, allowed_semaphore_types= ): raise ValueError( f"Must {name} semaphores of the following types:" - f" {allowed_semaphore_types}" + f" {allowed_semaphore_types}." ) +def _index_semaphore(ref_value, indexers, ref_aval): + """Helper function for indexing into a semaphore during state_discharge.""" + if ref_value.shape == ref_aval.shape: + return state_discharge.index_array(ref_value, indexers) + elif len(ref_value.shape) == 0: + return ref_value + else: + raise ValueError( + f"Semaphore value shape {ref_value.shape} does not match aval shape" + f" {ref_aval.shape}" + ) semaphore_read_p = jax_core.Primitive("semaphore_read") semaphore_read_p.multiple_results = False @@ -199,11 +213,27 @@ def _semaphore_read_abstract_eval( sem_indexers_avals, "read", allowed_semaphore_types={ - tpu_core.dma_semaphore, tpu_core.semaphore, tpu_core.barrier_semaphore + tpu_core.dma_semaphore, + tpu_core.semaphore, + tpu_core.barrier_semaphore, + pl_core.SEMAPHORE_INTERPRET_DTYPE, }, ) return jax_core.ShapedArray((), jnp.dtype("int32")) +def _semaphore_read_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, indexers] = args_tree.unflatten(flat_args) + sem_value = _index_semaphore(ref, indexers, in_avals[0]) + sem_value = sem_value.astype(jnp.int32) + return (None,) * len(in_avals), sem_value +state_discharge.register_discharge_rule(semaphore_read_p)( + _semaphore_read_discharge_rule +) + semaphore_signal_p = jax_core.Primitive('semaphore_signal') semaphore_signal_p.multiple_results = True @@ -281,6 +311,29 @@ def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, return out jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn + +def _semaphore_signal_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree, + device_id_type): + del out_avals, device_id_type + [ref, indexers, inc, device_id, core_index] = args_tree.unflatten(flat_args) + if device_id is not None: + raise NotImplementedError("Remote signal not implemented.") + if core_index is not None: + raise NotImplementedError("Multiple core support not implemented.") + sem_value = _index_semaphore(ref, indexers, in_avals[0]) + inc = inc.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.index_swap_array( + ref, indexers, sem_value + inc + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_signal_p)( + _semaphore_signal_discharge_rule +) + + semaphore_wait_p = jax_core.Primitive('semaphore_wait') semaphore_wait_p.multiple_results = True @@ -319,6 +372,22 @@ def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, ]) jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn +def _semaphore_wait_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, indexers, dec] = args_tree.unflatten(flat_args) + sem_value = _index_semaphore(ref, indexers, in_avals[0]) + dec = dec.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.index_swap_array( + ref, indexers, sem_value -dec + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_wait_p)( + _semaphore_wait_discharge_rule +) + @dataclasses.dataclass class AsyncCopyDescriptor: @@ -464,22 +533,26 @@ def dma_start_discharge_rule(in_avals, out_avals, src_indexers_avals, _, dst_indexers_avals, - *_ + dst_sem_aval, + dst_sem_indexers_avals, + src_sem_aval, + src_sem_indexers_avals, + _, ) = tree_util.tree_unflatten(tree, in_avals) - del out_avals, dst_sem, dst_sem_indexers + del out_avals is_remote = device_id is not None if not is_remote: # Local async copies only use one semaphore. assert src_sem is None assert src_sem_indexers is None + num_src_sem_indexers = len(tree_util.tree_leaves(src_sem_indexers_avals)) + num_dst_sem_indexers = len(tree_util.tree_leaves(dst_sem_indexers_avals)) num_src_index_vals = len(tree_util.tree_leaves(src_indexers_avals)) num_dst_index_vals = len(tree_util.tree_leaves(dst_indexers_avals)) - if src_indexers: - updates = state_discharge.index_array(src_ref, src_indexers) - else: - updates = src_ref + updates = state_discharge.index_array(src_ref, src_indexers) + local_src = updates if is_remote: # Note that this code only works in SPMD mode. If not all devices execute @@ -520,6 +593,9 @@ def dma_start_discharge_rule(in_avals, out_avals, global_updates = jax.lax.all_gather(updates, shard_axis) updates = jax.lax.dynamic_index_in_dim( global_updates, index, axis=0, keepdims=False) + global_dst_sem = jax.lax.all_gather(dst_sem, shard_axis) + dst_sem = jax.lax.dynamic_index_in_dim( + global_dst_sem, index, axis=0, keepdims=False) # Handle asymmetrical indexing when devices do not share the same # dst_indexer. @@ -528,25 +604,45 @@ def dma_start_discharge_rule(in_avals, out_avals, dst_indexers = tree_util.tree_map( lambda x: jax.lax.dynamic_index_in_dim( x, index, axis=0, keepdims=False), global_dst_indexers) + global_dst_sem_indexers = tree_util.tree_map( + lambda x: jax.lax.all_gather(x, shard_axis), dst_sem_indexers) + dst_sem_indexers = tree_util.tree_map( + lambda x: jax.lax.dynamic_index_in_dim( + x, index, axis=0, keepdims=False), global_dst_sem_indexers) - if dst_indexers: - _, new_dst = state_discharge.index_swap_array( - dst_ref, dst_indexers, updates + _, new_dst = state_discharge.index_swap_array( + dst_ref, dst_indexers, updates + ) + + # Update semaphore values. + recv_size = jnp.array(updates.size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) + dst_sem_value = _index_semaphore(dst_sem, dst_sem_indexers, dst_sem_aval) + _, new_dst_sem = state_discharge.index_swap_array( + dst_sem, dst_sem_indexers, dst_sem_value + recv_size + ) + if is_remote: + send_size = jnp.array( + local_src.size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) + src_sem_value = _index_semaphore(src_sem, src_sem_indexers, src_sem_aval) + _, new_src_sem = state_discharge.index_swap_array( + src_sem, src_sem_indexers, src_sem_value + send_size ) else: - new_dst = updates - - # TODO(b/345505876): Implement semaphore counting. - new_avals = (None,) # src_aval - new_avals += (None,) * num_src_index_vals - new_avals += (new_dst,) # dst_aval - new_avals += (None,) * num_dst_index_vals - new_avals += (None,) # dst_sem_aval + new_src_sem = None + + new_vals = (None,) # src_val + new_vals += (None,) * num_src_index_vals + new_vals += (new_dst,) # dst_val + new_vals += (None,) * num_dst_index_vals + new_vals += (new_dst_sem,) # dst_sem + new_vals += (None,) * num_dst_sem_indexers if is_remote: - new_avals += (None, None) # src_sem_aval, device_id - assert (len(new_avals) == - len(in_avals)), f"{len(new_avals), new_avals} != {len(in_avals)}" - return new_avals, [] + new_vals += (new_src_sem,) # src_sem + new_vals += (None,) * num_src_sem_indexers + new_vals += (None,) # device_id + assert (len(new_vals) == + len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}" + return new_vals, [] state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule) @@ -578,9 +674,27 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, def dma_wait_discharge_rule(in_avals, out_avals, *args, tree, device_id_type): - del out_avals, args, tree, device_id_type - # TODO(justinfu): Implement semaphore counting. - return (None,) * len(in_avals), [] + del out_avals, device_id_type + (sem, sem_indexers, ref, ref_indexers) = tree_util.tree_unflatten(tree, args) + ( + sem_aval, + sem_indexers_avals, + _, + ref_indexers_avals, + ) = tree_util.tree_unflatten(tree, in_avals) + num_sem_indexers = len(tree_util.tree_leaves(sem_indexers_avals)) + num_indexers = len(tree_util.tree_leaves(ref_indexers_avals)) + updates = state_discharge.index_array(ref, ref_indexers) + copy_size = jnp.array(updates.size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) + sem_value = _index_semaphore(sem, sem_indexers, sem_aval) + _, new_sem = state_discharge.index_swap_array( + sem, sem_indexers, sem_value - copy_size + ) + new_vals = (new_sem,) # sem + new_vals += (None,) * num_sem_indexers + new_vals += (None,) # ref + new_vals += (None,) * num_indexers + return new_vals, [] state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule) def _get_ref_and_indexers(ref): diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index db364820e443..2d4ca1b8ca5b 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -421,6 +421,11 @@ def _load_jvp(primals, tangents, args_tree, **params): def uninitialized_value(shape, dtype): if jnp.issubdtype(dtype, jnp.floating): return jnp.full(shape, jnp.nan, dtype) + # Note: Currently semaphore is i16[], meaning this case needs to be + # handled before the general case for integers. + # TODO(justinfu): Handle semaphores with a custom extended dtype. + elif jnp.issubdtype(dtype, pallas_core.SEMAPHORE_INTERPRET_DTYPE): + return jnp.full(shape, 0, dtype) elif jnp.issubdtype(dtype, jnp.integer): return jnp.full(shape, jnp.iinfo(dtype).min, dtype) elif jnp.issubdtype(dtype, jnp.bool): diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 8666f4cb08f4..1feac75eb530 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -241,6 +241,8 @@ def _prepend_scatter(x, indexer, val, *, add=False): def index_array(x, indexers): + if indexers is None: + indexers = [] result = x for indexer in indexers: if _is_trivial_indexer(indexer): @@ -260,6 +262,8 @@ def index_array(x, indexers): return result def index_swap_array(x, indexers, val): + if indexers is None: + indexers = [] result = x result_val = val # Compute updated "val" (result). diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 94d169713c34..e1c01a4e84ec 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -886,8 +886,8 @@ def body(dma_sems, sems): self.assertTupleEqual(dma_sems.shape, (4,)) self.assertTupleEqual(sems.shape, (3,)) if self.INTERPRET: - self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.int32)) - self.assertTrue(jnp.issubdtype(sems.dtype, jnp.int32)) + self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.integer)) + self.assertTrue(jnp.issubdtype(sems.dtype, jnp.integer)) else: self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) @@ -905,8 +905,8 @@ def kernel(y_ref, dma_sems, sems): self.assertTupleEqual(dma_sems.shape, (4,)) self.assertTupleEqual(sems.shape, (3,)) if self.INTERPRET: - self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.int32)) - self.assertTrue(jnp.issubdtype(sems.dtype, jnp.int32)) + self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.integer)) + self.assertTrue(jnp.issubdtype(sems.dtype, jnp.integer)) else: self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) @@ -926,10 +926,6 @@ def kernel(y_ref, dma_sems, sems): ) def test_can_wait_on_semaphore(self): - # TODO(b/345534352): Add interpret support for semaphore signal/wait. - if self.INTERPRET: - self.skipTest('Semaphore signal/wait not supported in interpret mode.') - def kernel(y_ref): def body(sem): pltpu.semaphore_signal(sem) @@ -956,10 +952,6 @@ def body3(sem): )()) def test_can_wait_on_semaphore_array(self): - # TODO(b/345534352): Add interpret support for semaphore signal/wait. - if self.INTERPRET: - self.skipTest('Semaphore signal/wait not supported in interpret mode.') - def kernel(y_ref): def body(sems): pltpu.semaphore_signal(sems.at[0]) @@ -984,10 +976,6 @@ def body(sems): )()) def test_can_wait_on_semaphore_array_with_dynamic_index(self): - # TODO(b/345534352): Add interpret support for semaphore signal/wait. - if self.INTERPRET: - self.skipTest('Semaphore signal/wait not supported in interpret mode.') - def kernel(y_ref): i = pl.program_id(0) def body(sems): @@ -1017,10 +1005,6 @@ def body(sems): ) def test_can_read_semaphore(self): - # TODO(b/345534352): Add interpret support for semaphore signal/wait. - if self.INTERPRET: - self.skipTest('Semaphore signal/wait not supported in interpret mode.') - m, n = 2, 3 def kernel(y_ref): @@ -1034,7 +1018,6 @@ def body(sems): pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((m, n))) - # TODO(b/345534352): Add interpret support for semaphore signal/wait. y = jax.block_until_ready( self.pallas_call( kernel, @@ -1047,16 +1030,12 @@ def body(sems): ) def test_can_read_dma_semaphore(self): - # TODO(b/345534352): Add interpret support for semaphore signal/wait. - if self.INTERPRET: - self.skipTest('Semaphore signal/wait not supported in interpret mode.') - def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): sem_val_ref[0, 0] = 123 pltpu.async_copy(x_hbm_ref, y_hbm_ref, dma_sem).wait() sem_val_ref[0, 0] = pltpu.semaphore_read(dma_sem) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) - # TODO(b/345534352): Add interpret support for semaphore signal/wait. y, sem_val = jax.block_until_ready( self.pallas_call( kernel, @@ -1102,10 +1081,9 @@ def body(sem): sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) - # TODO(b/345534352): Add interpret support for nonscalar semaphores. with self.assertRaisesRegex(ValueError, 'Cannot signal'): x = jnp.arange(8 * 128.).reshape((8, 128)) - pl.pallas_call( + self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -1122,8 +1100,7 @@ def body(sem): pl.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) x = jnp.arange(8 * 128.).reshape((8, 128)) - # TODO(b/345534352): Add interpret support for nonscalar semaphores. - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -1425,9 +1402,8 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): pltpu.semaphore_wait(sem) pltpu.async_copy(x_bbm_ref, y_ref, dma_sem).wait() - # TODO(b/345534352): Add interpret support for semaphore signal/wait. x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, @@ -1473,27 +1449,36 @@ class PallasCallDMAInterpretTest(PallasCallDMATest): INTERPRET = True def test_interpret_local_dma(self): + # We run this test in interpret mode to test semaphore counting. + # On a physical device the values update asynchronously so we cannot + # deterministically check the values. def test_kernel(x_ref, o_ref, + sem_out_ref, copy_sem, ): o_ref[...] = jnp.zeros_like(o_ref[...]) input_to_output_copy = pltpu.make_async_copy( src_ref=x_ref.at[0:8], dst_ref=o_ref.at[0:8], - sem=copy_sem, + sem=copy_sem.at[0], ) input_to_output_copy.start() + sem_out_ref[0, :] = jnp.ones_like( + sem_out_ref[0, :]) * pltpu.semaphore_read(copy_sem.at[0]) input_to_output_copy.wait() + sem_out_ref[1, :] = jnp.ones_like( + sem_out_ref[0, :]) * pltpu.semaphore_read(copy_sem.at[0]) - out_shape = (jax.ShapeDtypeStruct((9, 128), jnp.float32)) + out_shape = (jax.ShapeDtypeStruct((16, 128), jnp.int32), + jax.ShapeDtypeStruct((2, 1), jnp.int32)) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), ], scratch_shapes=( - [pltpu.SemaphoreType.DMA] + [pltpu.SemaphoreType.DMA(2,)] ) ) @@ -1501,13 +1486,61 @@ def test_kernel(x_ref, test_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=True + interpret=True, ) - x = jax.random.normal(jax.random.key(0), shape=(16, 128)) - result = kernel(x) + x = jax.random.randint( + jax.random.key(0), shape=(16, 128), minval=0, maxval=128) + + result, semaphores = kernel(x) np.testing.assert_array_equal(result[0:8], x[0:8]) np.testing.assert_array_equal(result[8:], jnp.zeros_like(result[8:])) + # Make sure semaphores have the correct value before and after DMA wait. + result_sem_pre_wait = semaphores[0, 0] + np.testing.assert_array_equal(result_sem_pre_wait, result[0:8].size) + result_sem_post_wait = semaphores[1, 0] + np.testing.assert_array_equal(result_sem_post_wait, 0) + + def test_interpreter_semaphore_counting(self): + # We run this test in interpret mode because the kernel exits with + # non-zero values. In normal Pallas this would crash the kernel. + def test_kernel(o_ref, + sem_ref, + ): + o_ref[...] = jnp.zeros_like(o_ref) + pltpu.semaphore_signal(sem_ref.at[0], 1) + pltpu.semaphore_signal(sem_ref.at[1], 2) + pltpu.semaphore_signal(sem_ref.at[2], 3) + pltpu.semaphore_signal(sem_ref.at[3], 4) + o_ref[0, 0] = pltpu.semaphore_read(sem_ref.at[0]) + o_ref[1, 0] = pltpu.semaphore_read(sem_ref.at[1]) + o_ref[2, 0] = pltpu.semaphore_read(sem_ref.at[2]) + o_ref[3, 0] = pltpu.semaphore_read(sem_ref.at[3]) + pltpu.semaphore_wait(sem_ref.at[0], 4) + pltpu.semaphore_wait(sem_ref.at[1], 3) + pltpu.semaphore_wait(sem_ref.at[2], 2) + pltpu.semaphore_wait(sem_ref.at[3], 1) + o_ref[4, 0] = pltpu.semaphore_read(sem_ref.at[0]) + o_ref[5, 0] = pltpu.semaphore_read(sem_ref.at[1]) + o_ref[6, 0] = pltpu.semaphore_read(sem_ref.at[2]) + o_ref[7, 0] = pltpu.semaphore_read(sem_ref.at[3]) + + out_shape = jax.ShapeDtypeStruct((8, 1), jnp.int32) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + scratch_shapes=( + [pltpu.SemaphoreType.DMA(4,)] + ) + ) + results = pl.pallas_call( + test_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + interpret=True, + )() + expected = jnp.array([1, 2, 3, 4, -3, -1, 1, 3]).reshape(out_shape.shape) + np.testing.assert_array_equal(results, expected) + class PallasCallTest(PallasBaseTest): From 558000df7c0fed74e86eac3e37cfa6ff0f97a4a1 Mon Sep 17 00:00:00 2001 From: kaixih Date: Thu, 18 Jul 2024 17:03:49 +0000 Subject: [PATCH 198/702] Support variable sequence lengths --- jax/_src/nn/functions.py | 128 ++++++++++++++++++++++++++++----------- tests/nn_test.py | 76 +++++++++++++++++------ 2 files changed, 149 insertions(+), 55 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 821c4a413796..b49e16d95408 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -35,7 +35,7 @@ from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention as cudnn_dot_product_attention, MaskType) from jax._src.numpy import util as numpy_util -from jax._src.typing import Array, ArrayLike +from jax._src.typing import Array, ArrayLike, DType from jax._src.ops.special import logsumexp as _logsumexp @@ -781,13 +781,48 @@ def _get_large_negative(dtype): dtype_max = jnp.finfo(dtype).max return jnp.asarray(-0.7 * dtype_max, dtype=dtype) -def _get_causal_mask(T, S, dtype): - pred = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) - mask = jnp.where(pred, jnp.asarray(0.0, dtype), _get_large_negative(dtype)) - return mask +def _get_causal_mask(T, S): + mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) + return mask[None, None, :, :] + +def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen): + q_indices = jnp.arange(0, T)[None, :, None] + kv_indices = jnp.arange(0, S)[None, None, :] + q_mask = q_indices < q_seqlen[:, None, None] + kv_mask = kv_indices < kv_seqlen[:, None, None] + mask = jnp.logical_and(q_mask, kv_mask) + return mask[:, None, :, :] + +def _get_padding_mask_encoded(T, q_seqlen): + q_indices = jnp.arange(0, T)[None, :] + mask = q_indices < q_seqlen[:, None] + return mask[:, :, None, None] + +def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): + if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None: + return logits + + combined_mask = jnp.ones_like(logits, dtype=jnp.bool_) + if mask is not None: + assert mask.dtype == jnp.bool_ + combined_mask = jnp.logical_and(combined_mask, mask) + + T, S = logits.shape[2], logits.shape[3] + + if is_causal: + mask = _get_causal_mask(T, S) + combined_mask = jnp.logical_and(combined_mask, mask) + + if q_seqlen is not None and kv_seqlen is not None: + mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen) + combined_mask = jnp.logical_and(combined_mask, mask) + + large_negative_number = _get_large_negative(logits.dtype) + padded_logits = jnp.where(combined_mask, logits, large_negative_number) + return padded_logits def _dot_product_attention_core(query, key, value, bias, mask, is_causal, - scale): + scale, q_seqlen, kv_seqlen): logits_dtype = jnp.promote_types(query.dtype, jnp.float32) logits = jnp.einsum('BTNH,BSNH->BNTS', query, key, preferred_element_type=logits_dtype) @@ -797,24 +832,16 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, if bias is not None: logits = (logits + bias).astype(logits.dtype) - if mask is not None: - assert mask.dtype == jnp.bool_ - large_negative_number = _get_large_negative(logits.dtype) - padded_logits = jnp.where(mask, logits, large_negative_number) - else: - padded_logits = logits - - if is_causal: - T, S = query.shape[1], key.shape[1] - mask = jnp.broadcast_to(_get_causal_mask(T, S, logits.dtype), - padded_logits.shape) - padded_logits = padded_logits + mask + padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen) # Softmax and it is always carried out in fp32. padded_logits = padded_logits.astype(jnp.float32) probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype) encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value) + if q_seqlen is not None and kv_seqlen is not None: + mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen) + encoded *= mask.astype(encoded.dtype) return encoded def _dot_product_attention_xla( @@ -824,7 +851,9 @@ def _dot_product_attention_xla( bias: Array | None, mask: Array | None, is_causal: bool, - scale: float): + scale: float, + q_seqlen: Array | None, + kv_seqlen: Array | None): B, T, N, H = query.shape _, S, K, _ = key.shape @@ -843,9 +872,10 @@ def _reshape_to_grouped(t): bias = _reshape_to_grouped(bias) mask = _reshape_to_grouped(mask) vmapped_fn = jax.vmap(_dot_product_attention_core, - in_axes=(3, None, None, 2, 2, None, None), + in_axes=(3, None, None, 2, 2, None, None, None, None), out_axes=3) - encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale) + encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale, + q_seqlen, kv_seqlen) encoded = jnp.reshape(encoded, (B, T, N, H)) return encoded @@ -858,6 +888,8 @@ def dot_product_attention( *, scale: float | None = None, is_causal: bool = False, + query_seq_lengths: ArrayLike | None = None, + key_value_seq_lengths: ArrayLike | None = None, implementation: Literal['xla', 'cudnn'] | None = None) -> Array: r"""Scaled dot product attention function. @@ -903,6 +935,10 @@ def dot_product_attention( logits to mask out the non-causal parts of the attention matrix, but other implementations like `cudnn` will avoid computing the non-causal regions, providing speedups. + query_seq_lengths: `int32` array of sequence lengths for query; shape + :code:`(B)` + key_value_seq_lengths: `int32` array of sequence lengths for key and value; + shape :code:`(B)` implementation: A string to control which implementation backend to use. Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults to `None`, which will automatically select the best available backend. @@ -925,46 +961,64 @@ def _ensure_4d(t): value_arr = _ensure_4d(value) bias = _ensure_4d(bias) if bias is not None else None mask = _ensure_4d(mask) if mask is not None else None - - def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: + if query_seq_lengths is not None: + query_seq_lengths = jnp.asarray(query_seq_lengths) + if key_value_seq_lengths is not None: + key_value_seq_lengths = jnp.asarray(key_value_seq_lengths) + + def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], + dtype: DType | None, name: str) -> None: + if t is None: + return if t.ndim != len(shape): raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}") + if dtype is not None and t.dtype != dtype: + raise ValueError(f"{name} dtype should be {dtype}, but got {t.dtype}") for i in range(t.ndim): if shape[i] != -1 and t.shape[i] != shape[i]: raise ValueError(f"{name} shape should be {shape}: but got {t.shape}") B, S, K, H = key_arr.shape - _check_has_shape(value_arr, [B, S, K, H], 'value') - _check_has_shape(query_arr, [B, -1, -1, H], 'query') + _check_shape_and_dtype(value_arr, [B, S, K, H], key_arr.dtype, 'value') + _check_shape_and_dtype(query_arr, [B, -1, -1, H], key_arr.dtype, 'query') + _check_shape_and_dtype(mask, [-1] * 4, jnp.bool_, 'mask') + _check_shape_and_dtype(bias, [-1] * 4, None, 'bias') + _check_shape_and_dtype(query_seq_lengths, [B], jnp.int32, + 'query_seq_lengths') + _check_shape_and_dtype(key_value_seq_lengths, [B], jnp.int32, + 'key_value_seq_lengths') if query_arr.shape[-2] % K != 0: raise ValueError(f"The number of query heads must be a multiple of " f"key/value heads, but got {query_arr.shape[-2]} vs {K}") - if not (query_arr.dtype == key_arr.dtype == value_arr.dtype): - raise ValueError(f"query/key/value should have the same dtype, but got " - f"{query_arr.dtype} vs {key_arr.dtype} vs {value_arr.dtype}.") - if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4: - raise ValueError(f"Mask must be a 4D boolean tensor, but got " - f"rank={mask.ndim}, dtype={mask.dtype}.") - if bias is not None and bias.ndim != 4: - raise ValueError(f"Bias must be a 4D tensor, but got rank={bias.ndim}.") scale_val = (1.0 / np.sqrt(H)) if scale is None else scale match implementation: case 'xla': out = _dot_product_attention_xla( - query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, + query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, + scale=scale_val, q_seqlen=query_seq_lengths, + kv_seqlen=key_value_seq_lengths, ) case 'cudnn': - mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK + mask_type = MaskType.NO_MASK + if query_seq_lengths is not None and is_causal: + mask_type = MaskType.PADDING_CAUSAL + elif is_causal: + mask_type = MaskType.CAUSAL + elif query_seq_lengths is not None: + mask_type = MaskType.PADDING out = cudnn_dot_product_attention( - query_arr, key_arr, value_arr, bias, mask, scale=scale_val, mask_type=mask_type + query_arr, key_arr, value_arr, bias, mask, query_seq_lengths, + key_value_seq_lengths, scale=scale_val, mask_type=mask_type ) case None: # TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select # best backend. out = _dot_product_attention_xla( - query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, + query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, + scale=scale_val, q_seqlen=query_seq_lengths, + kv_seqlen=key_value_seq_lengths, ) case _: raise ValueError(f"Unsupported implementation option: {implementation}") diff --git a/tests/nn_test.py b/tests/nn_test.py index 802ed1b2f1e2..a79cf738714b 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -55,17 +55,21 @@ class NNFunctionsTest(jtu.JaxTestCase): @parameterized.product( dtype=[jnp.float32, jnp.bfloat16, jnp.float16], use_bias=[False, True], - causal_mode=[None, 'is_causal', 'is_mask'], + causal_mode=[None, 'attr', 'mask'], group_num=[1, 2, 4], use_vmap=[False, True], + use_seqlen=[False, True], impl=['xla', 'cudnn'], ) def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, - group_num, use_vmap, impl): + group_num, use_vmap, use_seqlen, impl): if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") + if use_vmap and use_seqlen: + raise unittest.SkipTest("vmap cannot be used together with variable " + "seqence lengths") sdpa = nn.dot_product_attention B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num @@ -77,41 +81,60 @@ def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, bias = random.normal(keys[3], (1, N, T, S), dtype) else: bias = None + if use_seqlen: + q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32) + kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32) + else: + q_seqlen = None + kv_seqlen = None - is_causal = causal_mode == 'is_causal' - causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None + is_causal = causal_mode == 'attr' + causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) if impl == 'cudnn': - lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask) + lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask, + query_seq_lengths=q_seqlen, + key_value_seq_lengths=kv_seqlen) hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) self.assertIn('__cudnn$fmha', hlo) - if use_vmap: - sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V - out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask) + out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask, + query_seq_lengths=q_seqlen, + key_value_seq_lengths=kv_seqlen) + if use_vmap: + sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) - out_ans = sdpa_ans(Q, K, V, bias, causal_mask) + out_ans = sdpa_ans(Q, K, V, bias, causal_mask, + query_seq_lengths=q_seqlen, + key_value_seq_lengths=kv_seqlen) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) @parameterized.product( dtype=[jnp.float32, jnp.bfloat16, jnp.float16], use_bias=[False, True], - causal_mode=[None, 'is_causal', 'is_mask'], + causal_mode=[None, 'attr', 'mask'], group_num=[1, 2, 4], use_vmap=[False, True], + use_seqlen=[False, True], impl=['xla', 'cudnn'], ) def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, - group_num, use_vmap, impl): + group_num, use_vmap, use_seqlen, impl): if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") + if use_vmap and use_seqlen: + raise unittest.SkipTest("vmap cannot be used together with variable " + "seqence lengths") + if use_seqlen and use_bias and impl == 'cudnn': + raise unittest.SkipTest("cudnn has limited support for dbias when using " + "variable seqence lengths") sdpa = nn.dot_product_attention B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num @@ -124,24 +147,41 @@ def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, bias = random.normal(keys[4], (1, N, T, S), dtype) else: bias = None + if use_seqlen: + q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32) + kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32) + else: + q_seqlen = None + kv_seqlen = None - is_causal = causal_mode == 'is_causal' - causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None + is_causal = causal_mode == 'attr' + causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - _, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, causal_mask) - dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad) + # Convert the keyword arguments to positional ones. + fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref( + q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs + ) + _, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask, + q_seqlen, kv_seqlen) + dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4] if G != 1: dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) - if use_vmap: + if use_vmap and not use_seqlen: sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) - _, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask) - dQ_ans, dK_ans, dV_ans, dbias_ans, _ = sdpa_vjp_ans(grad) + _, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask) + else: + fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( + q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs + ) + _, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask, q_seqlen, + kv_seqlen) + dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4] if impl == 'cudnn': lowered = jax.jit(sdpa_vjp_ans).lower(grad) From 0105254ab1da55647889485ab7611328772f6e8b Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 21 Aug 2024 11:58:30 -0700 Subject: [PATCH 199/702] Unbreak Mosaic after https://github.com/llvm/llvm-project/commit/42944da5ba7617bbc02f341e9ef401c325310a73 PiperOrigin-RevId: 665973530 --- jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc index 4d5e62049098..7006c1c2402a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc @@ -477,6 +477,7 @@ struct LinalgVectorizationPass // contract ops will help to sustain the structure through various // transformations. vector::populateVectorReductionToContractPatterns(patterns); + vector::populateSinkVectorOpsPatterns(patterns); // Pull in patterns to canonicalize transfer ops. vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); From 08fc5c024363324880da1a02d290783f533f4c5a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 21 Aug 2024 14:10:48 -0700 Subject: [PATCH 200/702] Reverts abd442b12a967f1738691b145c557df5df555dcc PiperOrigin-RevId: 666026942 --- jax/_src/interpreters/mlir.py | 7 +++---- jax/_src/pallas/mosaic/lowering.py | 4 +--- jax/_src/pallas/triton/lowering.py | 4 +--- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index e798a6fbdba9..bc1c00948943 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -483,10 +483,9 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location: return loc def _source_info_to_location( - ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any], + ctx: ModuleContext, primitive: core.Primitive, source_info: source_info_util.SourceInfo) -> ir.Location: - eqn_str = (f'{source_info.name_stack}/' - f'{core.str_eqn_compact(primitive, params)}') + eqn_str = f'{source_info.name_stack}/{primitive.name}' if config.include_full_tracebacks_in_locations.value: if source_info.traceback is None: loc = ir.Location.unknown() @@ -1745,7 +1744,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None in_nodes = map(read, eqn.invars) source_info = eqn.source_info.replace( name_stack=name_stack + eqn.source_info.name_stack) - loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info) + loc = _source_info_to_location(ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: override_rule = get_override_lowering_rule(eqn.primitive) platform_rules: dict[str, LoweringRule] = {} diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index aee894ee1b7e..2c5854876009 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -788,9 +788,7 @@ def write_env(var: jax_core.Var, val): source_info = eqn.source_info.replace( name_stack=ctx.name_stack + eqn.source_info.name_stack ) - loc = mlir._source_info_to_location( - ctx, eqn.primitive, eqn.params, source_info - ) + loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: if eqn.primitive in lowering_rules: if eqn.primitive not in skip_mlir_conversions: diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 852ac714d3c9..6db00a53671e 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -382,9 +382,7 @@ def write_env(var: jax_core.Var, val): avals_in = [v.aval for v in eqn.invars] avals_out = [v.aval for v in eqn.outvars] eqn_block_infos = map(read_block_info_env, eqn.invars) - loc = mlir._source_info_to_location( - ctx, eqn.primitive, eqn.params, eqn.source_info - ) + loc = mlir._source_info_to_location(ctx, eqn.primitive, eqn.source_info) rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos) try: with source_info_util.user_context(eqn.source_info.traceback), loc: From 810a91968a853b4ae15aa5c5282e5673136bb980 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 21 Aug 2024 15:51:55 -0700 Subject: [PATCH 201/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/6cdbb866c613947b22607664c32a7e06a23fe666. PiperOrigin-RevId: 666067632 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 291de1db18c2..f260f2dc65ac 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ac89370dce9df3d850bb51a1576ca39f1efec63b" -XLA_SHA256 = "054a56ecd26babe32deebdf0782e1090b6a1f2a6442c2752602b65ce87747d9a" +XLA_COMMIT = "6cdbb866c613947b22607664c32a7e06a23fe666" +XLA_SHA256 = "273b9d3e13f9c922357df4d0fabb1e4e3fc0a80f3848abafa60f33cd49185c10" def repo(): tf_http_archive( From 3713b966c2a868e948a663193282deba7ba14842 Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Wed, 21 Aug 2024 20:44:52 -0700 Subject: [PATCH 202/702] Fix a potential segfault in triton kernel call caching It is possible that a null pointer is inserted into the cache and not updated with a valid kernel call in case there is an error later during initialization. This change updates the cache to store either an error or a valid kernel call. PiperOrigin-RevId: 666161091 --- jaxlib/gpu/triton_kernels.cc | 63 ++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 89d804511313..c96c6b5c54b0 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -31,7 +31,6 @@ #include "jaxlib/gpu/triton_utils.h" #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" -#include "tsl/platform/env.h" #ifdef JAX_GPU_CUDA #include "xla/stream_executor/cuda/cuda_asm_compiler.h" @@ -137,14 +136,18 @@ absl::StatusOr GetKernelCall(absl::string_view opaque, gpuStream_t stream, void** buffers) { static absl::Mutex mutex; static auto& kernel_calls = - *new absl::flat_hash_map> + *new absl::flat_hash_map>> ABSL_GUARDED_BY(mutex); { // Fast path uses reader lock (as hash map look-up is relatively slow). absl::ReaderMutexLock lock(&mutex); auto it = kernel_calls.find(opaque); - if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) return it->second.get(); + if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) { + JAX_RETURN_IF_ERROR(it->second.status()); + return it->second->get(); + } } if (opaque.empty()) { @@ -152,37 +155,41 @@ absl::StatusOr GetKernelCall(absl::string_view opaque, } absl::MutexLock lock(&mutex); - std::unique_ptr& kernel_call = kernel_calls[opaque]; - // We released the reader lock, so it may have been written by another thread. - if (kernel_call != nullptr) return kernel_call.get(); - // The opaque data is a zlib compressed protobuf. - JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); + auto get_kernel_call = [&]() -> absl::StatusOr> { + // The opaque data is a zlib compressed protobuf. + JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); - jax_triton::TritonAnyKernelCall proto; - if (!proto.ParseFromString(serialized)) { - return absl::InvalidArgumentError("Failed to parse serialized data."); - } + jax_triton::TritonAnyKernelCall proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError("Failed to parse serialized data."); + } - if (proto.has_kernel_call()) { - JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_, - KernelCall::FromProto(proto.kernel_call())); - kernel_call = std::make_unique(std::move(kernel_call_)); - } else if (proto.has_autotuned_kernel_call()) { - JAX_ASSIGN_OR_RETURN( - AutotunedKernelCall autotuned_call, - AutotunedKernelCall::FromProto(proto.autotuned_kernel_call())); - { + if (proto.has_kernel_call()) { JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_, - AutotunedKernelCall::Autotune( - std::move(autotuned_call), stream, buffers)); - kernel_call = std::make_unique(std::move(kernel_call_)); + KernelCall::FromProto(proto.kernel_call())); + return std::make_unique(std::move(kernel_call_)); + } else if (proto.has_autotuned_kernel_call()) { + JAX_ASSIGN_OR_RETURN( + AutotunedKernelCall autotuned_call, + AutotunedKernelCall::FromProto(proto.autotuned_kernel_call())); + { + JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_, + AutotunedKernelCall::Autotune( + std::move(autotuned_call), stream, buffers)); + return std::make_unique(std::move(kernel_call_)); + } + } else { + return absl::InvalidArgumentError("Unknown kernel call type."); } - } else { - return absl::InvalidArgumentError("Unknown kernel call type."); - } + }; + + // We released the reader lock, so it may have been written by another thread. + // Create a new entry if it already exists or create a new one. + auto it = kernel_calls.emplace(std::string(opaque), get_kernel_call()).first; - return kernel_call.get(); + JAX_RETURN_IF_ERROR(it->second.status()); + return it->second->get(); } } // namespace From a72d46c54963bf967dab240be07f55437d6ff93f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Thu, 22 Aug 2024 01:37:59 -0700 Subject: [PATCH 203/702] Ignore LAPACK info parameter for QR Factorization The assumption is that QR Factorization will never fail from LAPACK's side because all necessary verification is happening right before the call. PiperOrigin-RevId: 666241215 --- jaxlib/cpu/lapack_kernels.cc | 14 ++++++-------- jaxlib/cpu/lapack_kernels.h | 3 +-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 9765b227d403..c1475d1f2ed2 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -306,14 +306,14 @@ template struct Geqrf>; // FFI Kernel template -ffi::Error QrFactorization::Kernel( - ffi::Buffer x, ffi::ResultBuffer x_out, - ffi::ResultBuffer tau, ffi::ResultBuffer info) { +ffi::Error QrFactorization::Kernel(ffi::Buffer x, + ffi::ResultBuffer x_out, + ffi::ResultBuffer tau) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* tau_data = tau->typed_data(); - auto* info_data = info->typed_data(); + lapack_int info; const int64_t work_size = GetWorkspaceSize(x_rows, x_cols); auto work_data = AllocateScratchMemory(work_size); @@ -328,10 +328,9 @@ ffi::Error QrFactorization::Kernel( const int64_t tau_step{std::min(x_rows, x_cols)}; for (int64_t i = 0; i < batch_count; ++i) { fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data, - work_data.get(), &workspace_dim_v, info_data); + work_data.get(), &workspace_dim_v, &info); x_out_data += x_out_step; tau_data += tau_step; - ++info_data; } return ffi::Error::Success(); } @@ -1713,8 +1712,7 @@ template struct Sytrd>; ::xla::ffi::Ffi::Bind() \ .Arg<::xla::ffi::Buffer>(/*x*/) \ .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*tau*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/)) + .Ret<::xla::ffi::Buffer>(/*tau*/)) #define JAX_CPU_DEFINE_ORGQR(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 8abf8e22daac..20823e785f32 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -194,8 +194,7 @@ struct QrFactorization { static ::xla::ffi::Error Kernel( ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, - ::xla::ffi::ResultBuffer tau, - ::xla::ffi::ResultBuffer info); + ::xla::ffi::ResultBuffer tau); static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); }; From 4786930a4c295da5d68aee85999e8d45297ae06e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Thu, 22 Aug 2024 04:09:02 -0700 Subject: [PATCH 204/702] Determine LAPACK workspace during Eigenvalue Kernels runtime PiperOrigin-RevId: 666285759 --- jaxlib/cpu/_lapack/__init__.pyi | 4 - jaxlib/cpu/lapack.cc | 16 --- jaxlib/cpu/lapack_kernels.cc | 178 +++++++++++++++----------------- jaxlib/cpu/lapack_kernels.h | 38 +++---- 4 files changed, 100 insertions(+), 136 deletions(-) diff --git a/jaxlib/cpu/_lapack/__init__.pyi b/jaxlib/cpu/_lapack/__init__.pyi index f2a4d943086a..4275d8e48813 100644 --- a/jaxlib/cpu/_lapack/__init__.pyi +++ b/jaxlib/cpu/_lapack/__init__.pyi @@ -49,11 +49,7 @@ def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matr # FFI Kernel LAPACK Workspace Size Queries -def heevd_rwork_size_ffi(n: int) -> int: ... -def heevd_work_size_ffi(n: int) -> int: ... def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def syevd_iwork_size_ffi(n: int) -> int: ... -def syevd_work_size_ffi(n: int) -> int: ... diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 83ed7610ced7..354a1cf9ab34 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -37,14 +37,6 @@ svd::ComputationMode GetSvdComputationMode(bool job_opt_compute_uv, return svd::ComputationMode::kComputeFullUVt; } -// Due to enforced kComputeEigenvectors, this assumes a larger workspace size. -// Could be improved to more accurately estimate the expected size based on the -// eig::ComputationMode value. -template -inline constexpr auto BoundWithEigvecs = +[](lapack_int n) { - return f(n, eig::ComputationMode::kComputeEigenvectors); -}; - void GetLapackKernelsFromScipy() { static bool initialized = false; // Protected by GIL if (initialized) return; @@ -348,14 +340,6 @@ NB_MODULE(_lapack, m) { m.def("lapack_zungqr_workspace_ffi", &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("syevd_work_size_ffi", BoundWithEigvecs, - nb::arg("n")); - m.def("syevd_iwork_size_ffi", BoundWithEigvecs, - nb::arg("n")); - m.def("heevd_work_size_ffi", BoundWithEigvecs, - nb::arg("n")); - m.def("heevd_rwork_size_ffi", BoundWithEigvecs, - nb::arg("n")); } } // namespace diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index c1475d1f2ed2..fd0a12ef2ed1 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -954,21 +954,24 @@ template struct ComplexHeevd>; // FFI Kernel -lapack_int eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr eig::GetWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: - return CastNoOverflow(2 * x_cols + 1); + return MaybeCastNoOverflow(2 * x_cols + 1); case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(1 + 6 * x_cols + 2 * x_cols * x_cols); + return MaybeCastNoOverflow(1 + 6 * x_cols + + 2 * x_cols * x_cols); } } -lapack_int eig::GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr eig::GetIntWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: return 1; case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(3 + 5 * x_cols); + return MaybeCastNoOverflow(3 + 5 * x_cols); } } @@ -976,34 +979,34 @@ template ffi::Error EigenvalueDecompositionSymmetric::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, ffi::ResultBuffer x_out, ffi::ResultBuffer eigenvalues, - ffi::ResultBuffer info, ffi::ResultBuffer work, - ffi::ResultBuffer iwork, eig::ComputationMode mode) { + ffi::ResultBuffer info, eig::ComputationMode mode) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* eigenvalues_data = eigenvalues->typed_data(); auto* info_data = info->typed_data(); - auto* work_data = work->typed_data(); - auto* iwork_data = iwork->typed_data(); CopyIfDiffBuffer(x, x_out); auto mode_v = static_cast(mode); auto uplo_v = static_cast(uplo); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); - FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow( - iwork->dimensions().back())); FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, MaybeCastNoOverflow(x_cols)); + // Prepare LAPACK workspaces. + FFI_ASSIGN_OR_RETURN(lapack_int work_size_v, + eig::GetWorkspaceSize(x_cols, mode)); + FFI_ASSIGN_OR_RETURN(lapack_int iwork_size_v, + eig::GetIntWorkspaceSize(x_cols, mode)); + auto work_data = AllocateScratchMemory(work_size_v); + auto iwork_data = AllocateScratchMemory(iwork_size_v); const int64_t x_out_step{x_cols * x_cols}; const int64_t eigenvalues_step{x_cols}; for (int64_t i = 0; i < batch_count; ++i) { fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v, - eigenvalues_data, work_data, &workspace_dim_v, iwork_data, - &iworkspace_dim_v, info_data); + eigenvalues_data, work_data.get(), &work_size_v, iwork_data.get(), + &iwork_size_v, info_data); x_out_data += x_out_step; eigenvalues_data += eigenvalues_step; ++info_data; @@ -1013,21 +1016,24 @@ ffi::Error EigenvalueDecompositionSymmetric::Kernel( namespace eig { -lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr GetComplexWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: - return CastNoOverflow(x_cols + 1); + return MaybeCastNoOverflow(x_cols + 1); case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(2 * x_cols + x_cols * x_cols); + return MaybeCastNoOverflow(2 * x_cols + x_cols * x_cols); } } -lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr GetRealWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: - return CastNoOverflow(std::max(x_cols, int64_t{1})); + return MaybeCastNoOverflow(std::max(x_cols, int64_t{1})); case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(1 + 5 * x_cols + 2 * x_cols * x_cols); + return MaybeCastNoOverflow(1 + 5 * x_cols + + 2 * x_cols * x_cols); } } @@ -1038,37 +1044,37 @@ ffi::Error EigenvalueDecompositionHermitian::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, ffi::ResultBuffer x_out, ffi::ResultBuffer eigenvalues, - ffi::ResultBuffer info, ffi::ResultBuffer work, - ffi::ResultBuffer rwork, - ffi::ResultBuffer iwork, eig::ComputationMode mode) { + ffi::ResultBuffer info, eig::ComputationMode mode) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* eigenvalues_data = eigenvalues->typed_data(); auto* info_data = info->typed_data(); - auto* work_data = work->typed_data(); - auto* iwork_data = iwork->typed_data(); CopyIfDiffBuffer(x, x_out); auto mode_v = static_cast(mode); auto uplo_v = static_cast(uplo); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); - FFI_ASSIGN_OR_RETURN(auto rworkspace_dim_v, MaybeCastNoOverflow( - rwork->dimensions().back())); - FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow( - iwork->dimensions().back())); FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, MaybeCastNoOverflow(x_cols)); + // Prepare LAPACK workspaces. + FFI_ASSIGN_OR_RETURN(lapack_int work_size_v, + eig::GetComplexWorkspaceSize(x_cols, mode)); + FFI_ASSIGN_OR_RETURN(lapack_int rwork_size_v, + eig::GetRealWorkspaceSize(x_cols, mode)); + FFI_ASSIGN_OR_RETURN(lapack_int iwork_size_v, + eig::GetIntWorkspaceSize(x_cols, mode)); + auto work_data = AllocateScratchMemory(work_size_v); + auto iwork_data = AllocateScratchMemory(iwork_size_v); + auto rwork_data = AllocateScratchMemory(rwork_size_v); const int64_t x_out_step{x_cols * x_cols}; const int64_t eigenvalues_step{x_cols}; for (int64_t i = 0; i < batch_count; ++i) { fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v, - eigenvalues_data, work_data, &workspace_dim_v, rwork->typed_data(), - &rworkspace_dim_v, iwork_data, &iworkspace_dim_v, info_data); + eigenvalues_data, work_data.get(), &work_size_v, rwork_data.get(), + &rwork_size_v, iwork_data.get(), &iwork_size_v, info_data); x_out_data += x_out_step; eigenvalues_data += eigenvalues_step; ++info_data; @@ -1265,16 +1271,11 @@ ffi::Error EigenvalueDecomposition::Kernel( ffi::ResultBuffer eigvals_imag, ffi::ResultBuffer eigvecs_left, ffi::ResultBuffer eigvecs_right, - ffi::ResultBuffer info, ffi::ResultBuffer x_work, - ffi::ResultBuffer work_eigvecs_left, - ffi::ResultBuffer work_eigvecs_right) { + ffi::ResultBuffer info) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); const auto* x_data = x.typed_data(); - auto* x_work_data = x_work->typed_data(); - auto* work_eigvecs_left_data = work_eigvecs_left->typed_data(); - auto* work_eigvecs_right_data = work_eigvecs_right->typed_data(); auto* eigvecs_left_data = eigvecs_left->typed_data(); auto* eigvecs_right_data = eigvecs_right->typed_data(); auto* eigvals_real_data = eigvals_real->typed_data(); @@ -1284,43 +1285,45 @@ ffi::Error EigenvalueDecomposition::Kernel( auto compute_left_v = static_cast(compute_left); auto compute_right_v = static_cast(compute_right); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - + // Prepare LAPACK workspaces. int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right); FFI_ASSIGN_OR_RETURN(auto work_size_v, MaybeCastNoOverflow(work_size)); - // TODO(phawkins): preallocate workspace using XLA. - auto work = std::make_unique(work_size); - auto* work_data = work.get(); + auto work_data = AllocateScratchMemory(work_size); + const int64_t x_size{x_cols * x_cols}; + auto x_copy = AllocateScratchMemory(x_size); + auto work_eigvecs_left = AllocateScratchMemory(x_size); + auto work_eigvecs_right = AllocateScratchMemory(x_size); const auto is_finite = [](ValueType* data, int64_t size) { return absl::c_all_of(absl::MakeSpan(data, size), [](ValueType value) { return std::isfinite(value); }); }; - const int64_t x_size{x_cols * x_cols}; [[maybe_unused]] const auto x_size_bytes = static_cast(x_size) * sizeof(ValueType); [[maybe_unused]] const auto x_cols_bytes = static_cast(x_cols) * sizeof(ValueType); for (int64_t i = 0; i < batch_count; ++i) { - std::copy_n(x_data, x_size, x_work_data); - if (is_finite(x_work_data, x_size)) { - fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v, - eigvals_real_data, eigvals_imag_data, work_eigvecs_left_data, - &x_cols_v, work_eigvecs_right_data, &x_cols_v, work_data, &work_size_v, - info_data); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes); + std::copy_n(x_data, x_size, x_copy.get()); + if (is_finite(x_copy.get(), x_size)) { + fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v, + eigvals_real_data, eigvals_imag_data, work_eigvecs_left.get(), + &x_cols_v, work_eigvecs_right.get(), &x_cols_v, work_data.get(), + &work_size_v, info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left_data, x_size_bytes); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right_data, + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left.get(), + x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right.get(), x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int)); if (info_data[0] == 0) { - UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left_data, + UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left.get(), eigvecs_left_data); - UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_right_data, - eigvecs_right_data); + UnpackEigenvectors(x_cols_v, eigvals_imag_data, + work_eigvecs_right.get(), eigvecs_right_data); } } else { info_data[0] = -4; @@ -1341,12 +1344,10 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( eig::ComputationMode compute_right, ffi::ResultBuffer eigvals, ffi::ResultBuffer eigvecs_left, ffi::ResultBuffer eigvecs_right, - ffi::ResultBuffer info, ffi::ResultBuffer x_work, - ffi::ResultBuffer rwork) { + ffi::ResultBuffer info) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); const auto* x_data = x.typed_data(); - auto* x_work_data = x_work->typed_data(); auto* eigvecs_left_data = eigvecs_left->typed_data(); auto* eigvecs_right_data = eigvecs_right->typed_data(); auto* eigvals_data = eigvals->typed_data(); @@ -1355,13 +1356,14 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( auto compute_left_v = static_cast(compute_left); auto compute_right_v = static_cast(compute_right); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - + // Prepare LAPACK workspaces. int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right); FFI_ASSIGN_OR_RETURN(auto work_size_v, MaybeCastNoOverflow(work_size)); - // TODO(phawkins): preallocate workspace using XLA. - auto work = std::make_unique(work_size); - auto* work_data = work.get(); + auto work_data = AllocateScratchMemory(work_size); + const int64_t x_size{x_cols * x_cols}; + auto x_copy = AllocateScratchMemory(x_size); + auto rwork_data = AllocateScratchMemory(2 * x_cols); const auto is_finite = [](ValueType* data, int64_t size) { return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) { @@ -1369,18 +1371,18 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( }); }; - const int64_t x_size{x_cols * x_cols}; [[maybe_unused]] const auto x_size_bytes = static_cast(x_size) * sizeof(ValueType); [[maybe_unused]] const auto x_cols_bytes = static_cast(x_cols) * sizeof(ValueType); for (int64_t i = 0; i < batch_count; ++i) { - std::copy_n(x_data, x_size, x_work_data); - if (is_finite(x_work_data, x_size)) { - fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v, + std::copy_n(x_data, x_size, x_copy.get()); + if (is_finite(x_copy.get(), x_size)) { + fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v, eigvals_data, eigvecs_left_data, &x_cols_v, eigvecs_right_data, - &x_cols_v, work_data, &work_size_v, rwork->typed_data(), info_data); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes); + &x_cols_v, work_data.get(), &work_size_v, rwork_data.get(), + info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_left_data, x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_right_data, x_size_bytes); @@ -1766,23 +1768,18 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*x_out*/) \ .Ret<::xla::ffi::Buffer>(/*eigenvalues*/) \ .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ .Attr("mode")) -#define JAX_CPU_DEFINE_HEEVD(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, EigenvalueDecompositionHermitian::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Attr("uplo") \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ - /*eigenvalues*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ +#define JAX_CPU_DEFINE_HEEVD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, EigenvalueDecompositionHermitian::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("uplo") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ + /*eigenvalues*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ .Attr("mode")) #define JAX_CPU_DEFINE_GEEV(name, data_type) \ @@ -1798,12 +1795,7 @@ template struct Sytrd>; /*eigvecs_left*/) \ .Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \ /*eigvecs_right*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*x_work*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ - /*work_eigvecs_left*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ - /*work_eigvecs_right*/)) + .Ret<::xla::ffi::Buffer>(/*info*/)) #define JAX_CPU_DEFINE_GEEV_COMPLEX(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ @@ -1815,9 +1807,7 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*eigvals*/) \ .Ret<::xla::ffi::Buffer>(/*eigvecs_left*/) \ .Ret<::xla::ffi::Buffer>(/*eigvecs_right*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*x_work*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/)) + .Ret<::xla::ffi::Buffer>(/*info*/)) // FFI Handlers diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 20823e785f32..4d021b688de9 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -395,12 +395,16 @@ struct ComplexHeevd { namespace eig { // Eigenvalue Decomposition -lapack_int GetWorkspaceSize(int64_t x_cols, ComputationMode mode); -lapack_int GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode); +absl::StatusOr GetWorkspaceSize(int64_t x_cols, + ComputationMode mode); +absl::StatusOr GetIntWorkspaceSize(int64_t x_cols, + ComputationMode mode); // Hermitian Eigenvalue Decomposition -lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode); -lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode); +absl::StatusOr GetComplexWorkspaceSize(int64_t x_cols, + ComputationMode mode); +absl::StatusOr GetRealWorkspaceSize(int64_t x_cols, + ComputationMode mode); } // namespace eig @@ -417,14 +421,12 @@ struct EigenvalueDecompositionSymmetric { inline static FnType* fn = nullptr; - static ::xla::ffi::Error Kernel( - ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, - ::xla::ffi::ResultBuffer x_out, - ::xla::ffi::ResultBuffer eigenvalues, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer work, - ::xla::ffi::ResultBuffer iwork, - eig::ComputationMode mode); + static ::xla::ffi::Error Kernel(::xla::ffi::Buffer x, + MatrixParams::UpLo uplo, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer eigenvalues, + ::xla::ffi::ResultBuffer info, + eig::ComputationMode mode); }; template <::xla::ffi::DataType dtype> @@ -445,9 +447,6 @@ struct EigenvalueDecompositionHermitian { ::xla::ffi::ResultBuffer x_out, ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> eigenvalues, ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer work, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork, - ::xla::ffi::ResultBuffer iwork, eig::ComputationMode mode); }; @@ -496,10 +495,7 @@ struct EigenvalueDecomposition { ::xla::ffi::ResultBuffer eigvals_imag, ::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_left, ::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_right, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer x_work, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_left, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_right); + ::xla::ffi::ResultBuffer info); static int64_t GetWorkspaceSize(lapack_int x_cols, eig::ComputationMode compute_left, @@ -526,9 +522,7 @@ struct EigenvalueDecompositionComplex { ::xla::ffi::ResultBuffer eigvals, ::xla::ffi::ResultBuffer eigvecs_left, ::xla::ffi::ResultBuffer eigvecs_right, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer x_work, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork); + ::xla::ffi::ResultBuffer info); static int64_t GetWorkspaceSize(lapack_int x_cols, eig::ComputationMode compute_left, From 0b4f64e0025f63ebae1fe3dd8bc811f80a3cd5c9 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 Aug 2024 04:58:26 -0700 Subject: [PATCH 205/702] [Mosaic GPU] Allow tile sizes to exceed dimension size Otherwise, the dimension size still needs to be a multiple of tiling. PiperOrigin-RevId: 666298624 --- jax/experimental/mosaic/gpu/__init__.py | 14 ++++-- tests/mosaic/gpu_test.py | 67 +++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index f5a92cc67f21..9617d53075a3 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -145,7 +145,13 @@ def apply(self, ref: ir.Value) -> ir.Value: tiling_rank = len(self.tiling) tiled_rank = untiled_rank + tiling_rank for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]): - ref = utils.memref_unfold(ref, d, (None, t)) + s = ir.MemRefType(ref.type).shape[d] + if s % t and s > t: + raise ValueError( + f"Dimension {d} must have size smaller or a multiple of its tiling" + f" {t}, but got {s}" + ) + ref = utils.memref_unfold(ref, d, (None, min(t, s))) permutation = ( *range(untiled_rank - tiling_rank), *range(untiled_rank - tiling_rank, tiled_rank, 2), @@ -175,8 +181,10 @@ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: for size, tile_size in zip(shape[-tiling_rank:], self.tiling): if size % tile_size: raise ValueError( - f"Expected GMEM slice shape {shape} suffix to be a multiple" - f" of tiling {self.tiling}" + f"Expected GMEM slice shape {shape} suffix to be a multiple of" + f" tiling {self.tiling}.\nIf you're using padded async copies, your" + " slice might need to extend out of bounds of the GMEM buffer (OOB" + " accesses will be skipped)." ) return ( *shape[:-tiling_rank], diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f737ce721510..44fd518344e3 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1024,6 +1024,73 @@ def kernel(ctx, src, dst, tmp): y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) np.testing.assert_array_equal(y, x) + @parameterized.parameters(0, 1) + def test_tma_small_tile_load(self, small_dim): + if small_dim == 0: + shape = (4, 128) + elif small_dim == 1: + shape = (128, 8) + else: + raise ValueError("small_dim must be 0 or 1") + tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64) + padded_shape = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2])) + def kernel(ctx, src, dst, smem): + tmp, barrier = smem + ctx.async_copy( + src_ref=src, + dst_ref=tmp, + swizzle=128, + gmem_transform=mosaic_gpu.TileTransform((64, 64)), + gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])), + barrier=barrier, + ) + barrier.wait() + copy(tmp, dst, swizzle=128) + x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape) + tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16) + y_tiled = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, tiled, (tiled, mgpu.TMABarrier()), + )(x) + y = y_tiled.swapaxes(1, 2).reshape(padded_shape) + # y should contain x and zero everywhere else. + np.testing.assert_array_equal(y[:shape[0], :shape[1]], x) + y_mut = np.asarray(y).copy() + y_mut[:shape[0], :shape[1]] = 0 + np.testing.assert_array_equal(y_mut, np.zeros_like(y_mut)) + + @parameterized.parameters(0, 1) + def test_tma_small_tile_store(self, small_dim): + if small_dim == 0: + shape = (4, 128) + elif small_dim == 1: + shape = (128, 8) + else: + raise ValueError("small_dim must be 0 or 1") + tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64) + padded_shape = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2])) + def kernel(ctx, dst, tmp): + vals = iota_tensor( + m=padded_shape[0], n=padded_shape[1], mlir_dtype=ir.F16Type.get() + ) + vals.store_tiled(tmp, swizzle=128) + ctx.async_copy( + src_ref=tmp, + dst_ref=dst, + swizzle=128, + gmem_transform=mosaic_gpu.TileTransform((64, 64)), + gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])), + ) + ctx.await_async_copy(0) + tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16) + out = jax.ShapeDtypeStruct(shape, jnp.float16) + y = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out, tiled, + )() + iota = np.arange(np.prod(padded_shape), dtype=jnp.float16).reshape( + padded_shape + ) + np.testing.assert_array_equal(y, iota[:shape[0], :shape[1]]) + def test_tma_invalid(self): def kernel(ctx, src, dst, tmp): copy(src, tmp) From b56ed8eeddc5794f3981832a38b6bcc195eb20f8 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 22 Aug 2024 05:22:39 -0700 Subject: [PATCH 206/702] Port GPU kernel for Householder transformation to FFI. PiperOrigin-RevId: 666305682 --- jaxlib/gpu/gpu_kernels.cc | 2 + jaxlib/gpu/solver.cc | 1 + jaxlib/gpu/solver_kernels_ffi.cc | 171 ++++++++++++++++++++++++------- jaxlib/gpu/solver_kernels_ffi.h | 1 + 4 files changed, 140 insertions(+), 35 deletions(-) diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 1814641bb4fb..3841393654a8 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -51,6 +51,8 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", + OrgqrFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 4ee7a9f1dbf7..fee1c1014c75 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -477,6 +477,7 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); + dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); return dict; } diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 91124a847121..2b1f5552977f 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -51,13 +51,13 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, } // namespace #define SOLVER_DISPATCH_IMPL(impl, ...) \ - if (dataType == ffi::DataType::F32) { \ + if (dataType == ffi::F32) { \ return impl(__VA_ARGS__); \ - } else if (dataType == ffi::DataType::F64) { \ + } else if (dataType == ffi::F64) { \ return impl(__VA_ARGS__); \ - } else if (dataType == ffi::DataType::C64) { \ + } else if (dataType == ffi::C64) { \ return impl(__VA_ARGS__); \ - } else if (dataType == ffi::DataType::C128) { \ + } else if (dataType == ffi::C128) { \ return impl(__VA_ARGS__); \ } @@ -94,8 +94,8 @@ template ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { + ffi::Result> ipiv, + ffi::Result> info) { FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); @@ -110,13 +110,12 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, auto ipiv_data = ipiv->typed_data(); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols, - gpuMemcpyDeviceToDevice, stream))); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); } int ipiv_step = std::min(m, n); - for (int i = 0; i < batch; ++i) { + for (auto i = 0; i < batch; ++i) { FFI_RETURN_IF_ERROR_STATUS(GetrfKernel::Run( handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data)); out_data += m * n; @@ -147,8 +146,8 @@ template ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { + ffi::Result> ipiv, + ffi::Result> info) { FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); FFI_ASSIGN_OR_RETURN(auto batch_ptrs, @@ -159,9 +158,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, auto ipiv_data = ipiv->typed_data(); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols, - gpuMemcpyDeviceToDevice, stream))); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); } MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch, @@ -176,8 +174,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { + ffi::Result> ipiv, + ffi::Result> info) { auto dataType = a.element_type(); if (dataType != out->element_type()) { return ffi::Error::InvalidArgument( @@ -201,15 +199,14 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, } } // namespace -XLA_FFI_DEFINE_HANDLER_SYMBOL( - GetrfFfi, GetrfDispatch, - ffi::Ffi::Bind() - .Ctx>() - .Ctx() - .Arg() // a - .Ret() // out - .Ret>() // ipiv - .Ret>() // info +XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // a + .Ret() // out + .Ret>() // ipiv + .Ret>() // info ); // QR decomposition: geqrf @@ -264,14 +261,13 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, auto out_data = static_cast(out->untyped_data()); auto tau_data = static_cast(tau->untyped_data()); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols, - gpuMemcpyDeviceToDevice, stream))); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); } int out_step = m * n; int tau_step = std::min(m, n); - for (int i = 0; i < batch; ++i) { + for (auto i = 0; i < batch; ++i) { FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel::Run( handle.get(), m, n, out_data, tau_data, workspace, lwork, info)); out_data += out_step; @@ -284,8 +280,8 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, template <> \ struct GeqrfBatchedKernel { \ static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \ - type** tau, int* info, int batch) { \ - return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \ + type** tau, int* info, int batch) { \ + return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \ } \ } @@ -314,9 +310,8 @@ ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols, auto out_data = out->untyped_data(); auto tau_data = tau->untyped_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols, - gpuMemcpyDeviceToDevice, stream))); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); } MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch, @@ -369,6 +364,112 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch, .Ret() // tau ); +// Householder transformations: orgqr + +namespace { +#define ORGQR_KERNEL_IMPL(type, name) \ + template <> \ + struct OrgqrKernel { \ + static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \ + int n, int k) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + name##_bufferSize(handle, m, n, k, /*A=*/nullptr, /*lda=*/m, \ + /*tau=*/nullptr, &lwork))); \ + return lwork; \ + } \ + static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, int k, \ + type* a, type* tau, type* workspace, int lwork, \ + int* info) { \ + return JAX_AS_STATUS( \ + name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \ + } \ + } + +template +struct OrgqrKernel; +ORGQR_KERNEL_IMPL(float, gpusolverDnSorgqr); +ORGQR_KERNEL_IMPL(double, gpusolverDnDorgqr); +ORGQR_KERNEL_IMPL(gpuComplex, gpusolverDnCungqr); +ORGQR_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZungqr); +#undef ORGQR_KERNEL_IMPL + +template +ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + ffi::AnyBuffer a, ffi::AnyBuffer tau, + ffi::Result out) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto k, MaybeCastNoOverflow(size)); + + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(int lwork, + OrgqrKernel::BufferSize(handle.get(), m, n, k)); + + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "orgqr")); + // Note: We ignore the returned value of info because it is only used for + // shape checking (which we already do ourselves), but it is expected to be + // in device memory, so we need to allocate it. + FFI_ASSIGN_OR_RETURN(auto info, AllocateWorkspace(scratch, 1, "orgqr")); + + auto a_data = static_cast(a.untyped_data()); + auto tau_data = static_cast(tau.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + } + + int out_step = m * n; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(OrgqrKernel::Run( + handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info)); + out_data += out_step; + tau_data += k; + } + return ffi::Error::Success(); +} + +ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + ffi::AnyBuffer a, ffi::AnyBuffer tau, + ffi::Result out) { + auto dataType = a.element_type(); + if (dataType != tau.element_type() || dataType != out->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to orgqr must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + FFI_ASSIGN_OR_RETURN((auto [tau_batch, size]), + SplitBatch1D(tau.dimensions())); + if (tau_batch != batch) { + return ffi::Error::InvalidArgument( + "The batch dimensions of the inputs to orgqr must match"); + } + if (size > cols) { + return ffi::Error::InvalidArgument( + "The trailing dimension of the tau input to orgqr must be less than or " + "equal to the number of columns of the input matrix"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "orgqr")); + SOLVER_DISPATCH_IMPL(OrgqrImpl, batch, rows, cols, size, stream, scratch, a, + tau, out); + return ffi::Error::InvalidArgument("Unsupported element type for orgqr"); +} +} // namespace + +XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // a + .Arg() // tau + .Ret() // out +); + #undef SOLVER_DISPATCH_IMPL } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index d9c3da47655a..7dbc7454c2e6 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -24,6 +24,7 @@ namespace JAX_GPU_NAMESPACE { XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); } // namespace JAX_GPU_NAMESPACE } // namespace jax From 9c3f2dcefc2ca6a603a122bf4ab8bcfe67e247f7 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 Aug 2024 08:06:04 -0700 Subject: [PATCH 207/702] [Mosaic GPU] Make CUDA context part of the hash key + replace kernel id with a SHA256 digest XLA runtime creates a context per device, so we need to make sure that a kernel is loaded separately on each device. PiperOrigin-RevId: 666353098 --- jax/experimental/mosaic/gpu/__init__.py | 15 ++++++++-- jaxlib/mosaic/gpu/custom_call.cc | 38 ++++++++++++++++--------- tests/mosaic/BUILD | 1 + tests/mosaic/gpu_test.py | 13 +++++++++ 4 files changed, 51 insertions(+), 16 deletions(-) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 9617d53075a3..eb8eba9dfacf 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -18,6 +18,7 @@ import ctypes import dataclasses import functools +import hashlib import itertools import math import os @@ -92,11 +93,19 @@ def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes): return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] # TODO(apaszke): Implement a proper system for managing kernel lifetimes -kernel_idx = itertools.count() +KNOWN_KERNELS = {} def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes): del out_types # Unused. - idx_bytes = next(kernel_idx).to_bytes(8, byteorder="little") + kernel_id = hashlib.sha256(module).digest() + # Note that this is technically only a half measure. Someone might load a + # compiled module with a hash collision from disk. But that's so unlikely with + # SHA256 that it shouldn't be a problem. + if (kernel_text := KNOWN_KERNELS.get(kernel_id, None)) is not None: + if kernel_text != module: + raise RuntimeError("Hash collision!") + else: + KNOWN_KERNELS[kernel_id] = module op = mlir.custom_call( "mosaic_gpu", result_types=[ @@ -109,7 +118,7 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes) operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out] + [[0]], - backend_config=idx_bytes + module, + backend_config=kernel_id + module, ) return op.results[:-1] # Skip the scratch space. diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 47ad893eaa05..8fdd34cd91dc 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -18,7 +18,9 @@ limitations under the License. #include #include +#include #include +#include #include #include #include @@ -366,11 +368,14 @@ class CompiledKernel { MosaicHostFunc* host_launch_; }; -std::pair*, absl::Mutex*> +using KernelHash = std::array; +using CacheKey = std::pair; + +std::pair*, absl::Mutex*> GetKernelCache() { static absl::Mutex mutex; static auto& context_cache = - *new absl::flat_hash_map; + *new absl::flat_hash_map; return std::make_pair(&context_cache, &mutex); } @@ -378,7 +383,7 @@ GetKernelCache() { // a single HLO module. So it should be safe to not include the CUDA context // in the key. absl::StatusOr> CompileAndInit( - uint64_t kernel_id, const char* module) { + CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; auto* mutex = cache_and_mutex.second; @@ -386,14 +391,14 @@ absl::StatusOr> CompileAndInit( { // Fast path uses reader lock (as hash map look-up is relatively slow). absl::ReaderMutexLock lock(mutex); - auto it = cache->find(kernel_id); + auto it = cache->find(key); if (ABSL_PREDICT_TRUE(it != cache->end())) return it->second.GetHostLaunch(); } absl::MutexLock lock(mutex); // We released the reader lock, another thread might have initialized it. - if (cache->find(kernel_id) == cache->end()) { + if (cache->find(key) == cache->end()) { mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); InitContext(&context); mlir::ParserConfig parse_config(&context); @@ -418,22 +423,29 @@ absl::StatusOr> CompileAndInit( void** kernel_ptr_ptr = &kernel_ptr; void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); - CUmodule module = static_cast(module_ptr); - CUdeviceptr scratch_addr; - cuModuleGetGlobal(&scratch_addr, nullptr, module, "global_scratch"); cache->insert_or_assign( - kernel_id, + key, CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(scratch_addr), + nullptr, // TODO(apaszke): Clean this up. reinterpret_cast(*main))); } - return cache->at(kernel_id).GetHostLaunch(); + return cache->at(key).GetHostLaunch(); } void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { - uint64_t kernel_id = *reinterpret_cast(opaque); - auto ctx_and_kernel = CompileAndInit(kernel_id, opaque + sizeof(uint64_t)); + if (reinterpret_cast(opaque) % alignof(KernelHash)) { + fprintf(stderr, "Misaligned opaque pointer\n"); + abort(); + } + auto hash = *reinterpret_cast(opaque); + CUcontext ctx; + if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) { + fprintf(stderr, "Failed to get current CUDA context\n"); + abort(); + } + CacheKey key(hash, reinterpret_cast(ctx)); + auto ctx_and_kernel = CompileAndInit(key, opaque + sizeof(KernelHash)); if (!ctx_and_kernel.ok()) { XlaCustomCallStatusSetFailure(status, ctx_and_kernel.status().message().data(), diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index a52a62962b9c..fdb7ad7b0a1f 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -47,6 +47,7 @@ jax_test( srcs = ["gpu_test.py"], disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, + enable_configs = ["gpu_h100_2gpu"], shard_count = 4, deps = [ "//jax:mosaic_gpu", diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 44fd518344e3..e2bca54ab230 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1286,6 +1286,19 @@ def test_measure(self): x = jnp.arange(1024 * 1024) profiler.measure(lambda x, y: x + y, x, x) # This is just a smoke test + def test_multigpu(self): + if len(jax.devices()) < 2: + self.skipTest("Need at least 2 devices") + def kernel(ctx, src, dst, _): + mgpu.FragmentedArray.load_strided(src).store_untiled(dst) + x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + f = jax.jit(mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, () + )) + # Make sure we can invoke the same program on different devices. + for xd in (jax.device_put(x, d) for d in jax.devices()[:2]): + jax.block_until_ready(f(xd)) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 5c2ffa893f3def1ae822a50afdad56d709558713 Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Mon, 19 Aug 2024 14:32:35 -0500 Subject: [PATCH 208/702] * Add conditional docker interactive mode Interactive causes bazel to output more useful info when running locally. * Fix issue with rocm el8 repo urls Work around quirk with rocm version when it ends with 0 * Fix package name conflict Ubu22 and higher have a package name conflict between the debian versions and the AMD provided versions. * [ROCm] Use clang env --- build/rocm/ci_build | 15 ++++++++++- build/rocm/ci_build.sh | 12 +++++++++ build/rocm/tools/build_wheels.py | 12 +++++++-- build/rocm/tools/get_rocm.py | 46 +++++++++++++++++++++++++------- 4 files changed, 72 insertions(+), 13 deletions(-) diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 43f34f6ca758..aeb0201e27ed 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -34,7 +34,8 @@ def image_by_name(name): def dist_wheels( - rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="" + rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="", + compiler="gcc" ): if xla_path: xla_path = os.path.abspath(xla_path) @@ -71,6 +72,8 @@ def dist_wheels( rocm_version, "--python-versions", pyver_string, + "--compiler", + compiler, ] if xla_path: @@ -92,6 +95,9 @@ def dist_wheels( cmd.extend(mounts) + if os.isatty(sys.stdout.fileno()): + cmd.append("-it") + # NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID cmd.extend( [ @@ -251,6 +257,12 @@ def parse_args(): help="Path to XLA source to use during jaxlib build, instead of builtin XLA", ) + p.add_argument( + "--compiler", + choices=["gcc", "clang"], + help="Compiler backend to use when compiling jax/jaxlib" + ) + subp = p.add_subparsers(dest="action", required=True) dwp = subp.add_parser("dist_wheels") @@ -288,6 +300,7 @@ def main(): args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, + args.compiler, ) dist_docker( args.rocm_version, diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 7f93af8cae4c..302a0449b19e 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -50,6 +50,7 @@ ROCM_BUILD_JOB="" ROCM_BUILD_NUM="" BASE_DOCKER="ubuntu:20.04" CUSTOM_INSTALL="" +JAX_USE_CLANG="" POSITIONAL_ARGS=() RUNTIME_FLAG=1 @@ -89,6 +90,10 @@ while [[ $# -gt 0 ]]; do ROCM_BUILD_NUM="$2" shift 2 ;; + --use_clang) + JAX_USE_CLANG="$2" + shift 2 + ;; *) POSITIONAL_ARGS+=("$1") shift @@ -135,6 +140,12 @@ echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERF export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}" +# default to gcc +JAX_COMPILER="gcc" +if [ -n "$JAX_USE_CLANG" ]; then + JAX_COMPILER="clang" +fi + # ci_build.sh is mostly a compatibility wrapper for ci_build # 'dist_docker' will run 'dist_wheels' followed by a Docker build to create the "JAX image", @@ -145,6 +156,7 @@ export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}" --xla-source-dir=$XLA_CLONE_DIR \ --rocm-build-job=$ROCM_BUILD_JOB \ --rocm-build-num=$ROCM_BUILD_NUM \ + --compiler=$JAX_COMPILER \ dist_docker \ --dockerfile $DOCKERFILE_PATH \ --image-tag $DOCKER_IMG_NAME diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index 7f2a4c862bf0..1ba9e0b910db 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -56,7 +56,8 @@ def update_rocm_targets(rocm_path, targets): open(version_fp, "a").close() -def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None): +def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"): + use_clang = "true" if compiler == "clang" else "false" cmd = [ "python", "build/build.py", @@ -64,6 +65,7 @@ def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None): "--build_gpu_plugin", "--gpu_plugin_rocm_version=60", "--rocm_path=%s" % rocm_path, + "--use_clang=%s" % use_clang, ] if xla_path: @@ -194,6 +196,12 @@ def parse_args(): default=None, help="Optional directory where XLA source is located to use instead of JAX builtin XLA", ) + p.add_argument( + "--compiler", + type=str, + default="gcc", + help="Compiler backend to use when compiling jax/jaxlib", + ) p.add_argument("jax_path", help="Directory where JAX source directory is located") @@ -225,7 +233,7 @@ def main(): update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS) for py in python_versions: - build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path) + build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler) wheel_paths = find_wheels(os.path.join(args.jax_path, "dist")) for wheel_path in wheel_paths: # skip jax wheel since it is non-platform diff --git a/build/rocm/tools/get_rocm.py b/build/rocm/tools/get_rocm.py index d29f67982d4c..5334bf40ece7 100644 --- a/build/rocm/tools/get_rocm.py +++ b/build/rocm/tools/get_rocm.py @@ -115,6 +115,23 @@ def install_rocm(self): ) +def parse_version(version_str): + if isinstance(version_str, str): + parts = version_str.split(".") + rv = type("Version", (), {})() + rv.major = int(parts[0].strip()) + rv.minor = int(parts[1].strip()) + rv.rev = None + + if len(parts) > 2: + rv.rev = int(parts[2].strip()) + + else: + rv = version_str + + return rv + + def get_system(): md = os_release_meta() @@ -210,16 +227,7 @@ def install_amdgpu_installer_internal(rocm_version): def _build_installer_url(rocm_version, metadata): md = metadata - if isinstance(rocm_version, str): - parts = rocm_version.split(".") - rv = type("Version", (), {})() - rv.major = parts[0] - rv.minor = parts[1] - - if len(parts) > 2: - rv.rev = parts[2] - else: - rv = rocm_version + rv = parse_version(rocm_version) base_url = "http://artifactory-cdn.amd.com/artifactory/list" @@ -247,8 +255,21 @@ def _build_installer_url(rocm_version, metadata): return url, package_name +APT_RADEON_PIN_CONTENT = """ +Package: * +Pin: release o=repo.radeon.com +Pin-Priority: 600 +""" + + def setup_repos_ubuntu(rocm_version_str): + rv = parse_version(rocm_version_str) + + # if X.Y.0 -> repo url version should be X.Y + if rv.rev == 0: + rocm_version_str = "%d.%d" % (rv.major, rv.minor) + s = get_system() s.install_packages(["wget", "sudo", "gnupg"]) @@ -270,6 +291,11 @@ def setup_repos_ubuntu(rocm_version_str): % (rocm_version_str, codename) ) + # on ubuntu 22 or greater, debian community rocm packages + # conflict with repo.radeon.com packages + with open("/etc/apt/preferences.d/rocm-pin-600", "w") as fd: + fd.write(APT_RADEON_PIN_CONTENT) + # update indexes subprocess.check_call(["apt-get", "update"]) From 2c221f2d5a47da4435abd8be9f6e2d4479012716 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 22 Aug 2024 09:41:53 -0700 Subject: [PATCH 209/702] Register several jax.numpy argument name deprecations --- jax/__init__.py | 5 ----- jax/_src/deprecations.py | 13 +++++++++++++ jax/_src/numpy/lax_numpy.py | 23 ++++++++++------------- jax/_src/numpy/linalg.py | 25 +++++++++++++++---------- jax/_src/numpy/reductions.py | 25 +++++++++++++++++-------- tests/lax_numpy_reducers_test.py | 15 +++++++++++++++ tests/lax_numpy_test.py | 20 ++++++++++++++++++++ tests/linalg_test.py | 23 +++++++++++++++++++++++ 8 files changed, 113 insertions(+), 36 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index dc3d9af3a0c4..037386317ee4 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -180,11 +180,6 @@ import jax.experimental.compilation_cache.compilation_cache as _ccache del _ccache -from jax._src.deprecations import register as _register_deprecation -_register_deprecation('jax-scipy-beta-args') -_register_deprecation('tracer-hash') -del _register_deprecation - _deprecations = { # Added July 2022 "treedef_is_leaf": ( diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index fb346ca9b372..96eca4ccf45c 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -117,3 +117,16 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: else: warnings.warn(message, category=DeprecationWarning, stacklevel=stacklevel + 1) + + +# Register a number of deprecations: we do this here to ensure they're +# always registered by the time `accelerate` and `is_acelerated` are called. +register("jax-numpy-astype-complex-to-real") +register("jax-numpy-array-none") +register('jax-scipy-beta-args') +register('tracer-hash') +register('jax-numpy-reshape-newshape') +register('jax-numpy-clip-args') +register('jax-numpy-linalg-matrix_rank-tol') +register('jax-numpy-linalg-pinv-rcond') +register('jax-numpy-quantile-interpolation') diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f5abb3971cfc..c5feb67332c2 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1388,6 +1388,8 @@ def reshape( JAX does not support ``order="A"``. copy: unused by JAX; JAX always returns a copy, though under JIT the compiler may optimize such copies away. + newshape: deprecated alias of the ``shape`` argument. Will result in a + :class:`DeprecationWarning` if used. Returns: reshaped copy of input array with the specified shape. @@ -1452,11 +1454,10 @@ def reshape( "jnp.reshape received both `shape` and `newshape` arguments. Note that " "using `newshape` is deprecated, please only use `shape` instead." ) - warnings.warn( - "The newshape argument of jax.numpy.reshape is deprecated and setting it " - "will soon raise an error. To avoid an error in the future, and to " - "suppress this warning, please use the shape argument instead.", - DeprecationWarning, stacklevel=2) + deprecations.warn( + "jax-numpy-reshape-newshape", + ("The newshape argument of jax.numpy.reshape is deprecated. " + "Please use the shape argument instead."), stacklevel=2) shape = newshape del newshape elif shape is None: @@ -2502,10 +2503,10 @@ def clip( min = a_min if not isinstance(a_min, DeprecatedArg) else min max = a_max if not isinstance(a_max, DeprecatedArg) else max if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)): - warnings.warn( - "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is " - "deprecated. Please use 'arr', 'min' or 'max' respectively instead.", - DeprecationWarning, + deprecations.warn( + "jax-numpy-clip-args", + ("Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is " + "deprecated. Please use 'arr', 'min' or 'max' respectively instead."), stacklevel=2, ) @@ -3479,8 +3480,6 @@ def _supports_buffer_protocol(obj): https://jax.readthedocs.io/en/latest/faq.html). """ -deprecations.register("jax-numpy-array-none") - def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0, @@ -3670,8 +3669,6 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: return x -deprecations.register("jax-numpy-astype-complex-to-real") - def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 729dc81adb90..bb0ba2e85499 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -28,6 +28,7 @@ from jax import jit, custom_jvp from jax import lax +from jax._src import deprecations from jax._src.lax import lax as lax_internal from jax._src.lax.lax import PrecisionLike from jax._src.lax import linalg as lax_linalg @@ -408,6 +409,8 @@ def matrix_rank( smaller than `rtol * largest_singular_value` are considered to be zero. If ``rtol`` is None (the default), a reasonable default is chosen based the floating point precision of the input. + tol: deprecated alias of the ``rtol`` argument. Will result in a + :class:`DeprecationWarning` if used. Returns: array of shape ``a.shape[-2]`` giving the matrix rank. @@ -433,11 +436,11 @@ def matrix_rank( if not isinstance(tol, DeprecatedArg): rtol = tol del tol - warnings.warn( - "The tol argument for linalg.matrix_rank is deprecated using it will soon raise " - "an error. To prepare for future releases, and suppress this warning, " - "please use rtol instead.", - DeprecationWarning, stacklevel=2 + deprecations.warn( + "jax-numpy-linalg-matrix_rank-tol", + ("The tol argument for linalg.matrix_rank is deprecated. " + "Please use rtol instead."), + stacklevel=2 ) M, = promote_dtypes_inexact(jnp.asarray(M)) if M.ndim < 2: @@ -891,6 +894,8 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None, determined based on the floating point precision of the dtype. hermitian: if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False) + rcond: deprecated alias of the ``rtol`` argument. Will result in a + :class:`DeprecationWarning` if used. Returns: An array of shape ``(..., N, M)`` containing the pseudo-inverse of ``a``. @@ -921,11 +926,11 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None, if not isinstance(rcond, DeprecatedArg): rtol = rcond del rcond - warnings.warn( - "The rcond argument for linalg.pinv is deprecated using it will soon " - "raise an error. To prepare for future releases, and suppress this " - "warning, please use rtol instead.", - DeprecationWarning, stacklevel=2 + deprecations.warn( + "jax-numpy-linalg-pinv-rcond", + ("The rcond argument for linalg.pinv is deprecated. " + "Please use rtol instead."), + stacklevel=2 ) return _pinv(a, rtol, hermitian) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index fa8899325879..c619fdf02a80 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -28,6 +28,7 @@ from jax import lax from jax._src import api from jax._src import core +from jax._src import deprecations from jax._src import dtypes from jax._src.numpy import ufuncs from jax._src.numpy.util import ( @@ -1893,8 +1894,10 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No "out != None") raise ValueError(msg) if not isinstance(interpolation, DeprecatedArg): - warnings.warn("The interpolation= argument to 'quantile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + deprecations.warn( + "jax-numpy-quantile-interpolation", + ("The interpolation= argument to 'quantile' is deprecated. " + "Use 'method=' instead."), stacklevel=2) method = interpolation return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False) @@ -1910,8 +1913,10 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = "out != None") raise ValueError(msg) if not isinstance(interpolation, DeprecatedArg): - warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + deprecations.warn( + "jax-numpy-quantile-interpolation", + ("The interpolation= argument to 'nanquantile' is deprecated. " + "Use 'method=' instead."), stacklevel=2) method = interpolation return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True) @@ -2047,8 +2052,10 @@ def percentile(a: ArrayLike, q: ArrayLike, check_arraylike("percentile", a, q) q, = promote_dtypes_inexact(q) if not isinstance(interpolation, DeprecatedArg): - warnings.warn("The interpolation= argument to 'percentile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + deprecations.warn( + "jax-numpy-quantile-interpolation", + ("The interpolation= argument to 'percentile' is deprecated. " + "Use 'method=' instead."), stacklevel=2) method = interpolation return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) @@ -2063,8 +2070,10 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, check_arraylike("nanpercentile", a, q) q = ufuncs.true_divide(q, 100.0) if not isinstance(interpolation, DeprecatedArg): - warnings.warn("The interpolation= argument to 'nanpercentile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + deprecations.warn( + "jax-numpy-quantile-interpolation", + ("The interpolation= argument to 'nanpercentile' is deprecated. " + "Use 'method=' instead."), stacklevel=2) method = interpolation return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 0edc09fa7c14..402e206ef37b 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -27,6 +27,7 @@ from jax import numpy as jnp from jax._src import config +from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.util import NumpyComplexWarning @@ -722,6 +723,20 @@ def np_fun(*args): tol=tol) self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) + @jtu.sample_product( + op=['quantile', 'nanquantile', 'percentile', 'nanpercentile'] + ) + def testQuantileDeprecatedArgs(self, op): + func = getattr(jnp, op) + msg = f"The interpolation= argument to '{op}' is deprecated. " + def assert_warns_or_errors(msg=msg): + if deprecations.is_accelerated("jax-numpy-quantile-interpolation"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(msg): + func(jnp.arange(4), 0.5, interpolation='linear') + @unittest.skipIf(not config.enable_x64.value, "test requires X64") @jtu.run_on_devices("cpu") # test is for CPU float64 precision def testPercentilePrecision(self): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 75bfbbccc1e1..6ca36ffe9035 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -983,6 +983,16 @@ def testClipComplexInputError(self): with self.assertRaisesRegex(ValueError, msg): jnp.clip(x, max=jnp.array([-1+5j])) + def testClipDeprecatedArgs(self): + msg = "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated" + def assert_warns_or_errors(msg=msg): + if deprecations.is_accelerated("jax-numpy-clip-args"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(msg): + jnp.clip(jnp.arange(4), a_min=2, a_max=3) + def testHypotComplexInputError(self): rng = jtu.rand_default(self.rng()) x = rng((5,), dtype=jnp.complex64) @@ -3366,6 +3376,16 @@ def testReshape(self, arg_shape, out_shape, dtype, order): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + def testReshapeDeprecatedArgs(self): + msg = "The newshape argument of jax.numpy.reshape is deprecated." + def assert_warns_or_errors(msg=msg): + if deprecations.is_accelerated("jax-numpy-reshape-newshape"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(msg): + jnp.reshape(jnp.arange(4), newshape=(2, 2)) + @jtu.sample_product( [dict(arg_shape=arg_shape, out_shape=out_shape) for arg_shape, out_shape in [ diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 1f4488fd5014..ce7b0b1991c8 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -30,6 +30,7 @@ from jax import numpy as jnp from jax import scipy as jsp from jax._src import config +from jax._src import deprecations from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge @@ -1155,6 +1156,17 @@ def np_fn(a): # TODO(phawkins): 6e-2 seems like a very loose tolerance. jtu.check_grads(jnp_fn, args_maker(), 1, rtol=6e-2, atol=1e-3) + def testPinvDeprecatedArgs(self): + msg = "The rcond argument for linalg.pinv is deprecated." + def assert_warns_or_errors(msg=msg): + if deprecations.is_accelerated("jax-numpy-linalg-pinv-rcond"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + x = jnp.ones((3, 3)) + with assert_warns_or_errors(msg): + jnp.linalg.pinv(x, rcond=1E-2) + def testPinvGradIssue2792(self): def f(p): a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2) @@ -1197,6 +1209,17 @@ def testMatrixRank(self, shape, dtype): self._CompileAndCheck(jnp.linalg.matrix_rank, args_maker, check_dtypes=False, rtol=1e-3) + def testMatrixRankDeprecatedArgs(self): + msg = "The tol argument for linalg.matrix_rank is deprecated." + def assert_warns_or_errors(msg=msg): + if deprecations.is_accelerated("jax-numpy-linalg-matrix_rank-tol"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + x = jnp.ones((3, 3)) + with assert_warns_or_errors(msg): + jnp.linalg.matrix_rank(x, tol=1E-2) + @jtu.sample_product( shapes=[ [(3, ), (3, 1)], # quick-out codepath From 498ddd50ef8cd0e2d4bc39445a078adf45dd3f12 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 Aug 2024 10:22:59 -0700 Subject: [PATCH 210/702] [Mosaic TPU] Allow overriding memory space assignment of kernel outputs PiperOrigin-RevId: 666400770 --- .../pallas/mosaic/pallas_call_registration.py | 1 + jax/_src/tpu_custom_call.py | 31 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index cfb55240a876..71091af27ca3 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -208,6 +208,7 @@ def _maybe_cast_inputs(*args): device_type=mosaic_params.get("device_type"), internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"), collective_id=mosaic_params.get("collective_id", None), + output_memory_spaces=None, # TODO(apaszke,sharadmv): Implement this. ) _maybe_cast_to_bool = lambda x, aval: x.astype( jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index f77ed0666705..86b6e443e854 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -21,6 +21,7 @@ import collections.abc from collections.abc import Callable, Sequence import dataclasses +import enum import functools import io import os @@ -67,6 +68,20 @@ tpu_custom_call_p.multiple_results = True +class MemorySpace(enum.Enum): + HBM = enum.auto() + VMEM = enum.auto() + + @property + def color(self) -> int: + if self == MemorySpace.HBM: + return 0 + elif self == MemorySpace.VMEM: + return 1 + else: + raise ValueError("invalid memory space: " + str(self)) + + @dataclasses.dataclass(frozen=True) class CostEstimate: flops: int @@ -95,6 +110,7 @@ class CustomCallBackendConfig: allow_input_fusion: list[bool] | None serialization_format: int | None internal_scratch_in_bytes: int | None + output_memory_spaces: tuple[MemorySpace, ...] | None # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -137,6 +153,13 @@ def to_json(self) -> bytes: if self.internal_scratch_in_bytes is not None: config.write(b', "internal_scratch_in_bytes": ') config.write(str(self.internal_scratch_in_bytes).encode("ascii")) + if self.output_memory_spaces is not None: + config.write(b', "output_memory_colors": [') + for i, memory_space in enumerate(self.output_memory_spaces): + if i: + config.write(b",") + config.write(str(memory_space.color).encode("ascii")) + config.write(b"]") config.write(b"}") # End of custom_call_config. if self.device_type is not None: config.write(b', "device_type": ') @@ -420,6 +443,7 @@ def _lower_to_custom_call_config( internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, + output_memory_spaces: tuple[MemorySpace, ...] | None = None, ) -> CustomCallBackendConfig: lowered_module_asm, ( has_communication, @@ -445,6 +469,7 @@ def _lower_to_custom_call_config( has_communication=has_communication, needs_hlo_passes=needs_hlo_passes, needs_layout_passes=needs_layout_passes, + output_memory_spaces=output_memory_spaces, ) @@ -463,6 +488,7 @@ def _lowered_to_custom_call_config( needs_hlo_passes: bool, needs_layout_passes: bool, device_type: str | None, + output_memory_spaces: tuple[MemorySpace, ...] | None = None, ): if has_custom_barrier: if collective_id is None: @@ -492,6 +518,7 @@ def _lowered_to_custom_call_config( allow_input_fusion, serialization_format, internal_scratch_in_bytes, + output_memory_spaces, ) return config @@ -511,6 +538,7 @@ def lower_module_to_custom_call( internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, + output_memory_spaces: tuple[MemorySpace, ...] | None, device_type: str | None, ) -> Sequence[ir.Value]: config = _lower_to_custom_call_config( @@ -524,6 +552,7 @@ def lower_module_to_custom_call( collective_id=collective_id, device_type=device_type, serialization_format=serialization_format, + output_memory_spaces=output_memory_spaces, ) return _tpu_custom_call_lowering( ctx, @@ -550,6 +579,7 @@ def as_tpu_kernel( internal_scratch_in_bytes: int | None = None, collective_id: int | None = None, serialization_format: int | None = 1, + output_memory_spaces: tuple[MemorySpace, ...] | None = None, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" config = _lower_to_custom_call_config( @@ -563,6 +593,7 @@ def as_tpu_kernel( internal_scratch_in_bytes=internal_scratch_in_bytes, collective_id=collective_id, serialization_format=serialization_format, + output_memory_spaces=output_memory_spaces, ) return _as_jax_callable( config, From ef8532bff5dd1c0bbe7b1a6609f10b91001c9134 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Wed, 21 Aug 2024 15:12:39 -0700 Subject: [PATCH 211/702] roll_docstring_added roll_docstring_added see_also_doc_fixed examples_adjusted --- jax/_src/numpy/lax_numpy.py | 38 ++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f5abb3971cfc..aa88db82040f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -7807,9 +7807,45 @@ def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array: dimension=ax) return a -@util.implements(np.roll) def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], axis: int | Sequence[int] | None = None) -> Array: + """Roll the elements of an array along a specified axis. + + JAX implementation of :func:`numpy.roll`. + + Args: + a: input array. + shift: the number of positions to shift the specified axis. If an integer, + all axes are shifted by the same amount. If a tuple, the shift for each + axis is specified individually. + axis: the axis or axes to roll. If ``None``, the array is flattened, shifted, + and then reshaped to its original shape. + + Returns: + A copy of ``a`` with elements rolled along the specified axis or axes. + + See also: + - :func:`jax.numpy.rollaxis`: roll the specified axis to a given position. + + Examples: + >>> a = jnp.array([0, 1, 2, 3, 4, 5]) + >>> jnp.roll(a, 2) + Array([4, 5, 0, 1, 2, 3], dtype=int32) + + Roll elements along a specific axis: + + >>> a = jnp.array([[ 0, 1, 2, 3], + ... [ 4, 5, 6, 7], + ... [ 8, 9, 10, 11]]) + >>> jnp.roll(a, 1, axis=0) + Array([[ 8, 9, 10, 11], + [ 0, 1, 2, 3], + [ 4, 5, 6, 7]], dtype=int32) + >>> jnp.roll(a, [2, 3], axis=[0, 1]) + Array([[ 5, 6, 7, 4], + [ 9, 10, 11, 8], + [ 1, 2, 3, 0]], dtype=int32) + """ util.check_arraylike("roll", a) arr = asarray(a) if axis is None: From d72104de59e5e0c5ca0e0e7f17a9fc4caa0f1c21 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 22 Aug 2024 12:35:59 -0700 Subject: [PATCH 212/702] Use StableHLO filegroup for python APIs in jaxlib MLIR build. PiperOrigin-RevId: 666450684 --- jaxlib/mlir/_mlir_libs/BUILD.bazel | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 06fa9e760a70..b7101a07d989 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -310,7 +310,7 @@ py_extension( py_extension( name = "_chlo", srcs = [ - "@stablehlo//:stablehlo/integrations/python/ChloModule.cpp", + "@stablehlo//:chlo_py_api_files", ], copts = COPTS, linkopts = LINKOPTS, @@ -327,9 +327,7 @@ py_extension( py_extension( name = "_stablehlo", srcs = [ - "@stablehlo//:stablehlo/integrations/python/PortableApi.cpp", - "@stablehlo//:stablehlo/integrations/python/PortableApi.h", - "@stablehlo//:stablehlo/integrations/python/StablehloModule.cpp", + "@stablehlo//:stablehlo_py_api_files", ], copts = COPTS, linkopts = LINKOPTS, From 5da27432e1494c6837e304ba7d1459dfe3466f17 Mon Sep 17 00:00:00 2001 From: Tongfei Guo Date: Thu, 22 Aug 2024 13:31:53 -0700 Subject: [PATCH 213/702] [XLA:SPMD] Check gather/scatter partitioning for index parallel case have the index parallel dimensions matches for operand and indices. PiperOrigin-RevId: 666469705 --- tests/pjit_test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b1ab9a613b58..3af5dfe4cd37 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5128,6 +5128,24 @@ def test_mesh_with_string_axis_names(self): mesh = jax.sharding.Mesh(jax.devices(), 'dp') self.assertTupleEqual(mesh.axis_names, ('dp',)) + def test_sharded_in_place_assignment(self): + mesh = jtu.create_global_mesh((8,), ('data',)) + + idx = [0, 2, 5, 7, 8, 10, 13, 15] + n = 16 + def _init(): + w = jnp.zeros((n, n)) + idx1 = jnp.array(idx) + w = w.at[idx1, jnp.arange(n//2)].set(1) + return w + + w = jax.jit(_init, out_shardings=NamedSharding(mesh, P(None, 'data')))() + + w_gt = np.zeros((n, n)) + for j, i in enumerate(idx): + w_gt[i, j] = 1 + + self.assertArraysEqual(w, w_gt) @jtu.with_config(jax_use_shardy_partitioner=True) class SdyIntegrationTest(jtu.JaxTestCase): From 74d96eeb835d873cc00b1a21a659ef1d2ab73ac6 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 22 Aug 2024 15:12:35 -0700 Subject: [PATCH 214/702] [Pallas TPU] Raise a clear error when trying to load/store to a non-SMEM/non-VMEM buffer PiperOrigin-RevId: 666506411 --- jax/_src/pallas/mosaic/lowering.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 2c5854876009..b9b95fc3f1e8 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1091,6 +1091,13 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): raise ValueError("Can only load scalars from SMEM") return _maybe_cast_load_to_bool( aval_out, memref.LoadOp(ref, starts).result) + elif str(ref_type.memory_space) != "#tpu.memory_space": + extra = "" + if str(ref_type.memory_space) == "#tpu.memory_space": + extra = " ANY memory space can only be accessed using async_copy." + raise ValueError( + "Loads are only allowed on VMEM and SMEM references." + extra + ) load_aval = jax_core.ShapedArray(sizes, dtype=ref_aval.dtype) if need_stride: load_val = tpu.StridedLoadOp( @@ -1226,6 +1233,13 @@ def _masked_swap_lowering_rule( val = _maybe_cast_store_to_memref_type(val_aval, val) memref.StoreOp(val, ref, starts) return result + elif str(ref_type.memory_space) != "#tpu.memory_space": + extra = "" + if str(ref_type.memory_space) == "#tpu.memory_space": + extra = " ANY memory space can only be accessed using async_copy." + raise ValueError( + "Loads and stores are only allowed on VMEM and SMEM references." + extra + ) mem_slice_shape = list(aval_out.shape) for i, a in enumerate(idx_aval.indices): if not isinstance(a, primitives.Slice): From b6f2840f2ab7566acb92fae1c8f08ad35fc8eb8e Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 22 Aug 2024 15:44:54 -0700 Subject: [PATCH 215/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/cc369fd42afddf02ec9ea7775c2798d775f1d219. PiperOrigin-RevId: 666518740 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f260f2dc65ac..2783a0cf5fe9 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "6cdbb866c613947b22607664c32a7e06a23fe666" -XLA_SHA256 = "273b9d3e13f9c922357df4d0fabb1e4e3fc0a80f3848abafa60f33cd49185c10" +XLA_COMMIT = "cc369fd42afddf02ec9ea7775c2798d775f1d219" +XLA_SHA256 = "218b50279f0b61e8f2cdbed0ab23279bc121c0cad4a511fb05643ec3b61bc8b6" def repo(): tf_http_archive( From 07767e81a05ea0ba3ae528237d8ec74484248cc1 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 22 Aug 2024 23:56:32 -0700 Subject: [PATCH 216/702] [Pallas] Add support for casting to/from unsigned integer types. PiperOrigin-RevId: 666663406 --- jax/_src/pallas/mosaic/lowering.py | 24 +++++++++++------------- tests/pallas/tpu_ops_test.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index b9b95fc3f1e8..0f85b3bcdb64 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -85,12 +85,6 @@ map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin -UNSIGNED_TO_SIGNED = { - np.dtype('uint8'): np.dtype('int8'), - np.dtype('uint16'): np.dtype('int16'), - np.dtype('uint32'): np.dtype('int32'), - np.dtype('uint64'): np.dtype('int64'), -} @dataclasses.dataclass class MeshContext: @@ -1543,6 +1537,12 @@ def _convert_helper(x, *, to_dtype): if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4: x = x.astype(jnp.float32) return x.astype(to_dtype) + if jnp.issubdtype(from_dtype, jnp.unsignedinteger): + if from_dtype.itemsize < 4: + x = x.astype(jnp.uint32) + if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4: + x = x.astype(jnp.float32) + return x.astype(to_dtype) if jnp.issubdtype(from_dtype, jnp.floating): if jnp.issubdtype(to_dtype, jnp.signedinteger): if from_dtype.itemsize < 4: @@ -1567,10 +1567,6 @@ def _convert_element_type_lowering_rule( old_dtype = ctx.avals_in[0].dtype out_type = aval_to_ir_type(out_aval) - # TODO(justinfu): Remove after mosaic supports unsigned types. - # This conversion makes mosaic interpret all unsigned types as signed types. - if np.issubdtype(new_dtype, jnp.unsignedinteger): - new_dtype = UNSIGNED_TO_SIGNED[new_dtype] if old_dtype == new_dtype: return x if jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype( @@ -1580,18 +1576,20 @@ def _convert_element_type_lowering_rule( return arith.ExtFOp(out_type, x).result elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4: return arith.TruncFOp(out_type, x).result - elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype( - new_dtype, jnp.signedinteger + elif jnp.issubdtype(old_dtype, jnp.integer) and jnp.issubdtype( + new_dtype, jnp.integer ): if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4: return arith.ExtSIOp(out_type, x).result elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4: return arith.TruncIOp(out_type, x).result + else: # This case triggers when casting signed to unsigned or vice versa. + return x elif jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype( new_dtype, jnp.signedinteger ) and old_dtype.itemsize == new_dtype.itemsize == 4: return arith.FPToSIOp(out_type, x).result - elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype( + elif jnp.issubdtype(old_dtype, jnp.integer) and jnp.issubdtype( new_dtype, jnp.floating ) and old_dtype.itemsize == new_dtype.itemsize == 4: return arith.SIToFPOp(out_type, x).result diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index cc39c879b121..e7d0e04b05b0 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -21,6 +21,7 @@ from absl.testing import parameterized import jax +from jax import lax import jax.numpy as jnp from jax._src import test_util as jtu from jax.experimental import pallas as pl @@ -173,6 +174,18 @@ def kernel(x_ref, y_ref, out_ref): )(x, y) np.testing.assert_array_equal(out, inp.reshape(m * 2, n)) + def test_tpu_unsigned_int(self): + def body(x_ref, o_ref): + # Test cast from uint16 -> uint32 + ux = lax.convert_element_type(x_ref[...], jnp.uint32) + res = ux + 1 + # Test cast from uint32 -> float32 + o_ref[...] = res.astype(jnp.float32) + out = jax.ShapeDtypeStruct((8, 128), jnp.float32) + x = jnp.arange(8 * 128, dtype=jnp.uint16).reshape((8, 128)) + result = self.pallas_call(body, out_shape=out)(x) + np.testing.assert_array_equal(result, x.astype(jnp.float32) + 1.0) + class OpsInterpretTest(OpsTest): INTERPRET = True From c430b0c5e3d158bb66c0b0863474e81eb0162fc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Fri, 23 Aug 2024 03:20:55 -0700 Subject: [PATCH 217/702] Activate QR Factorization to XLA's FFI PiperOrigin-RevId: 666722604 --- jax/_src/export/_export.py | 19 +- .../cpu_qr_lapack_geqrf.py | 465 ++++++++++++++++++ jax/_src/lax/linalg.py | 15 +- jaxlib/lapack.py | 85 ++-- tests/export_back_compat_test.py | 10 + 5 files changed, 556 insertions(+), 38 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 1bc9d8ab1c8c..88e4f21546fe 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -923,9 +923,22 @@ def _check_lowering(lowering) -> None: "\n".join(not_implemented_msgs)) _CPU_FFI_KERNELS = [ - "lapack_spotrf_ffi", "lapack_dpotrf_ffi", "lapack_cpotrf_ffi", "lapack_zpotrf_ffi", - "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", - "lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi", + "lapack_spotrf_ffi", + "lapack_dpotrf_ffi", + "lapack_cpotrf_ffi", + "lapack_zpotrf_ffi", + "lapack_sgeqrf_ffi", + "lapack_dgeqrf_ffi", + "lapack_cgeqrf_ffi", + "lapack_zgeqrf_ffi", + "lapack_sgesdd_ffi", + "lapack_dgesdd_ffi", + "lapack_cgesdd_ffi", + "lapack_zgesdd_ffi", + "lapack_sgetrf_ffi", + "lapack_dgetrf_ffi", + "lapack_cgetrf_ffi", + "lapack_zgetrf_ffi", ] # These are the JAX custom call target names that are guaranteed to be stable. # Their backwards compatibility is tested by back_compat_test.py. diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py index 21448430ead6..045e8df55cd2 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py @@ -98,6 +98,7 @@ xla_call_module_version=4, ) # End paste + # Pasted from the test output (see back_compat_test.py module docstring) data_2023_03_17["f64"] = dict( testdata_version=1, @@ -180,6 +181,7 @@ xla_call_module_version=4, ) # End paste + # Pasted from the test output (see back_compat_test.py module docstring) data_2023_03_17["c64"] = dict( testdata_version=1, @@ -346,3 +348,466 @@ mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa6\x02\n\x029\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0bOO/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x039\x0f\x17\x0f\x07\x0b\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x07\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x16\n\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f+\x01\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x03!\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x0b\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03'\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03)\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x031\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x033\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x037\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xd2\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x8dW\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf\x00lapack_zungqr\x00callee\x00", xla_call_module_version=4, ) # End paste + + +data_2024_08_22 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_22['c128'] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr'], + serialized_date=datetime.date(2024, 8, 22), + inputs=(), + expected_outputs=( + array([ + [0.0 + 0.0j, 0.9128709291752773 + 0.0j, 0.40824829046386235 + 0.0j], + [ + -0.447213595499958 - 0.0j, + 0.3651483716701102 + 0.0j, + -0.8164965809277263 + 0.0j, + ], + [ + -0.894427190999916 - 0.0j, + -0.1825741858350548 + 0.0j, + 0.40824829046386324 + 0.0j, + ], + ]), + array([ + [ + -6.7082039324993694e00 + 0.0j, + -8.0498447189992444e00 + 0.0j, + -9.3914855054991175e00 + 0.0j, + ], + [ + 0.0000000000000000e00 + 0.0j, + 1.0954451150103341e00 + 0.0j, + 2.1908902300206665e00 + 0.0j, + ], + [ + 0.0000000000000000e00 + 0.0j, + 0.0000000000000000e00 + 0.0j, + -8.8817841970012523e-16 + 0.0j, + ], + ]), + ), + mlir_module_text=r""" +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) +#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) + %c = stablehlo.constant dense<3> : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) + %2:2 = stablehlo.custom_call @lapack_zgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) + %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) + %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) + %4:3 = stablehlo.custom_call @lapack_zungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) + %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) + %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) + %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) + %cst_10 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc9) + %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) + %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) + %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) + return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) + %c = stablehlo.constant dense<-1> : tensor loc(#loc10) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) + %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) + %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) + %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) + return %6 : tensor<3x3xcomplex> loc(#loc10) + } loc(#loc10) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) +#loc4 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) +#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) +#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) +#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) +#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) +#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) +#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) +#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +""", + mlir_module_serialized=( + b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0bOO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xf2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x0b\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xaa\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8dW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" + b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' + b' out_layouts=(None,) resource_env=None donated_invars=(False,)' + b' name=triu keep_unused=False' + b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' + b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128' + b' shape=(9,)' + b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' + b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' + b' 0, 0), (0, 0,' + b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf_ffi\x00lapack_zungqr\x00callee\x00' + ), + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_22['c64'] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr'], + serialized_date=datetime.date(2024, 8, 22), + inputs=(), + expected_outputs=( + array( + [ + [0.0 + 0.0j, 0.91287076 + 0.0j, 0.4082487 + 0.0j], + [-0.44721356 - 0.0j, 0.36514866 + 0.0j, -0.8164965 + 0.0j], + [-0.8944271 - 0.0j, -0.18257445 + 0.0j, 0.40824816 + 0.0j], + ], + dtype=complex64, + ), + array( + [ + [ + -6.7082043e00 + 0.0j, + -8.0498438e00 + 0.0j, + -9.3914852e00 + 0.0j, + ], + [0.0000000e00 + 0.0j, 1.0954441e00 + 0.0j, 2.1908894e00 + 0.0j], + [ + 0.0000000e00 + 0.0j, + 0.0000000e00 + 0.0j, + 7.1525574e-07 + 0.0j, + ], + ], + dtype=complex64, + ), + ), + mlir_module_text=r""" +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) +#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) + %c = stablehlo.constant dense<3> : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) + %2:2 = stablehlo.custom_call @lapack_cgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) + %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) + %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) + %4:3 = stablehlo.custom_call @lapack_cungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) + %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) + %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) + %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) + %cst_10 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc9) + %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) + %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) + %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) + return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) + %c = stablehlo.constant dense<-1> : tensor loc(#loc10) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) + %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) + %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) + %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) + return %6 : tensor<3x3xcomplex> loc(#loc10) + } loc(#loc10) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) +#loc4 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) +#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) +#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) +#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) +#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) +#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) +#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) +#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +""", + mlir_module_serialized=( + b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xb2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\t\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xa6\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8bW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" + b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' + b' out_layouts=(None,) resource_env=None donated_invars=(False,)' + b' name=triu keep_unused=False' + b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' + b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex64' + b' shape=(9,)' + b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' + b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' + b' 0, 0), (0, 0,' + b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_cgeqrf_ffi\x00lapack_cungqr\x00callee\x00' + ), + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_22['f32'] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr'], + serialized_date=datetime.date(2024, 8, 22), + inputs=(), + expected_outputs=( + array( + [ + [0.0, 0.91287076, 0.4082487], + [-0.44721356, 0.36514866, -0.8164965], + [-0.8944271, -0.18257445, 0.40824816], + ], + dtype=float32, + ), + array( + [ + [-6.7082043e00, -8.0498438e00, -9.3914852e00], + [0.0000000e00, 1.0954441e00, 2.1908894e00], + [0.0000000e00, 0.0000000e00, 7.1525574e-07], + ], + dtype=float32, + ), + ), + mlir_module_text=r""" +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) +#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc5) + %c = stablehlo.constant dense<3> : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) + %2:2 = stablehlo.custom_call @lapack_sgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>) loc(#loc6) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc8) + %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) + %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) + %4:3 = stablehlo.custom_call @lapack_sorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<96xf32>) loc(#loc9) + %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) + %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) + %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) + %cst_10 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc9) + %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc9) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) + %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc9) + %11 = call @triu(%2#0) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc10) + return %10, %11 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) + } loc(#loc) + func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) + %c = stablehlo.constant dense<-1> : tensor loc(#loc10) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) + %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) + %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc15) + %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc16) + return %6 : tensor<3x3xf32> loc(#loc10) + } loc(#loc10) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) +#loc4 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) +#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) +#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) +#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) +#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) +#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) +#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) +#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +""", + mlir_module_serialized=( + b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b\x1fO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\x8a\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\t\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\t\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\t\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" + b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' + b' out_layouts=(None,) resource_env=None donated_invars=(False,)' + b' name=triu keep_unused=False' + b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' + b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32' + b' shape=(9,)' + b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' + b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' + b' 0, 0), (0, 0,' + b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_sgeqrf_ffi\x00lapack_sorgqr\x00callee\x00' + ), + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_22['f64'] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr'], + serialized_date=datetime.date(2024, 8, 22), + inputs=(), + expected_outputs=( + array([ + [0.0, 0.9128709291752773, 0.40824829046386235], + [-0.447213595499958, 0.3651483716701102, -0.8164965809277263], + [-0.894427190999916, -0.1825741858350548, 0.40824829046386324], + ]), + array([ + [ + -6.7082039324993694e00, + -8.0498447189992444e00, + -9.3914855054991175e00, + ], + [ + 0.0000000000000000e00, + 1.0954451150103341e00, + 2.1908902300206665e00, + ], + [ + 0.0000000000000000e00, + 0.0000000000000000e00, + -8.8817841970012523e-16, + ], + ]), + ), + mlir_module_text=r""" +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) +#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<9xf64> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> loc(#loc5) + %c = stablehlo.constant dense<3> : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) + %2:2 = stablehlo.custom_call @lapack_dgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xf64>) loc(#loc6) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> loc(#loc8) + %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) + %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) + %4:3 = stablehlo.custom_call @lapack_dorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf64>, tensor<3xf64>) -> (tensor<3x3xf64>, tensor, tensor<96xf64>) loc(#loc9) + %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) + %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) + %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) + %cst_10 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc9) + %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc9) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) + %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc9) + %11 = call @triu(%2#0) : (tensor<3x3xf64>) -> tensor<3x3xf64> loc(#loc10) + return %10, %11 : tensor<3x3xf64>, tensor<3x3xf64> loc(#loc) + } loc(#loc) + func.func private @triu(%arg0: tensor<3x3xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) + %c = stablehlo.constant dense<-1> : tensor loc(#loc10) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) + %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) + %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc15) + %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc16) + return %6 : tensor<3x3xf64> loc(#loc10) + } loc(#loc10) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) +#loc4 = loc("jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) +#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) +#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) +#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) +#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) +#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) +#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) +#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +""", + mlir_module_serialized=( + b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xaa\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x0b\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" + b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' + b' out_layouts=(None,) resource_env=None donated_invars=(False,)' + b' name=triu keep_unused=False' + b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' + b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float64' + b' shape=(9,)' + b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' + b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' + b' 0, 0), (0, 0,' + b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_dgeqrf_ffi\x00lapack_dorgqr\x00callee\x00' + ), + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 0d0db8eb38c1..45eed43e0b4f 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1625,8 +1625,19 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *, a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) - a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a, - a_shape_vals=a_shape_vals) + # TODO(b/344892332): Remove the conditional after the compatibility period + ctx_args = ( + (ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else () + ) + a_out, taus, *maybe_info_geqrf = geqrf_impl( + *ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals + ) + if not ctx.is_forward_compat(): + # Skip the info parameter verification for the FFI kernel. + return a_out, taus + # TODO(b/344892332): This parameter will no longer be needed after + # the forward compatibility period + info_geqrf = maybe_info_geqrf[0] zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) ok = mlir.compare_hlo(info_geqrf, zeros, "EQ", "SIGNED") select_ok_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_)) diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 09bb75597904..a71f219acd1d 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -202,9 +202,10 @@ def getrf_hlo(dtype, a: ir.Value, *, a_shape_vals: tuple[DimensionSize, ...]): # # ?geqrf: QR decomposition -def geqrf_hlo(dtype, a: ir.Value, *, - a_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() + +def geqrf_hlo( + ctx, dtype, a: ir.Value, *, a_shape_vals: tuple[DimensionSize, ...] +): a_type = ir.RankedTensorType(a.type) assert len(a_shape_vals) >= 2 m, n = a_shape_vals[-2:] @@ -213,51 +214,69 @@ def geqrf_hlo(dtype, a: ir.Value, *, batch_dims_vals = a_shape_vals[:-2] num_bd = len(batch_dims_vals) + fn_base = prepare_lapack_call(fn_base="geqrf", dtype=dtype) - if dtype == np.float32: - fn = "lapack_sgeqrf" - lwork = _lapack.lapack_sgeqrf_workspace(m, n) - elif dtype == np.float64: - fn = "lapack_dgeqrf" - lwork = _lapack.lapack_dgeqrf_workspace(m, n) - elif dtype == np.complex64: - fn = "lapack_cgeqrf" - lwork = _lapack.lapack_cgeqrf_workspace(m, n) - elif dtype == np.complex128: - fn = "lapack_zgeqrf" - lwork = _lapack.lapack_zgeqrf_workspace(m, n) - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + if ctx.is_forward_compat(): + fn = fn_base + if dtype == np.float32: + lwork = _lapack.lapack_sgeqrf_workspace(m, n) + elif dtype == np.float64: + lwork = _lapack.lapack_dgeqrf_workspace(m, n) + elif dtype == np.complex64: + lwork = _lapack.lapack_cgeqrf_workspace(m, n) + elif dtype == np.complex128: + lwork = _lapack.lapack_zgeqrf_workspace(m, n) + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + scalar_layout = [] + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + shape_type_pairs: Sequence[ShapeTypePair] = [ + (a_shape_vals, a_type.element_type), + (batch_dims_vals + (min(m, n),), a_type.element_type), + (batch_dims_vals, i32_type), + ([lwork], a_type.element_type), + ] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + return custom_call( + fn, + result_types=result_types, + operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a], + operand_layouts=[scalar_layout] * 4 + [layout], + result_layouts=[ + layout, + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + [0], + ], + operand_output_aliases={4: 0}, + result_shapes=result_shapes, + ).results[:3] + fn = fn_base + "_ffi" shape_type_pairs: Sequence[ShapeTypePair] = [ (a_shape_vals, a_type.element_type), (batch_dims_vals + (min(m, n),), a_type.element_type), - (batch_dims_vals, i32_type), - ([lwork], a_type.element_type), ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - out = custom_call( + return custom_call( fn, result_types=result_types, - operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a], - operand_layouts=[scalar_layout] * 4 + [layout], + operands=[a], + operand_layouts=[layout], result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], + layout, + tuple(range(num_bd, -1, -1)), ], - operand_output_aliases={4: 0}, + operand_output_aliases={0: 0}, result_shapes=result_shapes, + backend_config={}, + api_version=4, ).results - return out[:3] # # ?orgqr: product of elementary Householder reflectors: diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 92072d0a0168..26eb34088460 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -114,6 +114,7 @@ def test_custom_call_coverage(self): targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) cpu_ffi_testdatas = [ cpu_cholesky_lapack_potrf.data_2024_05_31, + cpu_qr_lapack_geqrf.data_2024_08_22, cpu_lu_lapack_getrf.data_2024_05_31, cpu_svd_lapack_gesdd.data_2024_08_13, ] @@ -397,6 +398,15 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] self.run_one_test(func, data, rtol=rtol) + # TODO(b/344892332): Remove the check after the compatibility period. + has_xla_ffi_support = jaxlib_version >= (0, 4, 32) + if has_xla_ffi_support: + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{batched}", From c76787571b81c5538226f8856dc2c3f87e6d4a2a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 23 Aug 2024 05:48:29 -0700 Subject: [PATCH 218/702] [Mosaic GPU] Expose wait_parity on collective barrier PiperOrigin-RevId: 666761011 --- jax/experimental/mosaic/gpu/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 64c9f409ef7f..3c0cedfc2807 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -712,8 +712,11 @@ def arrive(self): has_side_effects=True, ) - def wait(self): - self.barrier.wait() + def wait(self, *args, **kwargs): + self.barrier.wait(*args, **kwargs) + + def wait_parity(self, *args, **kwargs): + self.barrier.wait_parity(*args, **kwargs) class Partition: From f54e2204308cc647650d795635c3f58875dacb8c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 23 Aug 2024 06:07:53 -0700 Subject: [PATCH 219/702] [Mosaic GPU] Add support for short n dimension in WGMMA PiperOrigin-RevId: 666766079 --- jax/experimental/mosaic/gpu/utils.py | 2 + jax/experimental/mosaic/gpu/wgmma.py | 22 +++++++++-- tests/mosaic/gpu_test.py | 56 ++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 4 deletions(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 3c0cedfc2807..546411c82c4c 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -309,6 +309,8 @@ class DynamicSlice: def memref_slice(ref: ir.Value, index) -> ir.Value: ref_ty = ir.MemRefType(ref.type) base_indices, slice_shape, is_squeezed = parse_indices(index, ref_ty.shape) + # TODO(apaszke): Check that slice is within the memref (indices might be + # dynamic, but we can at least catch some OOB slices). memref_strides, offset = ref_ty.get_strides_and_offset() new_offset = offset diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index fc2fe892ac03..b64418022d0e 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -156,14 +156,14 @@ def wgmma_m64( out_ty = ir.VectorType(acc.flat[0].type).element_type if not _supported_wgmma_types(out_ty, element_type): raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}") + if n % 8: + raise ValueError i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) index = ir.IndexType.get() if b_k_stride % 16: raise ValueError - if n % (swizzle // bytewidth(element_type)): - raise ValueError # Only 16-bit types support transposes supports_transpose = bytewidth(element_type) == 2 if not supports_transpose and (a_transpose or b_transpose): @@ -326,7 +326,15 @@ def wgmma( kn_tile = swizzle // element_bytewidth groups_k, groups_n = b_ty.shape[:2] - if b_ty.shape[2:] != [kn_tile, kn_tile]: + k_group_size, n_group_size = ( + b_ty.shape[2:] if b_order == WGMMALayout.ROW_MAJOR else b_ty.shape[:1:-1] + ) + # Note that while this technically allows n to be smaller than kn_tile, + # the stride checks below will still enforce that the memory region is padded. + # It might be possible to relax that requirement, but I haven't tested it. + if n_group_size > kn_tile and n_group_size % kn_tile: + raise ValueError(n_group_size, kn_tile) + if k_group_size != kn_tile: raise ValueError(b_ty.shape) if a_in_regs: @@ -353,6 +361,12 @@ def wgmma( if a_order == WGMMALayout.COL_MAJOR and swizzle != 128: # Not sure what the layout is like, since the tiles aren't square. raise NotImplementedError + expected_acc_shape = (groups_m * 64, groups_n * n_group_size) + if acc.value.shape != expected_acc_shape: + raise ValueError( + f"Accumulator shape mismatch: expected {expected_acc_shape}, got" + f" {acc.value.shape}" + ) row_major = WGMMALayout.ROW_MAJOR col_major = WGMMALayout.COL_MAJOR @@ -375,7 +389,7 @@ def wgmma( b_transpose=b_order == row_major, a_k_stride=(2 if a_order == row_major else 128) << 4, b_k_stride=(swizzle if b_order == row_major else 2) << 4, - n=(groups_n * kn_tile), + n=(groups_n * n_group_size), swizzle=swizzle, element_type=ir.FloatTF32Type.get() if ir.F32Type.isinstance(element_type) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index e2bca54ab230..b7f99ab7b290 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -702,6 +702,62 @@ def kernel(ctx, rhs, out, rhs_smem): rtol = 5e-4 np.testing.assert_allclose(z, ref, rtol=rtol, atol=0) + @parameterized.product( + rhs_transpose=(False, True), + swizzle=(32, 64, 128), + ) + def test_narrow_n(self, rhs_transpose, swizzle): + m, n, k_steps = 64, 8, 2 + + row_major = mgpu.WGMMALayout.ROW_MAJOR + col_major = mgpu.WGMMALayout.COL_MAJOR + rhs_order = col_major if rhs_transpose else row_major + bytewidth = 2 + nk_tile = swizzle // bytewidth + k = nk_tile * k_steps + + def kernel(ctx, rhs, out, smem): + rhs_smem, barrier = smem + gmem_slice = (ds(0, k), ds(0, nk_tile)) + smem_slice = (slice(None), slice(None), slice(None), ds(0, n)) + transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),) + if rhs_transpose: + gmem_slice = gmem_slice[::-1] + smem_slice = (slice(None), slice(None), ds(0, n), slice(None)) + transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + swizzle=swizzle, + gmem_slice=gmem_slice, + gmem_transform=transform, + barrier=barrier, + ) + barrier.wait() + init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n) + lhs_regs = iota_tensor(m, k, ir.F16Type.get()) + rhs_smem = memref_slice(rhs_smem, smem_slice) + acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order, swizzle=swizzle) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(0) + acc.value.store_untiled(out) + + jax_dtype = jnp.float16 + y_shape = (n, k) if rhs_transpose else (k, n) + y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + rhs_scratch_shape = jax.ShapeDtypeStruct( + (k_steps, 1, nk_tile, nk_tile), jax_dtype + ) + z = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), y, out_shape, (rhs_scratch_shape, mgpu.TMABarrier()), + )(y) + x = np.arange(m * k, dtype=jax_dtype).reshape(m, k) + ref = jax.lax.dot( + x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32 + ) + np.testing.assert_allclose(z, ref, rtol=5e-4, atol=0) + class BarrierTest(TestCase): From 71b7e7891602919d06790d8518eb9df6b041b5d1 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Fri, 23 Aug 2024 06:50:14 -0700 Subject: [PATCH 220/702] Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests. Tests fixed include: - `test_globally_sharded_key_array_8x4_multi_device` - Issue was in `replicate_trailing_dims` where an `xc.OpSharding` was always created. Fixed by creating an equivalent SDY sharding. - `test_aot_out_info` - Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in `lower_jaxpr_to_module` when adding the sdy MeshOp (there won't be any propagation) - `test_concurrent_pjit` - In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same. - `test_globally_sharded_key_array_result_8x4_single_device` - This tests adds a WSC when no `mesh_shape_tuple` exists (`"sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>`), so we should create a mesh named `mesh` with a single device id in case it doesn't exist. - `testLowerCostAnalysis` - This calls into `mlir_module_to_xla_computation` which calls its own MLIR parsing function in `//third_party/tensorflow/compiler/xla/python/mlir.cc`. Needed to register the SDY dialect in it. - `testShardingConstraintWithArray` - This calls `.compiler_ir(dialect="hlo")` which calls `PyMlirModuleToXlaComputation` which converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO. PiperOrigin-RevId: 666777167 --- jax/_src/interpreters/mlir.py | 27 ++++++---- tests/BUILD | 5 ++ tests/pjit_test.py | 93 ++++++++++++++++++++++++++++++----- 3 files changed, 105 insertions(+), 20 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index bc1c00948943..75353197ae2b 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1119,14 +1119,15 @@ def lower_jaxpr_to_module( # XLA computation preserves the module name. attrs = ctx.module.operation.attributes if config.use_shardy_partitioner.value: - assert (isinstance(axis_context, sharding_impls.ShardingContext) and - axis_context.mesh_shape is not None) - ctx.module.body.append( - dialects.sdy.MeshOp( - "mesh", - dialects.sdy.MeshAttr.get( - [dialects.sdy.MeshAxisAttr.get(name, size) - for name, size in axis_context.mesh_shape]))) + if (isinstance(axis_context, sharding_impls.ShardingContext) and + axis_context.mesh_shape is not None): + sdy_mesh_attr = dialects.sdy.MeshAttr.get( + [dialects.sdy.MeshAxisAttr.get(name, size) + for name, size in axis_context.mesh_shape]) + else: + sdy_mesh_attr = dialects.sdy.MeshAttr.get([]) + + ctx.module.body.append(dialects.sdy.MeshOp("mesh", sdy_mesh_attr)) module_name = _module_name_regex.sub("_", module_name) attrs["sym_name"] = ir.StringAttr.get(module_name) attrs["mhlo.num_replicas"] = i32_attr(num_replicas) @@ -1633,7 +1634,15 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). # The below custom call achieves the sharding like above example. - return wrap_with_sharding_op( + if config.use_shardy_partitioner.value: + physical_ndim = core.physical_aval(aval).ndim + s = sharding.SdyArraySharding( + mesh_name='mesh', + dimension_shardings=[sharding.SdyDimSharding(axes=[], is_closed=i >= aval.ndim) + for i in range(physical_ndim)]) + return wrap_with_sharding_op(ctx, val, aval, s) + else: + return wrap_with_sharding_op( ctx, val, aval, xc.HloSharding.replicate().to_proto(), unspecified_dims=set(range(aval.ndim))) diff --git a/tests/BUILD b/tests/BUILD index 14d1d409c2ce..eab1d11287e2 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -234,6 +234,11 @@ jax_test( "tpu": ["notsan"], # Times out under tsan. "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, + enable_configs = [ + "gpu_2gpu_shardy", + "tpu_df_2x2_shardy", + "tpu_pf_2x2_shardy", + ], shard_count = { "cpu": 5, "gpu": 5, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3af5dfe4cd37..392a25f32612 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -402,6 +402,8 @@ def f(inp1, inp2, inp3): @jtu.run_on_devices('tpu') def testBufferDonationWithOutputShardingInferenceAndTokens(self): + if config.use_shardy_partitioner.value: + self.skipTest('b/355263220: Shardy does not support callbacks yet.') mesh = jtu.create_global_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) @@ -453,10 +455,16 @@ def f(x): check_dtypes=False) hlo = f.lower(np.ones(shape)).compiler_ir() - # Annotation from with_sharding_constraint - self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo)) - # Annotation from pjit - self.assertIn('sharding = "{replicated}"', str(hlo)) + if config.use_shardy_partitioner.value: + # Annotation from with_sharding_constraint + self.assertIn('<@mesh, [{"x"}, {"y"}]>', str(hlo)) + # Annotation from pjit + self.assertIn('sharding = #sdy.sharding<@mesh, [{}, {}]>}', str(hlo)) + else: + # Annotation from with_sharding_constraint + self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo)) + # Annotation from pjit + self.assertIn('sharding = "{replicated}"', str(hlo)) def testShardingConstraintWithArray(self): mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) @@ -484,6 +492,8 @@ def f(x): self.assertIn("sharding={replicated}", hlo.as_hlo_text()) def testShardingConstraintWithArrayOpSharding(self): + if config.use_shardy_partitioner.value: + self.skipTest("Shardy doesn't support PositionalSharding") shape = (8, 8) mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) @@ -555,8 +565,12 @@ def f(x): self.assertLen(actual[0]['a'].addressable_shards, 4) mlir_str = str(f.lower(x).compiler_ir()) - self.assertIn("unspecified_dims=[0]", mlir_str) - self.assertIn("unspecified_dims=[1]", mlir_str) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {"y"}, {}]>', mlir_str) + self.assertIn('<@mesh, [{"x"}, {?}, {}]>', mlir_str) + else: + self.assertIn("unspecified_dims=[0]", mlir_str) + self.assertIn("unspecified_dims=[1]", mlir_str) @jtu.with_mesh([('x', 2), ('y', 2)]) def testShardingConstraintPyTreeVmapWithUnconstrainedDims(self): @@ -575,8 +589,12 @@ def f(x): x = [{'a': v, 'b': v * 2}, v * 3] mlir_str = str(f.lower(x).compiler_ir()) - self.assertIn("unspecified_dims=[0,1]", mlir_str) - self.assertIn("unspecified_dims=[0,2]", mlir_str) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {"y"}]>', mlir_str) + self.assertIn('<@mesh, [{?}, {"x"}, {?}]>', mlir_str) + else: + self.assertIn("unspecified_dims=[0,1]", mlir_str) + self.assertIn("unspecified_dims=[0,2]", mlir_str) def testCaching(self): def f(x): @@ -847,6 +865,9 @@ def f_for_pjit(x): def testOutfeed(self): if xla_bridge.using_pjrt_c_api(): raise unittest.SkipTest('outfeed not implemented in PJRT C API') + if config.use_shardy_partitioner.value: + self.skipTest( + 'b/355263220: outfeed lowering not supported by Shardy') devices = np.array(jax.local_devices()) nr_devices = len(devices) @@ -1280,6 +1301,9 @@ class CustomPartitionerTest(jtu.JaxTestCase): def skip_if_custom_partitioning_not_supported(self): if jtu.is_cloud_tpu(): raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") + if config.use_shardy_partitioner.value: + self.skipTest( + 'Custom partitioning is not supported with Shardy yet.') @jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU. @jtu.with_mesh([('x', 4), ('y', 2)]) @@ -1564,6 +1588,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase): ) def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, mesh_axis_names): + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1580,6 +1606,8 @@ def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, self.assertArraysEqual(out._value, input_data) def test_xla_arr_sharding_mismatch(self): + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (6, 2) input_data = np.arange( @@ -1607,6 +1635,8 @@ def test_xla_arr_sharding_mismatch(self): compiled(arr) def test_gda_auto_shardings_len(self): + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (4, 2) input_data = np.arange( @@ -1627,6 +1657,8 @@ def test_gda_auto_shardings_len(self): ) def test_jit_arr_partial_auto_sharding_array( self, mesh_shape, mesh_axis_names, pspec): + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) input_data = np.arange( @@ -1667,6 +1699,8 @@ def test_jit_auto_sharding_partial_tuple_input_shardings( self, mesh_shape, mesh_axis_names): if not jtu.test_device_matches(["tpu"]): self.skipTest('Parameters are tupled only on TPU if >2000 parameters') + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) @@ -1838,6 +1872,11 @@ def _checks(out, input_data): ) def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, s2_shape, s3_shape, s4_shape): + if config.use_shardy_partitioner.value: + self.skipTest( + 'TODO(b/355263220) Shardy conflict resolution is not complete. Issue ' + 'here is that for `a1 @ a1.T` GSPMD gives dim 0 sharded on `x` while ' + 'Shardy gives it fully replicated.') global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) global_input_shape = (8, 2) @@ -2400,6 +2439,10 @@ def test_device_put_sharding_prng(self): self.assertTrue(jax.dtypes.issubdtype(a.dtype, jax.dtypes.prng_key)) self.assertEqual(a.sharding, out_p.sharding) + if config.use_shardy_partitioner.value: + # OpSharding is not supported in shardy. + return + op = xc.OpSharding() op.type = xc.OpSharding.Type.OTHER op.tile_assignment_dimensions = [8] @@ -3405,6 +3448,8 @@ def g(x): jtu.check_grads(g, (arr,), order=2) def test_pjit_out_sharding_preserved(self): + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3483,6 +3528,8 @@ def test_list_in_pspec(self): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) def test_sharding_preserved_trivial(self): + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3535,6 +3582,8 @@ def test_sharding_on_output_with_vmap(self): self.assertEqual(count[0], 1) def test_jit_mul_sum_sharding_preserved(self): + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3608,6 +3657,8 @@ def test_none_out_sharding(self): self.assertEqual(out2.sharding.spec, P()) def test_sharding_preserved_apply_primitive(self): + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) @@ -3848,6 +3899,9 @@ def f(): f() # doesn't crash def test_lowering_cache_hit_different_devices(self): + if config.use_shardy_partitioner.value: + self.skipTest('b/358322664: different axis names results in ' + 'a cache miss with Shardy.') if jax.device_count() < 4: self.skipTest('Requires >=4 devices') @@ -3945,7 +3999,10 @@ def make_keys(seeds): self.assertEqual(base_array.sharding, NamedSharding(mesh, P('y', 'x', None))) lowered_text = make_keys.lower(seeds).as_text() - self.assertIn('unspecified_dims=[0,1]', lowered_text) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text) + else: + self.assertIn('unspecified_dims=[0,1]', lowered_text) def test_prng_sharding_propagation_with_nested_jit(self): input_shape = (8, 2) @@ -3971,7 +4028,10 @@ def f(): self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', None))) lowered_text = make_keys.lower(seeds).as_text() - self.assertIn('unspecified_dims=[0,1]', lowered_text) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text) + else: + self.assertIn('unspecified_dims=[0,1]', lowered_text) def test_partial_sharded_prng_key_inp(self): input_shape = (8, 2, 2) @@ -3995,7 +4055,10 @@ def make_keys(seeds): self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) lowered_text = make_keys.lower(seeds).as_text() - self.assertIn('unspecified_dims=[0,1,2]', lowered_text) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {?}, {}]>', lowered_text) + else: + self.assertIn('unspecified_dims=[0,1,2]', lowered_text) def test_jit_partially_specified_shardings(self): @@ -4048,6 +4111,8 @@ def f(*args): f(inps) # doesn't crash def test_spmd_preserves_input_sharding_vmap_grad(self): + if config.use_shardy_partitioner.value: + self.skipTest("Shardy doesn't support PositionalSharding") # https://github.com/google/jax/issues/20710 n_devices = jax.device_count() sharding = PositionalSharding(jax.devices()) @@ -4211,6 +4276,9 @@ def f(x): self.assertArraysEqual(out2, np.arange(8) * 2) def test_device_put_efficient_reshard_single_host(self): + if config.use_shardy_partitioner.value: + self.skipTest( + '_different_device_order_reshard is creating a GSPMDSharding') if jax.device_count() < 4: self.skipTest('Requires >= 4 devices') @@ -4235,6 +4303,9 @@ def test_device_put_efficient_reshard_single_host(self): ("8_384", (8, 384)), ) def test_device_put_efficient_reshard_complex_mesh(self, shape): + if config.use_shardy_partitioner.value: + self.skipTest( + '_different_device_order_reshard is creating a GSPMDSharding') if jax.device_count() < 8: self.skipTest('Requires >= 8 devices') From be59f6ec4748a3e55b832ee4b3e5715608a6d99a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 23 Aug 2024 08:05:43 -0700 Subject: [PATCH 221/702] [Mosaic GPU] Support tiled stores of arrays with fewer columns than swizzling PiperOrigin-RevId: 666798285 --- .../mosaic/gpu/fragmented_array.py | 13 ++++++-- tests/mosaic/gpu_test.py | 31 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 406f5eaba3b2..892cd2d09332 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -686,6 +686,8 @@ def store_tiled(self, ref, swizzle: int | None): assert m % 64 == 0 # This is implied by the layout. cols_per_tile = swizzle // bw expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] + if n < cols_per_tile: # We allow singular tiles shorter than swizzle. + expected_shape = [m // 64, 1, 64, cols_per_tile] if ir.MemRefType(ref.type).shape != expected_shape: raise ValueError(ref.type, (m, n)) for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle): @@ -715,9 +717,12 @@ def transfer_tiled(shape, dtype, swizzle: int | None): # TODO(apaszke): We could use ldmatrix/stmatrix for 16-bit types. bw = mgpu.bytewidth(dtype) m, n = shape - cols_per_tile = swizzle // bw - if n % cols_per_tile != 0: - raise NotImplementedError + assert m % 64 == 0 and n % 8 == 0 # Implied by the layout. + cols_per_tile = swizzle_elems = swizzle // bw + if n < swizzle_elems: + cols_per_tile = n + else: + assert n % swizzle_elems == 0, (n, swizzle_elems) if swizzle not in {32, 64, 128}: raise NotImplementedError("Only swizzled stores supported") @@ -752,6 +757,8 @@ def transfer_tiled(shape, dtype, swizzle: int | None): case _: raise AssertionError(swizzle) stagger_amount = swizzle // 64 + if (cols_per_tile // 8) % (stagger_amount + 1): + raise NotImplementedError else: # We rely on canonicalization to clean up the selects. i1 = ir.IntegerType.get_signless(1) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index b7f99ab7b290..ce1c02f5a01b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -423,6 +423,37 @@ def kernel(ctx, out, smem): )() np.testing.assert_array_equal(iota, expected) + @parameterized.product( + dtypes=( + (ir.F16Type.get, jnp.float16), + (partial(ir.IntegerType.get_signless, 8), jnp.int8), + ), + swizzle=(32, 64, 128), + ) + def test_store_tiled_short_n(self, dtypes, swizzle): + mlir_dtype_cls, jax_dtype = dtypes + mlir_dtype = mlir_dtype_cls() + col_tiling = swizzle // bytewidth(mlir_dtype) + m = 128 + n = 16 // bytewidth(mlir_dtype) + tiling = (64, col_tiling) + def kernel(ctx, out, smem): + iota_tensor(m, n, mlir_dtype).store_tiled(smem, swizzle=swizzle) + ctx.async_copy( + src_ref=smem, + dst_ref=out, + swizzle=swizzle, + gmem_slice=(ds(0, m), ds(0, col_tiling)), + gmem_transform=mosaic_gpu.TileTransform(tiling), + ) + ctx.await_async_copy(0) + smem_shape = jax.ShapeDtypeStruct((m // tiling[0], 1, *tiling), jax_dtype) + expected = np.arange(m * n, dtype=jax_dtype).reshape(m, n) + iota = mosaic_gpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), expected, smem_shape + )() + np.testing.assert_array_equal(iota, expected) + @parameterized.named_parameters( ("bf16_i8", ir.BF16Type.get, jnp.bfloat16, From 279977c61daa4451f88a52952190026579142eab Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 23 Aug 2024 10:43:52 -0700 Subject: [PATCH 222/702] Refactor hermetic CUDA flags and update `--config=cuda` to add CUDA dependencies both for `bazel build` and `bazel test` phases. Add `--@local_config_cuda//cuda:override_include_cuda_libs` to override settings for TF wheel. Forbid building TF wheel with `--@local_config_cuda//cuda:include_cuda_libs=true` PiperOrigin-RevId: 666848518 --- .bazelrc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.bazelrc b/.bazelrc index 767e0982a1e4..4456994bf2f7 100644 --- a/.bazelrc +++ b/.bazelrc @@ -68,8 +68,8 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true # Default hermetic CUDA and CUDNN versions. build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" -# This flag is needed to include hermetic CUDA libraries for bazel tests. -test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true +# This flag is needed to include CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_cuda_libs=true # Requires MSVC and LLVM to be installed build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl From 20d13abfa056f302d2e66f49e1251f8b9748895f Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 23 Aug 2024 11:37:49 -0700 Subject: [PATCH 223/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b0d313b58ec33fdb6f09be119edfc052c1b019e1. PiperOrigin-RevId: 666868666 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 2783a0cf5fe9..57f43e2b4d2e 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "cc369fd42afddf02ec9ea7775c2798d775f1d219" -XLA_SHA256 = "218b50279f0b61e8f2cdbed0ab23279bc121c0cad4a511fb05643ec3b61bc8b6" +XLA_COMMIT = "b0d313b58ec33fdb6f09be119edfc052c1b019e1" +XLA_SHA256 = "92a520dac4393535fafb380f0cb6b18e8a97154ea845dbe786d7c54b8c1125b3" def repo(): tf_http_archive( From 9090b8a4f95f80f608eb6c685cd485708321dfc0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 22 Aug 2024 10:22:04 -0700 Subject: [PATCH 224/702] Better docs for jnp quantile & percentile --- jax/_src/numpy/reductions.py | 168 +++++++++++++++++++++++++++++++++-- 1 file changed, 161 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index c619fdf02a80..b583e5f6d200 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -1883,16 +1883,53 @@ def cumulative_sum( # Quantiles # TODO(jakevdp): interpolation argument deprecated 2024-05-16 -@implements(np.quantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: + """Compute the quantile of the data along the specified axis. + + JAX implementation of :func:`numpy.quantile`. + + Args: + a: N-dimensional array input. + q: scalar or 1-dimensional array specifying the desired quantiles. ``q`` + should contain floating-point values between ``0.0`` and ``1.0``. + axis: optional axis or tuple of axes along which to compute the quantile + out: not implemented by JAX; will error if not None + overwrite_input: not implemented by JAX; will error if not False + method: specify the interpolation method to use. Options are one of + ``["linear", "lower", "higher", "midpoint", "nearest"]``. + default is ``linear``. + keepdims: if True, then the returned array will have the same number of + dimensions as the input. Default is False. + interpolation: deprecated alias of the ``method`` argument. Will result + in a :class:`DeprecationWarning` if used. + + Returns: + An array containing the specified quantiles along the specified axes. + + See also: + - :func:`jax.numpy.nanquantile`: compute the quantile while ignoring NaNs + - :func:`jax.numpy.percentile`: compute the percentile (0-100) + + Examples: + Computing the median and quartiles of an array, with linear interpolation: + + >>> x = jnp.arange(10) + >>> q = jnp.array([0.25, 0.5, 0.75]) + >>> jnp.quantile(x, q) + Array([2.25, 4.5 , 6.75], dtype=float32) + + Computing the quartiles using nearest-value interpolation: + + >>> jnp.quantile(x, q, method='nearest') + Array([2., 4., 7.], dtype=float32) + """ check_arraylike("quantile", a, q) if overwrite_input or out is not None: - msg = ("jax.numpy.quantile does not support overwrite_input=True or " - "out != None") - raise ValueError(msg) + raise ValueError("jax.numpy.quantile does not support overwrite_input=True " + "or out != None") if not isinstance(interpolation, DeprecatedArg): deprecations.warn( "jax-numpy-quantile-interpolation", @@ -1902,11 +1939,50 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 -@implements(np.nanquantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: + """Compute the quantile of the data along the specified axis, ignoring NaNs. + + JAX implementation of :func:`numpy.nanquantile`. + + Args: + a: N-dimensional array input. + q: scalar or 1-dimensional array specifying the desired quantiles. ``q`` + should contain floating-point values between ``0.0`` and ``1.0``. + axis: optional axis or tuple of axes along which to compute the quantile + out: not implemented by JAX; will error if not None + overwrite_input: not implemented by JAX; will error if not False + method: specify the interpolation method to use. Options are one of + ``["linear", "lower", "higher", "midpoint", "nearest"]``. + default is ``linear``. + keepdims: if True, then the returned array will have the same number of + dimensions as the input. Default is False. + interpolation: deprecated alias of the ``method`` argument. Will result + in a :class:`DeprecationWarning` if used. + + Returns: + An array containing the specified quantiles along the specified axes. + + See also: + - :func:`jax.numpy.quantile`: compute the quantile without ignoring nans + - :func:`jax.numpy.nanpercentile`: compute the percentile (0-100) + + Examples: + Computing the median and quartiles of a 1D array: + + >>> x = jnp.array([0, 1, 2, jnp.nan, 3, 4, 5, 6]) + >>> q = jnp.array([0.25, 0.5, 0.75]) + + Because of the NaN value, :func:`jax.numpy.quantile` returns all NaNs, + while :func:`~jax.numpy.nanquantile` ignores them: + + >>> jnp.quantile(x, q) + Array([nan, nan, nan], dtype=float32) + >>> jnp.nanquantile(x, q) + Array([1.5, 3. , 4.5], dtype=float32) + """ check_arraylike("nanquantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " @@ -2043,12 +2119,50 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, return lax.convert_element_type(result, a.dtype) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 -@implements(np.percentile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: + """Compute the percentile of the data along the specified axis. + + JAX implementation of :func:`numpy.percentile`. + + Args: + a: N-dimensional array input. + q: scalar or 1-dimensional array specifying the desired quantiles. ``q`` + should contain integer or floating point values between ``0`` and ``100``. + axis: optional axis or tuple of axes along which to compute the quantile + out: not implemented by JAX; will error if not None + overwrite_input: not implemented by JAX; will error if not False + method: specify the interpolation method to use. Options are one of + ``["linear", "lower", "higher", "midpoint", "nearest"]``. + default is ``linear``. + keepdims: if True, then the returned array will have the same number of + dimensions as the input. Default is False. + interpolation: deprecated alias of the ``method`` argument. Will result + in a :class:`DeprecationWarning` if used. + + Returns: + An array containing the specified percentiles along the specified axes. + + See also: + - :func:`jax.numpy.quantile`: compute the quantile (0.0-1.0) + - :func:`jax.numpy.nanpercentile`: compute the percentile while ignoring NaNs + + Examples: + Computing the median and quartiles of a 1D array: + + >>> x = jnp.array([0, 1, 2, 3, 4, 5, 6]) + >>> q = jnp.array([25, 50, 75]) + >>> jnp.percentile(x, q) + Array([1.5, 3. , 4.5], dtype=float32) + + Computing the same percentiles with nearest rather than linear interpolation: + + >>> jnp.percentile(x, q, method='nearest') + Array([1., 3., 4.], dtype=float32) + """ check_arraylike("percentile", a, q) q, = promote_dtypes_inexact(q) if not isinstance(interpolation, DeprecatedArg): @@ -2061,12 +2175,52 @@ def percentile(a: ArrayLike, q: ArrayLike, method=method, keepdims=keepdims) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 -@implements(np.nanpercentile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: + """Compute the percentile of the data along the specified axis, ignoring NaN values. + + JAX implementation of :func:`numpy.nanpercentile`. + + Args: + a: N-dimensional array input. + q: scalar or 1-dimensional array specifying the desired quantiles. ``q`` + should contain integer or floating point values between ``0`` and ``100``. + axis: optional axis or tuple of axes along which to compute the quantile + out: not implemented by JAX; will error if not None + overwrite_input: not implemented by JAX; will error if not False + method: specify the interpolation method to use. Options are one of + ``["linear", "lower", "higher", "midpoint", "nearest"]``. + default is ``linear``. + keepdims: if True, then the returned array will have the same number of + dimensions as the input. Default is False. + interpolation: deprecated alias of the ``method`` argument. Will result + in a :class:`DeprecationWarning` if used. + + Returns: + An array containing the specified percentiles along the specified axes. + + See also: + - :func:`jax.numpy.nanquantile`: compute the nan-aware quantile (0.0-1.0) + - :func:`jax.numpy.percentile`: compute the percentile without special + handling of NaNs. + + Examples: + Computing the median and quartiles of a 1D array: + + >>> x = jnp.array([0, 1, 2, jnp.nan, 3, 4, 5, 6]) + >>> q = jnp.array([25, 50, 75]) + + Because of the NaN value, :func:`jax.numpy.percentile` returns all NaNs, + while :func:`~jax.numpy.nanpercentile` ignores them: + + >>> jnp.percentile(x, q) + Array([nan, nan, nan], dtype=float32) + >>> jnp.nanpercentile(x, q) + Array([1.5, 3. , 4.5], dtype=float32) + """ check_arraylike("nanpercentile", a, q) q = ufuncs.true_divide(q, 100.0) if not isinstance(interpolation, DeprecatedArg): From 670a648b7bd1e1a3916599065c0de3a97670b5f0 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 23 Aug 2024 21:21:55 +0000 Subject: [PATCH 225/702] add experimental jax.no_tracing context manager --- jax/__init__.py | 1 + jax/_src/config.py | 5 +++++ jax/_src/pjit.py | 3 +++ tests/api_test.py | 14 ++++++++++++++ 4 files changed, 23 insertions(+) diff --git a/jax/__init__.py b/jax/__init__.py index 037386317ee4..168ac9278586 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -56,6 +56,7 @@ debug_nans as debug_nans, debug_infs as debug_infs, log_compiles as log_compiles, + no_tracing as no_tracing, explain_cache_misses as explain_cache_misses, default_device as default_device, default_matmul_precision as default_matmul_precision, diff --git a/jax/_src/config.py b/jax/_src/config.py index 46b3273278e0..b6d2358f4c26 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1501,6 +1501,11 @@ def _update_disable_jit_thread_local(val): upgrade=True, help='Enable eager-mode pmap when jax_disable_jit is activated.') +no_tracing = bool_state( + name='jax_no_tracing', + default=False, + help='Disallow tracing for JIT compilation.') + disable_vmap_shmap_error = bool_state( name='jax_disable_vmap_shmap_error', default=False, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d8784848fe58..0997f1107a02 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -353,6 +353,9 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @api_boundary def cache_miss(*args, **kwargs): + if config.no_tracing.value: + raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " + "`jit`, but 'no_tracing' is set") outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( fun, jit_info, *args, **kwargs) executable = _read_most_recent_pjit_call_executable(jaxpr) diff --git a/tests/api_test.py b/tests/api_test.py index eee4924435ed..3d3d08c092c3 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1463,6 +1463,20 @@ def f(x): self.assertAllClose(f(np.nan), np.nan) self.assertAllClose(jit(f)(np.nan), np.nan) + def test_no_tracing(self): + @jax.jit + def f(x): + return x + + x = jnp.arange(3) + y = jnp.arange(4) + + _ = f(x) # no crash + + with self.assertRaisesRegex(RuntimeError, 'no_tracing'): + with jax.no_tracing(): + _ = f(y) # crash! + class APITest(jtu.JaxTestCase): From 276c87eba007e3d45d0fb61d2d9360c0f7bff423 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Fri, 23 Aug 2024 14:42:08 -0700 Subject: [PATCH 226/702] Add a more helpful error message in `create_hybrid_device_mesh` for missing attribute `process_index` or `slice_index. PiperOrigin-RevId: 666928476 --- jax/experimental/mesh_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 473700024ad7..7cac0338a923 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -776,7 +776,11 @@ def create_hybrid_device_mesh( if devices is None: devices = xb.devices() attr = 'process_index' if process_is_granule else 'slice_index' - assert hasattr(devices[0], attr) + if not hasattr(devices[0], attr): + raise ValueError( + f'Device {devices[0]} does not have attribute {attr}. See' + ' `process_is_granule` option.' + ) granule_dict = collections.defaultdict(list) for dev in devices: granule_dict[getattr(dev, attr)].append(dev) From 6a5ca0bb5224e7f40ee272b4f2a89a59f7c95d77 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 23 Aug 2024 15:09:33 -0700 Subject: [PATCH 227/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/9738684ff8df93d138442a290de407c726f95a6f. PiperOrigin-RevId: 666937202 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 57f43e2b4d2e..e613126de4b3 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "b0d313b58ec33fdb6f09be119edfc052c1b019e1" -XLA_SHA256 = "92a520dac4393535fafb380f0cb6b18e8a97154ea845dbe786d7c54b8c1125b3" +XLA_COMMIT = "9738684ff8df93d138442a290de407c726f95a6f" +XLA_SHA256 = "f899c583d9ae189adf8e904f3902dbd5db858203af17e54b9d0400b7d1aa1f67" def repo(): tf_http_archive( From a2a351f88bd596d31f6e99374cc561bdeb0b4c45 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 23 Aug 2024 15:18:14 -0700 Subject: [PATCH 228/702] Fix pallas int4->int8 conversion PiperOrigin-RevId: 666939965 --- jax/_src/pallas/mosaic/lowering.py | 3 ++- tests/pallas/tpu_ops_test.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 0f85b3bcdb64..11ef428ae279 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1583,7 +1583,8 @@ def _convert_element_type_lowering_rule( return arith.ExtSIOp(out_type, x).result elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4: return arith.TruncIOp(out_type, x).result - else: # This case triggers when casting signed to unsigned or vice versa. + elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits: + # This case triggers when casting signed to unsigned or vice versa. return x elif jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype( new_dtype, jnp.signedinteger diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index e7d0e04b05b0..a34c2b2f2f61 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -186,6 +186,27 @@ def body(x_ref, o_ref): result = self.pallas_call(body, out_shape=out)(x) np.testing.assert_array_equal(result, x.astype(jnp.float32) + 1.0) + def test_tpu_signed_int_upcast(self): + if not jtu.is_device_tpu_at_least(version=5): + self.skipTest("TPUv5+ needed for integer matmuls") + + def body(x_ref, o_ref): + # Test cast from int4 -> int8 + ux = lax.convert_element_type(x_ref[...], jnp.int8) + o_ref[...] = jax.lax.dot(ux, ux, preferred_element_type=jnp.int32) + + out = jax.ShapeDtypeStruct((128, 128), jnp.int32) + x = jnp.arange(128 * 128, dtype=jnp.int4).reshape((128, 128)) + result = self.pallas_call(body, out_shape=out)(x) + np.testing.assert_array_equal( + result, + jax.lax.dot( + x.astype(jnp.int8), + x.astype(jnp.int8), + preferred_element_type=jnp.int32, + ), + ) + class OpsInterpretTest(OpsTest): INTERPRET = True From 7253b9ac8bf07b27df2b38464552a4600fa95d91 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 23 Aug 2024 16:07:05 -0700 Subject: [PATCH 229/702] [Pallas] Fix pallas interpret mode DMA test failures. PiperOrigin-RevId: 666953373 --- jax/_src/pallas/mosaic/primitives.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 60d3d23cb884..b7d02e5ccf8b 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -593,9 +593,6 @@ def dma_start_discharge_rule(in_avals, out_avals, global_updates = jax.lax.all_gather(updates, shard_axis) updates = jax.lax.dynamic_index_in_dim( global_updates, index, axis=0, keepdims=False) - global_dst_sem = jax.lax.all_gather(dst_sem, shard_axis) - dst_sem = jax.lax.dynamic_index_in_dim( - global_dst_sem, index, axis=0, keepdims=False) # Handle asymmetrical indexing when devices do not share the same # dst_indexer. @@ -604,17 +601,13 @@ def dma_start_discharge_rule(in_avals, out_avals, dst_indexers = tree_util.tree_map( lambda x: jax.lax.dynamic_index_in_dim( x, index, axis=0, keepdims=False), global_dst_indexers) - global_dst_sem_indexers = tree_util.tree_map( - lambda x: jax.lax.all_gather(x, shard_axis), dst_sem_indexers) - dst_sem_indexers = tree_util.tree_map( - lambda x: jax.lax.dynamic_index_in_dim( - x, index, axis=0, keepdims=False), global_dst_sem_indexers) _, new_dst = state_discharge.index_swap_array( dst_ref, dst_indexers, updates ) # Update semaphore values. + # TODO(justinfu): Potentially handle asymmetric copy sizes. recv_size = jnp.array(updates.size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) dst_sem_value = _index_semaphore(dst_sem, dst_sem_indexers, dst_sem_aval) _, new_dst_sem = state_discharge.index_swap_array( From 9ce8de5fb0df7a4ef82b87a42f6d1ee0fb58e4c5 Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Fri, 23 Aug 2024 09:36:50 -0500 Subject: [PATCH 230/702] [ROCm] add build file. --- build/rocm/dev_build_rocm.py | 165 +++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100755 build/rocm/dev_build_rocm.py diff --git a/build/rocm/dev_build_rocm.py b/build/rocm/dev_build_rocm.py new file mode 100755 index 000000000000..2be64152f667 --- /dev/null +++ b/build/rocm/dev_build_rocm.py @@ -0,0 +1,165 @@ +# !/usr/bin/env python3 +# +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE(ruturaj4): This script automates the build process for JAX and XLA on ROCm, +# allowing for optional uninstallation of existing packages, and custom paths for ROCm and XLA repositories. + +import argparse +import os +import shutil +import subprocess +import sys + + +def get_rocm_version(): + try: + version = subprocess.check_output( + "cat /opt/rocm/.info/version | cut -d '-' -f 1", shell=True + ) + return version.decode("utf-8").strip() + except subprocess.CalledProcessError as e: + print(f"Error fetching ROCm version: {e}") + return None + + +def get_rocm_target(): + try: + target_info = subprocess.check_output( + "rocminfo | grep gfx | head -n 1", shell=True + ) + target = target_info.decode("utf-8").split()[1] + return target + except subprocess.CalledProcessError as e: + print(f"Error fetching ROCm target: {e}") + return None + + +def uninstall_existing_packages(packages): + cmd = ["python3", "-m", "pip", "uninstall", "-y"] + cmd.extend(packages) + + try: + subprocess.run(cmd, check=True) + print(f"Successfully uninstalled {packages}") + except subprocess.CalledProcessError as e: + print(f"Failed to uninstall {packages}: {e}") + + +def clean_dist_directory(): + try: + shutil.rmtree("dist") + print("Cleaned dist directory.") + except FileNotFoundError: + print("dist directory not found, skipping cleanup.") + except Exception as e: + print(f"Failed to clean dist directory: {e}") + sys.exit(1) + + +def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path): + bazel_options = ( + f"--bazel_options=--override_repository=xla={xla_path}" if xla_path else "" + ) + clang_option = f"--clang_path={clang_path}" if clang_path else "" + build_command = [ + "python3", + "./build/build.py", + "--enable_rocm", + "--build_gpu_plugin", + "--gpu_plugin_rocm_version=60", + f"--use_clang={str(use_clang).lower()}", + f"--rocm_amdgpu_targets={rocm_target}", + f"--rocm_path=/opt/rocm-{rocm_version}/", + bazel_options, + ] + + if clang_option: + build_command.append(clang_option) + + print("Executing build command:") + print(" ".join(build_command)) + + try: + subprocess.run(build_command, check=True) + print("Build completed successfully.") + except subprocess.CalledProcessError as e: + print(f"Build failed: {e}") + sys.exit(1) + + +def install_wheel(): + try: + subprocess.run( + ["python3", "-m", "pip", "install", "dist/*.whl"], check=True, shell=True + ) + print("Packages installed successfully.") + except subprocess.CalledProcessError as e: + print(f"Failed to install packages: {e}") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser(description="Script to build JAX and XLA on ROCm.") + parser.add_argument( + "--clang-path", type=str, default="", help="Specify the Clang compiler path" + ) + parser.add_argument( + "--skip-uninstall", + action="store_true", + help="Skip uninstall of old versions during package install", + ) + parser.add_argument( + "--use-clang", default="false", help="Use Clang compiler if set" + ) + parser.add_argument( + "--xla-path", type=str, default="", help="Specify the XLA repository path" + ) + + args = parser.parse_args() + + if args.xla_path: + args.xla_path = os.path.abspath(args.xla_path) + print(f"Converted XLA path to absolute: {args.xla_path}") + + rocm_version = get_rocm_version() + if not rocm_version: + print("Could not determine ROCm version. Exiting.") + sys.exit(1) + + rocm_target = get_rocm_target() + if not rocm_target: + print("Could not determine ROCm target. Exiting.") + sys.exit(1) + + if not args.skip_uninstall: + print("Uninstalling existing packages...") + packages = ["jax", "jaxlib", "jax-rocm60-pjrt", "jax-rocm60-plugin"] + uninstall_existing_packages(packages) + + clean_dist_directory() + + print( + f"Building JAX and XLA with ROCm version: {rocm_version}, Target: {rocm_target}" + ) + build_jax_xla( + args.xla_path, rocm_version, rocm_target, args.use_clang, args.clang_path + ) + + install_wheel() + + +if __name__ == "__main__": + main() From e9143623e02a8e64ee3fe826370535b35571f1bb Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 24 Aug 2024 14:34:12 -0700 Subject: [PATCH 231/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/4bfb5c82a427151d6fe5acad8ebe12cee403036a. PiperOrigin-RevId: 667177243 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e613126de4b3..d20edfe2328a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9738684ff8df93d138442a290de407c726f95a6f" -XLA_SHA256 = "f899c583d9ae189adf8e904f3902dbd5db858203af17e54b9d0400b7d1aa1f67" +XLA_COMMIT = "4bfb5c82a427151d6fe5acad8ebe12cee403036a" +XLA_SHA256 = "83dfcddaf29205f8f426cf7044ab89242985ec34b51a8a52bb97dc5092fc1da0" def repo(): tf_http_archive( From a9b41e9fe7fb3957155e89dc71e4765da35632fa Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sat, 24 Aug 2024 17:29:20 -0700 Subject: [PATCH 232/702] improve `scan` error message on non-concrete `length` argument Specifically, make it speak concretely about the `length` argument. --- jax/_src/lax/control_flow/loops.py | 6 +++++- tests/lax_control_flow_test.py | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 443470e129fa..d9cf1f89c5c4 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -224,7 +224,11 @@ def scan(f, init, xs, length=None): if not hasattr(x, 'shape')))) from err if length is not None: - length = int(length) + try: + length = int(length) + except core.ConcretizationTypeError as err: + msg = 'The `length` argument to `scan` expects a concrete `int` value.' + raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type] if not all(length == l for l in lengths): msg = ("scan got `length` argument of {} which disagrees with " "leading axis sizes {}.") diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 192603de3655..829169b40778 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2950,6 +2950,14 @@ def body(carry, x): hlo_text = fn.lower(init).as_text('hlo') self.assertNotIn('4,1,2,2', hlo_text) + def test_scan_length_concrete_error(self): + f = jax.jit(lambda n, x: jax.lax.scan(lambda c, z: (c, z), x, (), n)) + + with self.assertRaisesRegex( + core.ConcretizationTypeError, + "The `length` argument to `scan` expects a concrete `int` value.*"): + f(3, 1.) + def test_cond_vmap_forwarding_doesnt_promote(self): def f(x, y): x, y = jax.lax.cond( From b3e3115391a9cf2373cf8f3ed1f68029ce956e60 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sat, 24 Aug 2024 22:39:28 -0700 Subject: [PATCH 233/702] improve `scan` error message on non-concrete `unroll` argument --- jax/_src/lax/control_flow/loops.py | 4 ++++ tests/lax_control_flow_test.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index d9cf1f89c5c4..f7f09424a9e8 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -296,6 +296,10 @@ def _create_jaxpr(init): raise NotImplementedError( f'Effects not supported in `scan`: {disallowed_effects}') + unroll = core.concrete_or_error( + None, unroll, + "The `unroll` argument to `scan` expects a concrete `int` or `bool` " + "value.") if isinstance(unroll, bool): unroll = max(length, 1) if unroll else 1 if unroll < 1: diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 829169b40778..fd83d269b41c 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2958,6 +2958,17 @@ def test_scan_length_concrete_error(self): "The `length` argument to `scan` expects a concrete `int` value.*"): f(3, 1.) + def test_scan_unroll_concrete_error(self): + f = jax.jit(lambda n, x: jax.lax.scan( + lambda c, z: (c, z), x, (), 10, unroll=n)) + + msg = ("The `unroll` argument to `scan` expects a concrete `int` or " + "`bool` value.*") + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(3, 1.) + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(True, 1.) + def test_cond_vmap_forwarding_doesnt_promote(self): def f(x, y): x, y = jax.lax.cond( From 4b1c9f483cffffaef31cd1c1e3035cc0cef5fa67 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 26 Aug 2024 09:04:15 -0700 Subject: [PATCH 234/702] jnp.mean: fix normalizer for large arrays --- jax/_src/numpy/reductions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index c619fdf02a80..f19125928a57 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -766,7 +766,8 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, else: normalizer = core.dimension_as_value(_axis_size(a, axis)) else: - normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) + normalizer = sum(_broadcast_to(where, np.shape(a)), axis, + dtype=computation_dtype, keepdims=keepdims) return lax.div( sum(a, axis, dtype=computation_dtype, keepdims=keepdims, where=where), From 6d1f51e63d714c77b44e9118389ffdbf6e7fa929 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 26 Aug 2024 09:10:26 -0700 Subject: [PATCH 235/702] Clean up BUILD files. PiperOrigin-RevId: 667604964 --- benchmarks/mosaic/BUILD | 4 +- docs/cuda_custom_call/BUILD | 2 +- examples/jax_cpp/BUILD | 8 +- jax/BUILD | 25 ++--- jax/_src/lib/BUILD | 2 +- jax/_src/pallas/BUILD | 2 +- jax/_src/pallas/mosaic/BUILD | 2 +- jax/_src/pallas/mosaic_gpu/BUILD | 2 +- jax/_src/pallas/triton/BUILD | 2 +- jax/experimental/jax2tf/g3doc/BUILD | 2 +- .../jax2tf/tests/back_compat_testdata/BUILD | 2 +- .../jax2tf/tests/flax_models/BUILD | 4 +- jax/experimental/mosaic/gpu/examples/BUILD | 18 ++-- jax/tools/build_defs.bzl | 4 +- jaxlib/BUILD | 20 ++-- jaxlib/cpu/BUILD | 10 +- jaxlib/cuda/BUILD | 102 +++++++++--------- jaxlib/gpu/BUILD | 2 +- jaxlib/mosaic/BUILD | 26 ++--- jaxlib/mosaic/gpu/BUILD | 20 ++-- jaxlib/mosaic/python/BUILD | 2 +- jaxlib/triton/BUILD | 8 +- tests/BUILD | 2 +- tests/mosaic/BUILD | 4 +- 24 files changed, 134 insertions(+), 141 deletions(-) diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD index 027da12ce6d3..72aae09af4a2 100644 --- a/benchmarks/mosaic/BUILD +++ b/benchmarks/mosaic/BUILD @@ -49,8 +49,8 @@ jax_test( disable_configs = DISABLED_CONFIGS, tags = ["notap"], deps = [ + "//jax:mosaic_gpu", + "//jax/experimental/mosaic/gpu/examples:matmul", "//third_party/py/google_benchmark", - "//third_party/py/jax:mosaic_gpu", - "//third_party/py/jax/experimental/mosaic/gpu/examples:matmul", ] + py_deps("absl/testing") + py_deps("numpy"), ) diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD index 93715bdac171..0591eed1fbec 100644 --- a/docs/cuda_custom_call/BUILD +++ b/docs/cuda_custom_call/BUILD @@ -56,8 +56,8 @@ cuda_library( name = "foo_", srcs = ["foo.cu.cc"], deps = [ + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@local_config_cuda//cuda:cuda_headers", ], ) diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index fccf0cc37048..6e4647b5e491 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_applicable_licenses = ["//third_party/py/jax:license"]) +package(default_applicable_licenses = ["//jax:license"]) licenses(["notice"]) @@ -21,13 +21,13 @@ cc_binary( srcs = ["main.cc"], tags = ["manual"], deps = [ - "//third_party/absl/status:statusor", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:platform_port", "@xla//xla:literal", "@xla//xla:literal_util", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt/cpu:cpu_client", "@xla//xla/tools:hlo_module_loader", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:platform_port", ], ) diff --git a/jax/BUILD b/jax/BUILD index ec350b4b99a7..574559688c4d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -76,39 +76,32 @@ package_group( packages = [ # Intentionally avoid jax dependencies on jax.extend. # See https://jax.readthedocs.io/en/latest/jep/15856-jex.html - "//third_party/py/jax/tests/...", + "//tests/...", ] + jax_extend_internal_users, ) package_group( name = "mosaic_users", - packages = [ - "//...", - ] + mosaic_internal_users, + includes = [":internal"], + packages = mosaic_internal_users, ) package_group( name = "pallas_gpu_users", - packages = [ - "//...", - "//learning/brain/research/jax", - ] + pallas_gpu_internal_users, + includes = [":internal"], + packages = pallas_gpu_internal_users, ) package_group( name = "pallas_tpu_users", - packages = [ - "//...", - "//learning/brain/research/jax", - ] + pallas_tpu_internal_users, + includes = [":internal"], + packages = pallas_tpu_internal_users, ) package_group( name = "mosaic_gpu_users", - packages = [ - "//...", - "//learning/brain/research/jax", - ] + mosaic_gpu_internal_users, + includes = [":internal"], + packages = mosaic_gpu_internal_users, ) # JAX-private test utilities. diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 09cc3a81c2c2..7068c0ef6732 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -22,7 +22,7 @@ load( package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) py_library_providing_imports_info( diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index c0fa02131bc8..4ff7062ac1e8 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -21,7 +21,7 @@ load( package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index f1616962f349..071f09f3f567 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -20,7 +20,7 @@ load("//jaxlib:jax.bzl", "py_deps") package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 8f351020a86f..9d2dfd8dfa0f 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -24,7 +24,7 @@ load( package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 01d2480983d5..c40fb19ec808 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -23,7 +23,7 @@ load( package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/jax/experimental/jax2tf/g3doc/BUILD b/jax/experimental/jax2tf/g3doc/BUILD index 424d3b8b9e5d..6222b82b3550 100644 --- a/jax/experimental/jax2tf/g3doc/BUILD +++ b/jax/experimental/jax2tf/g3doc/BUILD @@ -15,7 +15,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"], + default_visibility = ["//jax/experimental/jax2tf:__subpackages__"], ) filegroup( diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD index f584ab5d3191..3417c1abf6ac 100644 --- a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD @@ -18,7 +18,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"], + default_visibility = ["//jax/experimental/jax2tf:__subpackages__"], ) py_library( diff --git a/jax/experimental/jax2tf/tests/flax_models/BUILD b/jax/experimental/jax2tf/tests/flax_models/BUILD index 19afb4a6877c..d3af9581ae02 100644 --- a/jax/experimental/jax2tf/tests/flax_models/BUILD +++ b/jax/experimental/jax2tf/tests/flax_models/BUILD @@ -19,7 +19,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"], + default_visibility = ["//jax/experimental/jax2tf:__subpackages__"], ) py_library( @@ -27,8 +27,8 @@ py_library( srcs = glob(["*.py"]), srcs_version = "PY3", deps = [ + "//jax", "//third_party/py/flax:core", - "//third_party/py/jax", "//third_party/py/jraph", "//third_party/py/numpy", "//third_party/py/typing_extensions", diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index 3f9496b38376..6f5af51fbf0f 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:jax.bzl", "py_deps") load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("//jaxlib:jax.bzl", "py_deps") licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax:mosaic_gpu_users"], + default_visibility = ["//jax:mosaic_gpu_users"], ) exports_files( @@ -27,15 +27,15 @@ exports_files( "flash_attention.py", "matmul.py", ], - visibility = ["//third_party/py/jax:internal"], + visibility = ["//jax:internal"], ) py_library( name = "matmul", srcs = ["matmul.py"], deps = [ - "//third_party/py/jax", - "//third_party/py/jax:mosaic_gpu", + "//jax", + "//jax:mosaic_gpu", ], ) @@ -43,8 +43,8 @@ py_library( name = "flash_attention", srcs = ["flash_attention.py"], deps = [ - "//third_party/py/jax", - "//third_party/py/jax:mosaic_gpu", + "//jax", + "//jax:mosaic_gpu", ], ) @@ -58,8 +58,8 @@ py_test( "requires-gpu-sm90-only", ], deps = [ + "//jax", + "//jax:mosaic_gpu", "//learning/brain/research/jax:gpu_support", - "//third_party/py/jax", - "//third_party/py/jax:mosaic_gpu", ] + py_deps("numpy"), ) diff --git a/jax/tools/build_defs.bzl b/jax/tools/build_defs.bzl index 1540afe42a6a..06f5e69833c5 100644 --- a/jax/tools/build_defs.bzl +++ b/jax/tools/build_defs.bzl @@ -146,9 +146,9 @@ EOF ) if format == "TF": - jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir_with_tensorflow" + jax_to_ir_rule = "//jax/tools:jax_to_ir_with_tensorflow" else: - jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir" + jax_to_ir_rule = "//jax/tools:jax_to_ir" py_binary( name = runner, diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 77b46d6d51aa..ab60b3fadd37 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -14,19 +14,19 @@ # JAX is Autograd and XLA -load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "py_library_providing_imports_info", "pybind_extension", "pytype_library", ) +load("//jaxlib:symlink_files.bzl", "symlink_files") licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) # This makes xla_extension module accessible from jax._src.lib. @@ -129,13 +129,13 @@ cc_library( hdrs = ["ffi_helpers.h"], features = ["-use_header_modules"], deps = [ - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", ], ) @@ -149,10 +149,10 @@ cc_library( features = ["-use_header_modules"], deps = [ ":kernel_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/base", "@nanobind", + "@xla//xla/ffi/api:c_api", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -201,10 +201,10 @@ pybind_extension( srcs = ["utils.cc"], module_name = "utils", deps = [ - "@xla//third_party/python_runtime:headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:inlined_vector", "@nanobind", + "@xla//third_party/python_runtime:headers", ], ) @@ -238,6 +238,9 @@ pybind_extension( module_name = "rocm_plugin_extension", deps = [ "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/status", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", "@xla//third_party/python_runtime:headers", "@xla//xla:status", "@xla//xla:util", @@ -248,9 +251,6 @@ pybind_extension( "@xla//xla/pjrt/c:pjrt_c_api_helpers", "@xla//xla/python:py_client_gpu", "@xla//xla/tsl/python/lib/core:numpy", - "@com_google_absl//absl/status", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", ], ) diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 48332ee1a4d2..d3d15c4fc939 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -23,7 +23,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) # LAPACK @@ -36,13 +36,13 @@ cc_library( features = ["-use_header_modules"], deps = [ "//jaxlib:ffi_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -71,8 +71,8 @@ pybind_extension( deps = [ ":lapack_kernels", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/ffi/api:ffi", "@nanobind", + "@xla//xla/ffi/api:ffi", ], ) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index bd74be6732fd..a7a47f431a1d 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -26,7 +26,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) cc_library( @@ -37,9 +37,9 @@ cc_library( defines = ["JAX_GPU_CUDA=1"], visibility = ["//visibility:public"], deps = [ - "@xla//xla/tsl/cuda:cupti", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", + "@xla//xla/tsl/cuda:cupti", ], ) @@ -57,9 +57,6 @@ cc_library( features = ["-use_header_modules"], deps = [ ":cuda_vendor", - "@xla//xla/tsl/cuda:cupti", - "@xla//xla/tsl/cuda:cusolver", - "@xla//xla/tsl/cuda:cusparse", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -69,6 +66,9 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cupti", + "@xla//xla/tsl/cuda:cusolver", + "@xla//xla/tsl/cuda:cusparse", ], ) @@ -90,11 +90,11 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:handle_pool", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -108,9 +108,6 @@ cc_library( ":cuda_make_batch_pointers", ":cuda_vendor", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -122,6 +119,9 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -145,12 +145,12 @@ pybind_extension( ":cublas_kernels", ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@nanobind", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -163,13 +163,13 @@ cc_library( ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cudnn", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cudnn", ], ) @@ -201,11 +201,11 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:handle_pool", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", ], ) @@ -218,12 +218,12 @@ cc_library( ":cuda_solver_handle_pool", ":cuda_vendor", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", ], ) @@ -238,13 +238,13 @@ cc_library( ":cuda_solver_handle_pool", ":cuda_vendor", "//jaxlib:ffi_helpers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@xla//xla/ffi/api:ffi", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", ], ) @@ -272,15 +272,15 @@ pybind_extension( ":cusolver_kernels", ":cusolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -293,13 +293,13 @@ cc_library( ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusparse", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusparse", ], ) @@ -324,9 +324,6 @@ pybind_extension( ":cuda_vendor", ":cusparse_kernels", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusparse", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -338,6 +335,9 @@ pybind_extension( "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusparse", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -354,13 +354,13 @@ cc_library( ":cuda_vendor", "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -390,10 +390,10 @@ pybind_extension( ":cuda_linalg_kernels", ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -409,12 +409,12 @@ cc_library( ":cuda_vendor", "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -428,9 +428,9 @@ cuda_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:kernel_helpers", + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", - "@local_config_cuda//cuda:cuda_headers", ], ) @@ -447,9 +447,9 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cudart", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -483,10 +483,6 @@ cc_library( ":cuda_vendor", ":triton_utils", "//jaxlib/gpu:triton_cc_proto", - "@xla//xla/service:custom_call_status", - "@xla//xla/stream_executor/cuda:cuda_asm_compiler", - "@xla//xla/tsl/cuda:cudart", - "@tsl//tsl/platform:env", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -497,6 +493,10 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:env", + "@xla//xla/service:custom_call_status", + "@xla//xla/stream_executor/cuda:cuda_asm_compiler", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -556,6 +556,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", + "@com_google_absl//absl/base:dynamic_annotations", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", @@ -563,7 +564,6 @@ cc_library( "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/cuda:cusparse", - "@com_google_absl//absl/base:dynamic_annotations", ], ) @@ -594,6 +594,8 @@ pybind_extension( ":versions_helpers", "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/status:statusor", + "@nanobind", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", @@ -601,8 +603,6 @@ pybind_extension( "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/cuda:cusparse", - "@com_google_absl//absl/status:statusor", - "@nanobind", ], ) diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 706cac6b46d4..f3524ccdf781 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -20,7 +20,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) exports_files(srcs = [ diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 10acec815475..5452520204b8 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_python//python:defs.bzl", "py_library") licenses(["notice"]) package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:mosaic_users", ], ) @@ -54,6 +54,14 @@ cc_library( # compatible with libtpu deps = [ ":tpu_inc_gen", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", @@ -71,18 +79,10 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorTransforms", + "@tsl//tsl/platform:statusor", "@xla//xla:array", "@xla//xla:shape_util", "@xla//xla:util", - "@tsl//tsl/platform:statusor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", ], ) @@ -192,14 +192,14 @@ cc_library( deps = [ ":tpu_dialect", ":tpu_inc_gen", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@xla//xla:array", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", ], ) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 20fcf2b4ce74..e5eaeb347137 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -17,7 +17,7 @@ load("//jaxlib:jax.bzl", "pybind_extension") package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:mosaic_gpu_users"], ) py_library( @@ -105,6 +105,12 @@ cc_library( deps = [ ":passes", "//jaxlib/cuda:cuda_vendor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", @@ -142,12 +148,6 @@ cc_library( "@llvm-project//mlir:VectorDialect", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", ], alwayslink = True, ) @@ -168,11 +168,11 @@ pybind_extension( deps = [ "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/synchronization", "@nanobind", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -192,7 +192,7 @@ cc_binary( "notap", ], deps = [ - "@xla//xla/tsl/cuda:cudart", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", ], ) diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index 639e61a89062..48268bfcf30a 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -14,8 +14,8 @@ # Mosaic Python bindings -load("@rules_python//python:defs.bzl", "py_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") +load("@rules_python//python:defs.bzl", "py_library") gentbl_filegroup( name = "tpu_python_gen_raw", diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index 1d994209ffcc..95482e47e864 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") +load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library") licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) pytype_strict_library( @@ -56,8 +56,8 @@ genrule( out=$(RULEDIR)/$${base//_raw/} echo '# pytype: skip-file' > $${out} && \ cat $${src} | - sed -e 's/^from \\.\\./from jaxlib.mlir\\./g' | - sed -e 's/^from \\./from jaxlib.mlir\\.dialects\\./g' >> $${out} + sed -e 's/^from \\.\\./from jaxlib\\.mlir\\./g' | + sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' >> $${out} done """, ) diff --git a/tests/BUILD b/tests/BUILD index eab1d11287e2..ef5f27f9bccb 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1567,6 +1567,6 @@ filegroup( exclude = [], ) + ["BUILD"], visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index fdb7ad7b0a1f..255b03d3a002 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -68,10 +68,10 @@ jax_test( jax_test( name = "flash_attention", - srcs = ["//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py"], + srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"], disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, - main = "//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py", + main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py", tags = ["notap"], deps = [ "//jax:mosaic_gpu", From ffa53b5f050034ae5e069782f702d621d7b45f20 Mon Sep 17 00:00:00 2001 From: quattro Date: Mon, 26 Aug 2024 09:41:40 -0700 Subject: [PATCH 236/702] fixes cache miss in abstract_eval_shape for bcoo dot general --- jax/experimental/sparse/bcoo.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 4cbe52383751..20917c3f7152 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -738,12 +738,11 @@ def result(out_array, lhs_data, lhs_indices, rhs): @bcoo_dot_general_p.def_abstract_eval def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, preferred_element_type, lhs_spinfo: SparseInfo): - out_aval = jax.eval_shape( - partial(lax.dot_general, - dimension_numbers=dimension_numbers, - preferred_element_type=preferred_element_type), - jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype), - jax.ShapeDtypeStruct(rhs.shape, rhs.dtype)) + out_aval = jax.jit(lax.dot_general, static_argnames=("dimension_numbers", "preferred_element_type")).eval_shape( + jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype), + jax.ShapeDtypeStruct(rhs.shape, rhs.dtype), + dimension_numbers=dimension_numbers, + preferred_element_type=preferred_element_type) (lhs_contracting, _), (lhs_batch, _) = dimension_numbers n_batch, n_sparse, _, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_spinfo.shape) From d2b1ebd0aa504d2d0dddc3e2ffdb7240be7bdff6 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Fri, 9 Aug 2024 13:45:44 -0500 Subject: [PATCH 237/702] Update pgo_nsys_converter.py to use the NVTX kern sum report when available. --- jax/tools/pgo_nsys_converter.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 5e87220be606..5460edd960f5 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -38,7 +38,14 @@ profile_folder = os.path.join(os.path.split(args.profile_path)[0], '') assert isinstance(nsys_path, str) - stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] + + # Older versions of nsys use `nvtxsum` for the report name so determine which is available. + query_reports_command = [nsys_path, "stats", "--help-reports"] + reports_list = subprocess.run(query_reports_command, capture_output=True, text=True).stdout + report_name = "nvtx_sum" if "nvtx_sum" in reports_list else "nvtxsum" + + assert isinstance(nsys_path, str) + stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", report_name, f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] print(f""" ******Starting stats command****** @@ -49,10 +56,10 @@ thunk_re = re.compile("hlo_op=(.*)#") with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: - with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile: + with open(f"{pgle_folder}{pgle_filename}.pbtxt_{report_name}.csv", newline='') as csvfile: reader = csv.DictReader(csvfile) for row in reader: - name = row['NVTX Range'] + name = row['Range'] time_ns = float(row['Avg (ns)']) m = thunk_re.search(name) if m is not None: From 416f79bb5cbdb769bfc3772e2d90751c20c2f758 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 26 Aug 2024 10:48:29 -0700 Subject: [PATCH 238/702] DOC: update docstrings for broadcast-related functions --- jax/_src/numpy/lax_numpy.py | 108 +++++++++++++++++++++++++++++++++--- 1 file changed, 101 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 7dfabb63c0d8..61ef2acbdb1c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2366,25 +2366,119 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... def broadcast_shapes(*shapes: Sequence[int | core.Tracer] ) -> tuple[int | core.Tracer, ...]: ... -@util.implements(getattr(np, "broadcast_shapes", None)) def broadcast_shapes(*shapes): + """Broadcast input shapes to a common output shape. + + JAX implementation of :func:`numpy.broadcast_shapes`. JAX uses NumPy-style + broadcasting rules, which you can read more about at `NumPy broadcasting`_. + + Args: + shapes: 0 or more shapes specified as sequences of integers + + Returns: + The broadcasted shape as a tuple of integers. + + See Also: + - :func:`jax.numpy.broadcast_arrays`: broadcast arrays to a common shape. + - :func:`jax.numpy.broadcast_to`: broadcast an array to a specified shape. + + Examples: + Some compatible shapes: + + >>> jnp.broadcast_shapes((1,), (4,)) + (4,) + >>> jnp.broadcast_shapes((3, 1), (4,)) + (3, 4) + >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) + (5, 3, 4) + + Incompatible shapes: + + >>> jnp.broadcast_shapes((3, 1), (4, 1)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)] + + .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html + """ if not shapes: return () shapes = [(shape,) if np.ndim(shape) == 0 else tuple(shape) for shape in shapes] return lax.broadcast_shapes(*shapes) -@util.implements(np.broadcast_arrays, lax_description="""\ -The JAX version does not necessarily return a view of the input. -""") def broadcast_arrays(*args: ArrayLike) -> list[Array]: + """Broadcast arrays to a common shape. + + JAX implementation of :func:`numpy.broadcast_arrays`. JAX uses NumPy-style + broadcasting rules, which you can read more about at `NumPy broadcasting`_. + + Args: + args: zero or more array-like objects to be broadcasted. + + Returns: + a list of arrays containing broadcasted copies of the inputs. + + See also: + - :func:`jax.numpy.broadcast_shapes`: broadcast input shapes to a common shape. + - :func:`jax.numpy.broadcast_to`: broadcast an array to a specified shape. + + Examples: + + >>> x = jnp.arange(3) + >>> y = jnp.int32(1) + >>> jnp.broadcast_arrays(x, y) + [Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)] + + >>> x = jnp.array([[1, 2, 3]]) + >>> y = jnp.array([[10], + ... [20]]) + >>> x2, y2 = jnp.broadcast_arrays(x, y) + >>> x2 + Array([[1, 2, 3], + [1, 2, 3]], dtype=int32) + >>> y2 + Array([[10, 10, 10], + [20, 20, 20]], dtype=int32) + + .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html + """ return util._broadcast_arrays(*args) -@util.implements(np.broadcast_to, lax_description="""\ -The JAX version does not necessarily return a view of the input. -""") def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: + """Broadcast an array to a specified shape. + + JAX implementation of :func:`numpy.broadcast_to`. JAX uses NumPy-style + broadcasting rules, which you can read more about at `NumPy broadcasting`_. + + Args: + array: array to be broadcast. + shape: shape to which the array will be broadcast. + + Returns: + a copy of array broadcast to the specified shape. + + See also: + - :func:`jax.numpy.broadcast_arrays`: broadcast arrays to a common shape. + - :func:`jax.numpy.broadcast_shapes`: broadcast input shapes to a common shape. + + Examples: + >>> x = jnp.int32(1) + >>> jnp.broadcast_to(x, (1, 4)) + Array([[1, 1, 1, 1]], dtype=int32) + + >>> x = jnp.array([1, 2, 3]) + >>> jnp.broadcast_to(x, (2, 3)) + Array([[1, 2, 3], + [1, 2, 3]], dtype=int32) + + >>> x = jnp.array([[2], [4]]) + >>> jnp.broadcast_to(x, (2, 4)) + Array([[2, 2, 2, 2], + [4, 4, 4, 4]], dtype=int32) + + .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html + """ return util._broadcast_to(array, shape) From c33ce857847750b836b8c899c5d48c12b2842afc Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 26 Aug 2024 10:50:23 -0700 Subject: [PATCH 239/702] Small fix for the jax trace dumping path PiperOrigin-RevId: 667639334 --- jax/_src/profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 4752106c7688..cdf739944f4b 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -129,7 +129,7 @@ def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, _profile_state.create_perfetto_link = create_perfetto_link _profile_state.create_perfetto_trace = ( create_perfetto_trace or create_perfetto_link) - _profile_state.log_dir = pathlib.Path(log_dir) + _profile_state.log_dir = str(log_dir) def _write_perfetto_trace_file(log_dir: os.PathLike | str): From 14c719e810f58a721a1ba3f626439767f7a965fb Mon Sep 17 00:00:00 2001 From: quattro Date: Mon, 26 Aug 2024 13:33:42 -0700 Subject: [PATCH 240/702] fixes cache miss in for eval shape in BCOO related functions --- jax/experimental/sparse/bcoo.py | 37 +++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 20917c3f7152..a0a4df4d898f 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -1186,12 +1186,11 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic dimension_numbers, preferred_element_type): lhs_shape = lhs_spinfo.shape rhs_shape = rhs_spinfo.shape - out_aval = jax.eval_shape( - partial(lax.dot_general, - dimension_numbers=dimension_numbers, - preferred_element_type=preferred_element_type), - jax.ShapeDtypeStruct(lhs_shape, lhs_data.dtype), - jax.ShapeDtypeStruct(rhs_shape, rhs_data.dtype)) + out_aval = jax.jit(lax.dot_general, static_argnames=("dimension_numbers", "preferred_element_type")).eval_shape( + jax.ShapeDtypeStruct(lhs_shape, lhs_data.dtype), + jax.ShapeDtypeStruct(rhs_shape, rhs_data.dtype), + dimension_numbers=dimension_numbers, + preferred_element_type=preferred_element_type) lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape) rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape) @@ -1772,9 +1771,9 @@ def bcoo_concatenate(operands: Sequence[BCOO], *, dimension: int) -> BCOO: raise ValueError("bcoo_concatenate: expected operands to be a sequence of BCOO arrays. " f"Got {operands}") # Validate inputs using lax.concatenate abstract evaluation. - out_aval = jax.eval_shape( - functools.partial(lax.concatenate, dimension=dimension), - [core.ShapedArray(op.shape, op.dtype) for op in operands]) + out_aval = jax.jit(lax.concatenate, static_argnames=("dimension",)).eval_shape( + [core.ShapedArray(op.shape, op.dtype) for op in operands], + dimension=dimension) if len({op.n_dense for op in operands}) > 1: raise ValueError("bcoo_concatenate requires inputs to have matching nse dimensions.") @@ -1890,8 +1889,9 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[in def bcoo_rev(operand, dimensions): """Sparse implementation of {func}`jax.lax.rev`""" # Check validity of dimensions via original implementation. - _ = jax.eval_shape(partial(lax.rev, dimensions=dimensions), - jax.ShapeDtypeStruct(operand.shape, operand.dtype)) + _ = jax.jit(lax.rev, static_argnames=("dimensions",)).eval_shape( + jax.ShapeDtypeStruct(operand.shape, operand.dtype), + dimensions=dimensions) batch_dims = [d for d in dimensions if d < operand.n_batch] sparse_dims = [d for d in dimensions if operand.n_batch <= d < operand.n_batch + operand.n_sparse] dense_dims = [d for d in dimensions if d >= operand.n_batch + operand.n_sparse] @@ -2035,8 +2035,9 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq out: BCOO array containing the slice. """ # Use abstract eval to validate inputs. - jax.eval_shape(partial(lax.dynamic_slice, slice_sizes=slice_sizes), - jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices) + jax.jit(lax.dynamic_slice, static_argnames=("slice_sizes",)).eval_shape( + jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices, + slice_sizes=slice_sizes) if not isinstance(mat, BCOO): raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") start_indices = tuple(jnp.asarray(i) for i in start_indices) @@ -2302,9 +2303,13 @@ def bcoo_gather(operand: BCOO, start_indices: Array, mode=mode, fill_value=fill_value) # Abstract eval lax.gather to validate arguments & determine output shape. - out_aval = jax.eval_shape(partial(lax.gather, **kwds), - jax.ShapeDtypeStruct(operand.shape, operand.dtype), - jax.ShapeDtypeStruct(start_indices.shape, start_indices.dtype)) + static_argnames = ("dimension_numbers", "slice_sizes", "unique_indices", + "indices_are_sorted", "mode", "fill_value",) + out_aval = jax.jit(lax.gather, static_argnames=static_argnames).eval_shape( + jax.ShapeDtypeStruct(operand.shape, operand.dtype), + jax.ShapeDtypeStruct(start_indices.shape, start_indices.dtype), + **kwds) + offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims start_index_map = dimension_numbers.start_index_map From b38f985b013858846348f1e3ac77ac3b9f5b4dfd Mon Sep 17 00:00:00 2001 From: Bryan Massoth Date: Mon, 26 Aug 2024 14:04:07 -0700 Subject: [PATCH 241/702] Add a callout that LibTPU now supports profiling of SparseCore for TPUv5p chips which will be viewable in Tensorboard Profiler's TraceViewer tool. PiperOrigin-RevId: 667708094 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9abb705c33e1..2d09d610bdb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. various Bazel targets. This enables more reproducible builds for JAX and its supported CUDA versions. +* Changes + * SparseCore profiling is added. + * JAX now supports profiling [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore) on TPUv5p chips. These traces will be viewable in Tensorboard Profiler's [TraceViewer](https://www.tensorflow.org/guide/profiler#trace_viewer). + ## jax 0.4.31 (July 29, 2024) * Deletion From 57c0d59d040f28fdbe950899e99a69a4f3023da8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 26 Aug 2024 14:38:59 -0700 Subject: [PATCH 242/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/baf026d13bc20c6232ffeab0991628ce758982f3. PiperOrigin-RevId: 667720484 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d20edfe2328a..b8ef7cd6699f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "4bfb5c82a427151d6fe5acad8ebe12cee403036a" -XLA_SHA256 = "83dfcddaf29205f8f426cf7044ab89242985ec34b51a8a52bb97dc5092fc1da0" +XLA_COMMIT = "baf026d13bc20c6232ffeab0991628ce758982f3" +XLA_SHA256 = "28adba477042bda38a541eebb79aeff49067d6bd2bc4c6ded86583ebdee60e08" def repo(): tf_http_archive( From f812d0f28be5df0f4da598e29dfa913a67d4b63f Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Mon, 26 Aug 2024 16:00:36 -0700 Subject: [PATCH 243/702] Clarify meaning of tall and wide Jacobian matrices in autodiff docs. --- docs/_tutorials/advanced-autodiff.md | 2 +- docs/notebooks/autodiff_cookbook.ipynb | 2 +- docs/notebooks/autodiff_cookbook.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index d58a45d1ddf3..0449b82e9a5b 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -315,7 +315,7 @@ print("jacrev result, with shape", J.shape) print(J) ``` -These two functions compute the same values (up to machine numerics), but differ in their implementation: {func}`jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices, while {func}`jax.jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices. For matrices that are near-square, {func}`jax.jacfwd` probably has an edge over {func}`jax.jacrev`. +These two functions compute the same values (up to machine numerics), but differ in their implementation: {func}`jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while {func}`jax.jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, {func}`jax.jacfwd` probably has an edge over {func}`jax.jacrev`. You can also use {func}`jax.jacfwd` and {func}`jax.jacrev` with container types: diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index edfd0d4535f8..86c8bfea8468 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -487,7 +487,7 @@ "id": "iZDL-n_AvgBt" }, "source": [ - "These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for \"tall\" Jacobian matrices, while `jacrev` uses reverse-mode, which is more efficient for \"wide\" Jacobian matrices. For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`." + "These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for \"tall\" Jacobian matrices (more outputs than inputs), while `jacrev` uses reverse-mode, which is more efficient for \"wide\" Jacobian matrices (more inputs than outputs). For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`." ] }, { diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index c24d05c0e7c9..bc2d803f1228 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -276,7 +276,7 @@ print(J) +++ {"id": "iZDL-n_AvgBt"} -These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices, while `jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices. For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`. +These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while `jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`. +++ {"id": "zeKlr7Xz8bfm"} From 7087d0cf081b29d5a1938fe354289dba051e1e41 Mon Sep 17 00:00:00 2001 From: quattro Date: Mon, 26 Aug 2024 16:45:34 -0700 Subject: [PATCH 244/702] fixes cache miss and addresses static arg in BCOO dynamic slice --- jax/experimental/sparse/bcoo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index a0a4df4d898f..9eafa0db0fc2 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -2034,6 +2034,7 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq Returns: out: BCOO array containing the slice. """ + slice_sizes = tuple(operator.index(i) for i in slice_sizes) # Use abstract eval to validate inputs. jax.jit(lax.dynamic_slice, static_argnames=("slice_sizes",)).eval_shape( jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices, @@ -2043,7 +2044,6 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq start_indices = tuple(jnp.asarray(i) for i in start_indices) assert all(jnp.issubdtype(i.dtype, np.integer) for i in start_indices) assert all(i.shape == () for i in start_indices) - slice_sizes = tuple(operator.index(i) for i in slice_sizes) if len(start_indices) != len(slice_sizes) != mat.ndim: raise ValueError(f"bcoo_dynamic_slice: indices must have size mat.ndim={mat.ndim}") if not all(0 <= slice_size <= axis_size for slice_size, axis_size in zip(slice_sizes, mat.shape)): From 90271017375110371779a165358c217beac127d2 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 26 Aug 2024 16:51:05 -0700 Subject: [PATCH 245/702] Update usages of mosaic compiler params with TPUCompilerParams. PiperOrigin-RevId: 667762992 --- jax/_src/pallas/mosaic/core.py | 4 +- jax/experimental/pallas/ops/tpu/all_gather.py | 2 +- .../pallas/ops/tpu/flash_attention.py | 12 ++---- jax/experimental/pallas/ops/tpu/matmul.py | 5 +-- .../pallas/ops/tpu/megablox/gmm.py | 16 +++----- .../paged_attention/paged_attention_kernel.py | 3 +- .../splash_attention_kernel.py | 37 +++++++------------ 7 files changed, 29 insertions(+), 50 deletions(-) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 94b53a4067f0..e549ee05e770 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -68,8 +68,8 @@ class TPUCompilerParams(pallas_core.CompilerParams): device_type: The device type to compile for. """ PLATFORM: ClassVar[str] = "mosaic" - dimension_semantics: list[str] | None = None - allow_input_fusion: list[bool] | None = None + dimension_semantics: Sequence[str] | None = None + allow_input_fusion: Sequence[bool] | None = None vmem_limit_bytes: int | None = None collective_id: int | None = None flags: dict[str, Any] | None = None diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index e121db894122..8fb975504e26 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -136,7 +136,7 @@ def ag_local(x_shard): out = pl.pallas_call( functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh), out_shape=out_shape, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, scratch_shapes=( diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index f0332a87b508..6ce3a1886b1c 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -745,15 +745,13 @@ def kv_segment_ids_index_map( ), out_shape=out_shape, debug=debug, - compiler_params=dict( - mosaic=dict( + compiler_params=pltpu.TPUCompilerParams( dimension_semantics=( "parallel", "parallel", "parallel", "arbitrary", ) - ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids) if save_residuals: @@ -1105,15 +1103,13 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): ), out_shape=out_shapes, debug=debug, - compiler_params=dict( - mosaic=dict( + compiler_params=pltpu.TPUCompilerParams( dimension_semantics=( "parallel", "parallel", "parallel", "arbitrary", ) - ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di) assert dk.shape == k.shape @@ -1450,15 +1446,13 @@ def kv_segment_ids_index_map( ), out_shape=out_shapes, debug=debug, - compiler_params=dict( - mosaic=dict( + compiler_params=pltpu.TPUCompilerParams( dimension_semantics=( "parallel", "parallel", "parallel", "arbitrary", ) - ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di) diff --git a/jax/experimental/pallas/ops/tpu/matmul.py b/jax/experimental/pallas/ops/tpu/matmul.py index 2145fbc95b55..4ff82acbb5dd 100644 --- a/jax/experimental/pallas/ops/tpu/matmul.py +++ b/jax/experimental/pallas/ops/tpu/matmul.py @@ -78,8 +78,7 @@ def matmul( grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k), scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)], ), - compiler_params=dict( - mosaic=dict(dimension_semantics=("parallel", "parallel", "arbitrary")) - ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), debug=debug, )(x, y) diff --git a/jax/experimental/pallas/ops/tpu/megablox/gmm.py b/jax/experimental/pallas/ops/tpu/megablox/gmm.py index 320851422abf..5c2f938597e7 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/gmm.py +++ b/jax/experimental/pallas/ops/tpu/megablox/gmm.py @@ -538,11 +538,8 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=dict( - mosaic=dict( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - ) - ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, cost_estimate=cost_estimate, ) @@ -780,13 +777,10 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=dict( - mosaic=dict( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - cost_estimate=cost_estimate, - ) - ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, + cost_estimate=cost_estimate, ) out = call_gmm( diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 82fa5f7427bd..cd811a874385 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -640,7 +640,8 @@ def paged_attention( grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=dict(mosaic=dict(dimension_semantics=dimension_sematics)), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=dimension_sematics), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 4ae761d78953..536c32e574b2 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -1071,11 +1071,6 @@ def logsumexp_index_map(h, i, *_): out_shapes += [None] out_specs += [None] - mosaic_params = dict( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, - ) - kernel_name = get_kernel_name( dataclasses.asdict(block_sizes), is_mqa=is_mqa, @@ -1112,7 +1107,9 @@ def logsumexp_index_map(h, i, *_): out_specs=out_specs, grid=grid, ), - compiler_params=dict(mosaic=mosaic_params), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + ), out_shape=out_shapes, name=kernel_name, interpret=interpret, @@ -1545,11 +1542,6 @@ def logsumexp_index_map(h, i, *_): ) num_scalar_prefetch = 3 - mosaic_params = dict( - dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), - flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, - ) - kernel_name = get_kernel_name( dict( block_q_dq=bq, @@ -1573,7 +1565,9 @@ def logsumexp_index_map(h, i, *_): grid=grid, ), out_shape=out_shapes, - compiler_params=dict(mosaic=mosaic_params), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), + ), name=kernel_name, interpret=interpret, )( @@ -2088,16 +2082,6 @@ def logsumexp_index_map( ) num_scalar_prefetch = 3 - # We set all dimensions to arbitrary because: - # 1) for kv_seq_len, the splash attention prefetch schedule assumes no - # megacore - # 2) for heads, we are reducing over heads - # 3) for q_seq_len, we are reducing over it to compute dkv - mosaic_params = dict( - dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), - flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, - ) - kernel_name = get_kernel_name( dict( block_q_dkv=bq, @@ -2122,7 +2106,14 @@ def logsumexp_index_map( grid=grid, ), out_shape=out_shapes, - compiler_params=dict(mosaic=mosaic_params), + # We set all dimensions to arbitrary because: + # 1) for kv_seq_len, the splash attention prefetch schedule assumes no + # megacore + # 2) for heads, we are reducing over heads + # 3) for q_seq_len, we are reducing over it to compute dkv + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), + ), name=kernel_name, interpret=interpret, )( From 45b871950e5a8355060244fc75989d4a67bbdc7a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 26 Aug 2024 17:03:27 -0700 Subject: [PATCH 246/702] Fix a number of minor problems in the ROCM build. Change in preparation for adding more presubmits for AMD ROCM. PiperOrigin-RevId: 667766343 --- jaxlib/cuda/BUILD | 2 + jaxlib/gpu/linalg_kernels.cc | 7 +-- jaxlib/gpu/linalg_kernels.cu.cc | 62 ++++++++++++---------- jaxlib/gpu/linalg_kernels.h | 10 ++-- jaxlib/gpu/solver_kernels.cc | 8 +-- jaxlib/gpu/solver_kernels_ffi.cc | 19 +++++-- jaxlib/gpu/sparse.cc | 84 ++++++++++++++++-------------- jaxlib/gpu/sparse_kernels.cc | 20 ++++--- jaxlib/gpu/sparse_kernels.h | 2 +- jaxlib/gpu/triton_kernels.cc | 22 ++++++-- jaxlib/gpu/vendor.h | 30 ++++++----- jaxlib/rocm/{BUILD.bazel => BUILD} | 13 ++++- jaxlib/rocm_plugin_extension.cc | 12 ++--- 13 files changed, 173 insertions(+), 118 deletions(-) rename jaxlib/rocm/{BUILD.bazel => BUILD} (96%) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index a7a47f431a1d..72db9868e427 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -295,6 +295,7 @@ cc_library( "//jaxlib:kernel_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/service:custom_call_status", @@ -323,6 +324,7 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":cusparse_kernels", + "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index b22248409b60..039a9b5c1019 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -40,7 +40,8 @@ absl::Status CholeskyUpdateImpl(gpuStream_t stream, void** buffers, auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); const CholeskyUpdateDescriptor& d = **s; - LaunchCholeskyUpdateKernel(stream, buffers, d); + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(LaunchCholeskyUpdateKernel(stream, buffers, d))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); return absl::OkStatus(); } @@ -98,8 +99,8 @@ ffi::Error CholeskyUpdateFfiImpl(gpuStream_t stream, ffi::AnyBuffer matrix_in, gpuMemcpyDeviceToDevice, stream))); } for (auto n = 0; n < batch; ++n) { - LaunchCholeskyUpdateFfiKernel(stream, matrix, vector, size, - is_single_precision); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(LaunchCholeskyUpdateFfiKernel( + stream, matrix, vector, size, is_single_precision))); FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); } return ffi::Error::Success(); diff --git a/jaxlib/gpu/linalg_kernels.cu.cc b/jaxlib/gpu/linalg_kernels.cu.cc index 50c653d8cf16..7f87d66fb4ef 100644 --- a/jaxlib/gpu/linalg_kernels.cu.cc +++ b/jaxlib/gpu/linalg_kernels.cu.cc @@ -67,8 +67,9 @@ __global__ void CholeskyUpdateKernel(T* rMatrix, T* uVector, int nSize) { } // namespace template -void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers, - int grid_dim, int block_dim, int nSize) { +gpuError_t LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers, + int grid_dim, int block_dim, + int nSize) { T* rMatrix = reinterpret_cast(buffers[2]); T* uVector = reinterpret_cast(buffers[3]); @@ -77,39 +78,40 @@ void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers, reinterpret_cast(&uVector), reinterpret_cast(&nSize), }; - gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, - block_dim, arg_ptrs, - /*dynamic_shared_mem_bytes=*/0, stream); + return gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, + block_dim, arg_ptrs, + /*dynamic_shared_mem_bytes=*/0, stream); } -void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, - CholeskyUpdateDescriptor descriptor) { +gpuError_t LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, + CholeskyUpdateDescriptor descriptor) { int nSize = descriptor.matrix_size; LinalgType type = descriptor.linalg_type; int dev = 0; gpuDeviceProp deviceProp; - gpuGetDeviceProperties(&deviceProp, dev); + gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev); + if (err != gpuSuccess) { + return err; + } int block_dim = deviceProp.maxThreadsPerBlock; int grid_dim = deviceProp.multiProcessorCount; switch (type) { case LinalgType::F64: - LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, - block_dim, nSize); - break; + return LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, + block_dim, nSize); case LinalgType::F32: - LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, - block_dim, nSize); - break; + return LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, + block_dim, nSize); } } template -void LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix, - void* vector, int grid_dim, - int block_dim, int nSize) { +gpuError_t LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix, + void* vector, int grid_dim, + int block_dim, int nSize) { T* rMatrix = reinterpret_cast(matrix); T* uVector = reinterpret_cast(vector); @@ -118,26 +120,30 @@ void LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix, reinterpret_cast(&uVector), reinterpret_cast(&nSize), }; - gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, - block_dim, arg_ptrs, - /*dynamic_shared_mem_bytes=*/0, stream); + return gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, + block_dim, arg_ptrs, + /*dynamic_shared_mem_bytes=*/0, stream); } -void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, - void* vector, int size, - bool is_single_precision) { +gpuError_t LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, + void* vector, int size, + bool is_single_precision) { int dev = 0; gpuDeviceProp deviceProp; - gpuGetDeviceProperties(&deviceProp, dev); + gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev); + if (err != gpuSuccess) { + return err; + } + int block_dim = deviceProp.maxThreadsPerBlock; int grid_dim = deviceProp.multiProcessorCount; if (is_single_precision) { - LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, grid_dim, - block_dim, size); + return LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, + grid_dim, block_dim, size); } else { - LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, grid_dim, - block_dim, size); + return LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, + grid_dim, block_dim, size); } } diff --git a/jaxlib/gpu/linalg_kernels.h b/jaxlib/gpu/linalg_kernels.h index 47ada398c3a2..2c41b7f4350d 100644 --- a/jaxlib/gpu/linalg_kernels.h +++ b/jaxlib/gpu/linalg_kernels.h @@ -36,15 +36,15 @@ struct CholeskyUpdateDescriptor { std::int64_t matrix_size; // leading dim (N) for a square (NxN)matrix }; -void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, - CholeskyUpdateDescriptor descriptor); +gpuError_t LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, + CholeskyUpdateDescriptor descriptor); void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); -void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, - void* vector, int size, - bool is_single_precision); +gpuError_t LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, + void* vector, int size, + bool is_single_precision); XLA_FFI_DECLARE_HANDLER_SYMBOL(CholeskyUpdateFfi); void LaunchLuPivotsToPermutationKernel(gpuStream_t stream, diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index 8d90c70537c7..8c22dfcdbca7 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" @@ -421,9 +421,9 @@ static absl::Status Syevd_(gpuStream_t stream, void** buffers, int output_idx = 1; // with static shapes buffers[1] is the first output if (d.batch == -1) { // the batch is passed as a second operand - gpuMemcpyAsync((void*)&batch, - reinterpret_cast(buffers[1]), - sizeof(batch), gpuMemcpyDeviceToHost, stream); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( + (void*)&batch, reinterpret_cast(buffers[1]), + sizeof(batch), gpuMemcpyDeviceToHost, stream))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); output_idx = 2; } diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 2b1f5552977f..6e988a6ca5e6 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -61,6 +61,17 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, return impl(__VA_ARGS__); \ } +#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \ + if (dataType == ffi::F32) { \ + return impl(__VA_ARGS__); \ + } else if (dataType == ffi::F64) { \ + return impl(__VA_ARGS__); \ + } else if (dataType == ffi::C64) { \ + return impl(__VA_ARGS__); \ + } else if (dataType == ffi::C128) { \ + return impl(__VA_ARGS__); \ + } + // LU decomposition: getrf namespace { @@ -189,8 +200,8 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ipiv->dimensions(), {batch, std::min(rows, cols)}, "ipiv", "getrf")); FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "getrf")); if (batch > 1 && rows == cols && rows / batch <= 128) { - SOLVER_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a, out, - ipiv, info); + SOLVER_BLAS_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a, + out, ipiv, info); } else { SOLVER_DISPATCH_IMPL(GetrfImpl, batch, rows, cols, stream, scratch, a, out, ipiv, info); @@ -345,8 +356,8 @@ ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, FFI_RETURN_IF_ERROR(CheckShape( tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqrf")); if (batch > 1 && rows / batch <= 128 && cols / batch <= 128) { - SOLVER_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream, scratch, - a, out, tau); + SOLVER_BLAS_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream, + scratch, a, out, tau); } else { SOLVER_DISPATCH_IMPL(GeqrfImpl, batch, rows, cols, stream, scratch, a, out, tau); diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index b9eb51388fa9..2eeb94e309ce 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/vendor.h" @@ -57,14 +58,19 @@ gpusparseIndexType_t DtypeToCuSparseIndexType(const dtype& np_type) { gpuDataType DtypeToCudaDataType(const dtype& np_type) { static auto* types = new absl::flat_hash_map, gpuDataType>({ - {{'f', 2}, GPU_R_16F}, {{'c', 4}, GPU_C_16F}, {{'f', 4}, GPU_R_32F}, - {{'c', 8}, GPU_C_32F}, {{'f', 8}, GPU_R_64F}, - {{'c', 16}, GPU_C_64F}, + {{'f', 2}, GPU_R_16F}, + {{'c', 4}, GPU_C_16F}, + {{'f', 4}, GPU_R_32F}, + {{'c', 8}, GPU_C_32F}, + {{'f', 8}, GPU_R_64F}, + {{'c', 16}, GPU_C_64F}, #ifdef JAX_GPU_CUDA - {{'i', 1}, CUDA_R_8I}, {{'u', 1}, CUDA_R_8U}, - {{'i', 4}, CUDA_R_32I}, {{'u', 4}, CUDA_R_32U}, + {{'i', 1}, CUDA_R_8I}, + {{'u', 1}, CUDA_R_8U}, + {{'i', 4}, CUDA_R_32I}, + {{'u', 4}, CUDA_R_32U}, #if JAX_GPU_HAVE_SPARSE - {{'V', 2}, CUDA_R_16BF}, + {{'V', 2}, CUDA_R_16BF}, #endif // JAX_GPU_HAVE_SPARSE #endif // JAX_GPU_CUDA }); @@ -78,9 +84,8 @@ gpuDataType DtypeToCudaDataType(const dtype& np_type) { } // Returns the descriptor for a Sparse matrix. SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype, - const dtype& index_dtype, - int rows, int cols, int nnz, - int batch_count, + const dtype& index_dtype, int rows, + int cols, int nnz, int batch_count, int batch_stride) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); gpusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype); @@ -89,16 +94,15 @@ SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype, } // Returns the descriptor for a Dense matrix. -DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype, - int rows, int cols, int batch_count, +DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype, int rows, + int cols, int batch_count, int batch_stride) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride}; } // Returns the descriptor for a Dense vector. -DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, - int size) { +DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, int size) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); return DenseVecDescriptor{value_type, size}; } @@ -107,9 +111,10 @@ DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, // CsrToDense: Convert CSR matrix to dense matrix // Returns the descriptor for a Sparse matrix. -std::pair BuildCsrToDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { +std::pair BuildCsrToDenseDescriptor(const dtype& data_dtype, + const dtype& index_dtype, + int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -185,8 +190,8 @@ void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, // Returns the descriptor for a CsrFromDense operation. std::pair BuildCsrFromDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { + const dtype& data_dtype, const dtype& index_dtype, int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -261,9 +266,8 @@ void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, // Returns the descriptor for a CsrMatvec operation. std::pair BuildCsrMatvecDescriptor( - const dtype& data_dtype, const dtype& x_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int nnz, bool transpose) { + const dtype& data_dtype, const dtype& x_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -292,7 +296,7 @@ std::pair BuildCsrMatvecDescriptor( JAX_THROW_IF_ERROR( JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type))); size_t buffer_size; - SparseConst alpha = ConstOne(y.type); + SparseConst alpha = ValueOrThrow(ConstOne(y.type)); SparseConst beta = ConstZero(y.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, @@ -309,9 +313,9 @@ std::pair BuildCsrMatvecDescriptor( // Returns the descriptor for a CsrMatmat operation. std::pair BuildCsrMatmatDescriptor( - const dtype& data_dtype, const dtype& b_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int BCcols, int nnz, bool transpose) { + const dtype& data_dtype, const dtype& b_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int BCcols, int nnz, + bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -344,7 +348,7 @@ std::pair BuildCsrMatmatDescriptor( JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, empty, C.type, GPUSPARSE_ORDER_ROW))); size_t buffer_size; - SparseConst alpha = ConstOne(C.type); + SparseConst alpha = ValueOrThrow(ConstOne(C.type)); SparseConst beta = ConstZero(C.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize( handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, @@ -360,9 +364,10 @@ std::pair BuildCsrMatmatDescriptor( // CooToDense: Convert COO matrix to dense matrix // Returns the descriptor for a CooToDense operation. -std::pair BuildCooToDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { +std::pair BuildCooToDenseDescriptor(const dtype& data_dtype, + const dtype& index_dtype, + int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -398,8 +403,8 @@ std::pair BuildCooToDenseDescriptor( // Returns the descriptor for a CooFromDense operation. std::pair BuildCooFromDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { + const dtype& data_dtype, const dtype& index_dtype, int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -434,9 +439,8 @@ std::pair BuildCooFromDenseDescriptor( // Returns the descriptor for a CooMatvec operation. std::pair BuildCooMatvecDescriptor( - const dtype& data_dtype, const dtype& x_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int nnz, bool transpose) { + const dtype& data_dtype, const dtype& x_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -465,7 +469,7 @@ std::pair BuildCooMatvecDescriptor( JAX_THROW_IF_ERROR( JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type))); size_t buffer_size; - SparseConst alpha = ConstOne(y.type); + SparseConst alpha = ValueOrThrow(ConstOne(y.type)); SparseConst beta = ConstZero(y.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, @@ -482,10 +486,10 @@ std::pair BuildCooMatvecDescriptor( // Returns the descriptor for a CooMatmat operation. std::pair BuildCooMatmatDescriptor( - const dtype& data_dtype, const dtype& b_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int BCcols, int nnz, bool transpose, int batch_count, - int lhs_batch_stride, int rhs_batch_stride) { + const dtype& data_dtype, const dtype& b_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int BCcols, int nnz, + bool transpose, int batch_count, int lhs_batch_stride, + int rhs_batch_stride) { // Three batch modes are supported, C_i = A_i B, C_i = A B_i, and // Ci = A_i B_i, where `i` denotes the batch dimension. // All three matrices A, B, and C must have the same batch count. @@ -535,7 +539,7 @@ std::pair BuildCooMatmatDescriptor( JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch( mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride))); size_t buffer_size; - SparseConst alpha = ConstOne(C.type); + SparseConst alpha = ValueOrThrow(ConstOne(C.type)); SparseConst beta = ConstZero(C.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize( handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 93c6aef17008..a44d4b33149d 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" @@ -58,7 +59,7 @@ SparseConst ConstZero(gpuDataType type) { return c; } -SparseConst ConstOne(gpuDataType type) { +absl::StatusOr ConstOne(gpuDataType type) { SparseConst c; std::memset(&c, 0, sizeof(c)); switch (type) { @@ -138,6 +139,9 @@ SparseConst ConstOne(gpuDataType type) { case GPU_C_64F: c.f64[0] = 1.0; break; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported data type: ", type)); } return c; } @@ -248,7 +252,7 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.y.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.y.type)); SparseConst beta = ConstZero(d.y.type); gpusparseSpMatDescr_t mat_a = 0; @@ -305,7 +309,7 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.C.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.C.type)); SparseConst beta = ConstZero(d.C.type); gpusparseSpMatDescr_t mat_a = 0; @@ -446,7 +450,7 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.y.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.y.type)); SparseConst beta = ConstZero(d.y.type); gpusparseSpMatDescr_t mat_a = 0; @@ -502,7 +506,7 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.C.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.C.type)); SparseConst beta = ConstZero(d.C.type); gpusparseSpMatDescr_t mat_a = 0; @@ -567,7 +571,7 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, T* du = static_cast(buffers[2]); T* B = static_cast(buffers[3]); T* X = static_cast(buffers[4]); - void* buffer = static_cast(buffers[5]); + void* buffer = static_cast(buffers[5]); // The solution X is written in place to B. We need to therefore copy the // contents of B into the output buffer X and pass that into the kernel as B. @@ -581,8 +585,8 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream))); } for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(computeGtsv2( - handle.get(), m, n, dl, d, du, X, ldb, buffer))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + computeGtsv2(handle.get(), m, n, dl, d, du, X, ldb, buffer))); dl += m; d += m; du += m; diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 2180767b0cf7..48433b3d6eaa 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -51,7 +51,7 @@ union SparseConst { }; SparseConst ConstZero(gpuDataType type); -SparseConst ConstOne(gpuDataType type); +absl::StatusOr ConstOne(gpuDataType type); struct SparseMatDescriptor { gpuDataType value_type; diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index c96c6b5c54b0..c4a9af5ffe2e 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -34,7 +34,11 @@ #ifdef JAX_GPU_CUDA #include "xla/stream_executor/cuda/cuda_asm_compiler.h" -#endif +#endif // JAX_GPU_CUDA + +#ifdef JAX_GPU_HIP +#include "tsl/platform/env.h" +#endif // JAX_GPU_HIP #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) @@ -44,7 +48,12 @@ namespace { constexpr float kBenchmarkTimeMillis = 10.; struct gpuModuleDeleter { - void operator()(gpuModule_t module) { gpuModuleUnload(module); } + void operator()(gpuModule_t module) { + absl::Status status = JAX_AS_STATUS(gpuModuleUnload(module)); + if (!status.ok()) { + LOG(WARNING) << "Failed to unload GPU module: " << status; + } + } }; using OwnedGPUmodule = @@ -52,11 +61,11 @@ using OwnedGPUmodule = absl::StatusOr GetStreamDevice(gpuStream_t stream) { gpuDevice_t device; - gpuContext_t context; #ifdef JAX_GPU_HIP int device_id = gpuGetStreamDeviceId(stream); GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id)); #else // JAX_GPU_CUDA + gpuContext_t context; GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context)); GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context)); absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); }; @@ -210,7 +219,12 @@ class ModuleImage { } GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context)); - absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); }; + absl::Cleanup ctx_restorer = [] { + absl::Status status = JAX_AS_STATUS(gpuCtxPopCurrent(nullptr)); + if (!status.ok()) { + LOG(WARNING) << "Failed to pop GPU context: " << status; + } + }; gpuModule_t module; GPU_RETURN_IF_ERROR(gpuModuleLoadData(&module, module_image_.data())); diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index ef635bebd401..077d3bb54185 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -22,18 +22,20 @@ limitations under the License. #if defined(JAX_GPU_CUDA) -#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cooperative_groups.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuda_fp8.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export -#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export +// IWYU pragma: begin_exports +#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" +#include "third_party/gpus/cuda/include/cooperative_groups.h" +#include "third_party/gpus/cuda/include/cuComplex.h" +#include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_fp8.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/cufft.h" +#include "third_party/gpus/cuda/include/cusolverDn.h" +#include "third_party/gpus/cuda/include/cusolver_common.h" +#include "third_party/gpus/cuda/include/cusparse.h" +#include "third_party/gpus/cudnn/cudnn.h" +// IWYU pragma: end_exports #if CUDA_VERSION < 11080 #error "JAX requires CUDA 11.8 or newer." @@ -305,11 +307,13 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #elif defined(JAX_GPU_HIP) -#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h" +// IWYU pragma: begin_exports +#include "rocm/include/hip/hip_cooperative_groups.h" #include "rocm/include/hip/hip_runtime_api.h" #include "rocm/include/hipblas/hipblas.h" #include "rocm/include/hipsolver/hipsolver.h" #include "rocm/include/hipsparse/hipsparse.h" +// IWYU pragma: end_exports #define JAX_GPU_NAMESPACE hip #define JAX_GPU_PREFIX "hip" diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD similarity index 96% rename from jaxlib/rocm/BUILD.bazel rename to jaxlib/rocm/BUILD index 58dfd076dbab..ce856ae5f83d 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD @@ -14,6 +14,7 @@ # AMD HIP kernels +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "if_rocm_is_configured", @@ -23,7 +24,10 @@ load( licenses(["notice"]) -package(default_visibility = ["//:__subpackages__"]) +package( + default_applicable_licenses = [], + default_visibility = ["//:__subpackages__"], +) cc_library( name = "hip_vendor", @@ -203,6 +207,7 @@ pybind_extension( ":hipsolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", @@ -223,6 +228,7 @@ cc_library( "//jaxlib:kernel_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", @@ -243,6 +249,7 @@ pybind_extension( ":hip_gpu_kernel_helpers", ":hip_vendor", ":hipsparse_kernels", + "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -277,6 +284,7 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -376,14 +384,15 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", - "@xla//xla/stream_executor/gpu:asm_compiler", "@xla//xla/tsl/util:env_var", ], ) diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm_plugin_extension.cc index dde4e57a97cf..8a732380d1e2 100644 --- a/jaxlib/rocm_plugin_extension.cc +++ b/jaxlib/rocm_plugin_extension.cc @@ -35,8 +35,8 @@ namespace nb = nanobind; namespace xla { namespace { absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, - nb::capsule fn, int api_version, - XLA_FFI_Handler_Traits traits) { + nb::capsule fn, int api_version, + XLA_FFI_Handler_Traits traits) { if (c_api->extension_start == nullptr) { return Unimplemented("The plugin does not have extension."); } @@ -139,11 +139,11 @@ NB_MODULE(rocm_plugin_extension, m) { void* data_ptr = reinterpret_cast(data_value); hipError_t result = hipPointerGetAttribute(static_cast(&device_ordinal), - HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, - reinterpret_cast(data_ptr)); + HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(data_ptr)); if (result != hipSuccess) { - LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << data_ptr - << ". Error: " << ToString(result); + LOG(FATAL) << "Not able to get the device_ordinal for ptr: " + << data_ptr << ". Error: " << ToString(result); } return device_ordinal; }, From d63df3974469f8f395e88630adbefa1ea3cbbffa Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 26 Aug 2024 17:18:48 -0700 Subject: [PATCH 247/702] Support donating arrays with non-default layouts by setting up XLA donation directly instead of defining aliasing for arrays with potentially incompatible layouts. We only fallback to xla dontation for exactly those arrays which have input and output layouts explicitly set to conflicting layouts. PiperOrigin-RevId: 667770224 --- jax/_src/interpreters/mlir.py | 37 ++++++++++++++++++--------- tests/layout_test.py | 48 +++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 12 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 75353197ae2b..ab2c77833f15 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1066,14 +1066,15 @@ def lower_jaxpr_to_module( if num_partitions > 1 and ( result_shardings is None or all(s is None for s in result_shardings)): xla_donated_args = donated_args + donated_args = [False] * len(donated_args) if xla_donated_args is None: - input_output_aliases, donated_args = _set_up_aliases( + input_output_aliases, donated_args, xla_donated_args = _set_up_aliases( input_output_aliases, in_avals, out_avals, donated_args, - arg_memory_kinds, result_memory_kinds) + arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts) unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) if unlowerable_effects: raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') - if xla_donated_args is None and any(donated_args): + if any(donated_args): unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d] msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation." if not platforms_with_donation: @@ -1082,11 +1083,6 @@ def lower_jaxpr_to_module( warnings.warn("Some donated buffers were not usable:" f" {', '.join(unused_donations)}.\n{msg}") - if xla_donated_args is not None: - assert input_output_aliases is None - if input_output_aliases is not None: - assert xla_donated_args is None - # Delete donated_args by default here, since it's not needed beyond this point del donated_args @@ -1170,8 +1166,10 @@ def emit_diagnostic_info(d): ctx.shape_poly_state) -def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args, - arg_memory_kinds, result_memory_kinds): +def _set_up_aliases(input_output_aliases, avals_in, avals_out, + donated_args, + arg_memory_kinds, result_memory_kinds, + in_layouts, out_layouts): if input_output_aliases is None: input_output_aliases = [None] * len(avals_in) else: @@ -1200,6 +1198,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args, if donated and aliased is None: donations[(aval, am)].append(i) + xla_donated_args = None out_donated_args = list(donated_args) for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)): # Only donate if memory kinds match. Relax this when the compiler can @@ -1207,10 +1206,24 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args, key = (aval, rm) if donations.get(key, ()): input_id = donations[key].popleft() - input_output_aliases[input_id] = i out_donated_args[input_id] = False + if (in_layouts is None or + out_layouts is None or + in_layouts[input_id] == out_layouts[i] or + # We can alias if XLA performs layout assignment because XLA will + # respect the aliases when assigning layouts. Its only for two + # mismatched explicitly assigned layouts that XLA will certainly + # fail. + isinstance(in_layouts[input_id], (AutoLayout, type(None))) or + isinstance(out_layouts[i], (AutoLayout, type(None)))): + input_output_aliases[input_id] = i + else: + # Fallback to xla donation if layouts don't match. + if xla_donated_args is None: + xla_donated_args = [False] * len(avals_in) + xla_donated_args[input_id] = True - return input_output_aliases, out_donated_args + return input_output_aliases, out_donated_args, xla_donated_args Token = ir.Value token_type = hlo.TokenType.get diff --git a/tests/layout_test.py b/tests/layout_test.py index 3cfc117b925e..33f3318bf4bb 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -527,6 +527,54 @@ def test_in_layouts_jit_jnp_input(self): out4 = f(np_inp) self.assertArraysEqual(out4, np_inp + 1) + def test_layout_donation(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (16, 128) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + custom_dll = DLL(major_to_minor=(0, 1)) + arr = jax.device_put(np_inp, Layout(custom_dll, s)) + + @partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0) + def f(x): + return x + + out = f(arr) + self.assertTrue(arr.is_deleted()) + + def test_layout_donation_auto(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (128, 16) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + arr = jax.device_put(np_inp, s) + + @partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0) + def f(x): + return x * x + + out = f(arr) + self.assertTrue(arr.is_deleted()) + + def test_layout_donation_matching_in_and_out(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (128, 16) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + custom_dll = DLL(major_to_minor=(0, 1)) + l = Layout(custom_dll, s) + arr = jax.device_put(np_inp, l) + + @partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0) + def f(x): + return x * x + + out = f(arr) + self.assertTrue(arr.is_deleted()) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 087f5697595d4d0d0008f1dec3466e78be6c7283 Mon Sep 17 00:00:00 2001 From: Oliver Dutton Date: Tue, 27 Aug 2024 14:55:54 +0100 Subject: [PATCH 248/702] Fast jvp for 2x3 and 2x2 determinants Speed up jvp's for 3x3 and 2x2 determinants The current det implementation custom_jvp is all encompassing, so while there's fast functions for the 2 and 3d cases they still go via the slow general jvp. PR localises the custom_jvp to the generic case. This general case is ~10x slower on GPU (A100) and ~250x slower on TPU (v2). ```python import jax from jax import numpy as jnp from jax import random def det_3x3(a: jax.Array) -> jax.Array: return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] + a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] + a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] - a[..., 0, 2] * a[..., 1, 1] * a[..., 2, 0] - a[..., 0, 0] * a[..., 1, 2] * a[..., 2, 1] - a[..., 0, 1] * a[..., 1, 0] * a[..., 2, 2]) key = random.key(42) x = random.normal(key, (int(1e5), 3, 3)) general_grad = jax.grad(lambda x: jnp.linalg.det(x).sum()) direct_3by3_grad = jax.vmap(jax.grad(det_3x3)) general_grad, direct_3by3_grad = (jax.jit(f) for f in (general_grad, direct_3by3_grad)) _ = jax.block_until_ready(general_grad(x)) _ = jax.block_until_ready(direct_3by3_grad(x)) %timeit _ = jax.block_until_ready(general_grad(x)) %timeit _ = jax.block_until_ready(direct_3by3_grad(x)) --- jax/_src/numpy/linalg.py | 24 ++++++++++++++---------- tests/linalg_test.py | 2 +- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index bb0ba2e85499..2af25bcf80c4 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -662,6 +662,19 @@ def _det_3x3(a: Array) -> Array: @custom_jvp +def _det(a): + sign, logdet = slogdet(a) + return sign * ufuncs.exp(logdet).astype(sign.dtype) + + +@_det.defjvp +def _det_jvp(primals, tangents): + x, = primals + g, = tangents + y, z = _cofactor_solve(x, g) + return y, jnp.trace(z, axis1=-1, axis2=-2) + + @jit def det(a: ArrayLike) -> Array: """ @@ -692,21 +705,12 @@ def det(a: ArrayLike) -> Array: elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3: return _det_3x3(a) elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]: - sign, logdet = slogdet(a) - return sign * ufuncs.exp(logdet).astype(sign.dtype) + return _det(a) else: msg = "Argument to _det() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) -@det.defjvp -def _det_jvp(primals, tangents): - x, = primals - g, = tangents - y, z = _cofactor_solve(x, g) - return y, jnp.trace(z, axis1=-1, axis2=-2) - - def eig(a: ArrayLike) -> tuple[Array, Array]: """ Compute the eigenvalues and eigenvectors of a square array. diff --git a/tests/linalg_test.py b/tests/linalg_test.py index ce7b0b1991c8..6cd0110538eb 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -111,7 +111,7 @@ def testDetOfSingularMatrix(self): self.assertAllClose(np.float32(0), jsp.linalg.det(x)) @jtu.sample_product( - shape=[(1, 1), (3, 3), (2, 4, 4)], + shape=[(1, 1), (2, 2), (3, 3), (2, 2, 2), (2, 3, 3), (2, 4, 4), (5, 7, 7)], dtype=float_types, ) @jtu.skip_on_flag("jax_skip_slow_tests", True) From 2f3d428e7835080f2df8e12738020bc8f43eb541 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Tue, 27 Aug 2024 19:40:28 +0530 Subject: [PATCH 249/702] Improved docs for jnp.fix and jnp.trunc --- jax/_src/numpy/lax_numpy.py | 59 +++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f5abb3971cfc..d85bb5c2ef83 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -465,9 +465,36 @@ def result_type(*args: Any) -> DType: return dtypes.result_type(*args) -@util.implements(np.trunc, module='numpy') @jit def trunc(x: ArrayLike) -> Array: + """Round input to the nearest integer towards zero. + + JAX implementation of :func:`numpy.trunc`. + + Args: + x: input array or scalar. + + Returns: + An array with same shape and dtype as ``x`` containing the rounded values. + + See also: + - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero. + - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. + - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. + + Examples: + >>> key = jax.random.key(42) + >>> x = jax.random.uniform(key, (3, 3), minval=-10, maxval=10) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(x) + [[ 2.88 -3.55 -6.13] + [ 7.73 4.49 -6.16] + [-3.1 -4.95 2.64]] + >>> jnp.trunc(x) + Array([[ 2., -3., -6.], + [ 7., 4., -6.], + [-3., -4., 2.]], dtype=float32) + """ util.check_arraylike('trunc', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): return lax_internal.asarray(x) @@ -2558,9 +2585,37 @@ def _round_float(x: ArrayLike) -> Array: round_ = round -@util.implements(np.fix, skip_params=['out']) @jit def fix(x: ArrayLike, out: None = None) -> Array: + """Round input to the nearest integer towards zero. + + JAX implementation of :func:`numpy.fix`. + + Args: + x: input array. + out: unused by JAX. + + Returns: + An array with same shape and dtype as ``x`` containing the rounded values. + + See also: + - :func:`jax.numpy.trunc`: Rounds the input to nearest integer towards zero. + - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. + - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. + + Examples: + >>> key = jax.random.key(0) + >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(x) + [[-1.45 1.04 -0.72] + [-2.69 1.74 -0.6 ] + [-2.49 -2.23 2.68]] + >>> jnp.fix(x) + Array([[-1., 1., -0.], + [-2., 1., -0.], + [-2., -2., 2.]], dtype=float32) + """ util.check_arraylike("fix", x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.fix is not supported.") From 859eacb5a12ea1eae68e792877f2c6e4f28b244f Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 27 Aug 2024 16:57:53 +0100 Subject: [PATCH 250/702] Fix mypy error --- build/rocm/run_single_gpu.py | 2 +- build/rocm/tools/build_wheels.py | 2 +- build/rocm/tools/symbols.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build/rocm/run_single_gpu.py b/build/rocm/run_single_gpu.py index 4e7660ca1f15..e1fa26c72872 100755 --- a/build/rocm/run_single_gpu.py +++ b/build/rocm/run_single_gpu.py @@ -169,7 +169,7 @@ def run_parallel(all_testmodules, p, c): def find_num_gpus(): - cmd = ["lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"] + cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"] _, _, stdout = run_shell_command(cmd, shell=True) return int(stdout) diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index 1ba9e0b910db..b6dd1256e2f5 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -101,7 +101,7 @@ def build_jax_wheel(jax_path, python_version): env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) - pattern = re.compile("Successfully built jax-.+ and (jax-.+\.whl)\n") + pattern = re.compile(r"Successfully built jax-.+ and (jax-.+\.whl)\n") _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout") diff --git a/build/rocm/tools/symbols.py b/build/rocm/tools/symbols.py index f2bf2d561f72..2982bb187c9e 100644 --- a/build/rocm/tools/symbols.py +++ b/build/rocm/tools/symbols.py @@ -42,7 +42,7 @@ def main(): def highest_for_file(sofile): output = subprocess.check_output(["objdump", "-T", sofile]) - r = re.compile("\(GLIBC_(.*)\)") + r = re.compile(r"\(GLIBC_(.*)\)") versions = {} for line in output.decode("utf-8").split("\n"): From afff0e09aa85f0abf767eeae56b83c7ddadb16b2 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 27 Aug 2024 13:30:12 -0700 Subject: [PATCH 251/702] Improve the error message to specify shapes too PiperOrigin-RevId: 668117141 --- jax/_src/pjit.py | 6 +++--- jax/experimental/shard_map.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0997f1107a02..7badefab7922 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2520,8 +2520,8 @@ def _sharding_constraint_impl(x, sharding, layout, resource_env, unconstrained_dims): if (isinstance(sharding, NamedSharding) and isinstance(sharding.mesh, AbstractMesh)): + aval = shaped_abstractify(x) if not hasattr(x, 'sharding'): - aval = shaped_abstractify(x) raise ValueError( 'Target sharding contains a `jax.sharding.AbstractMesh` which' ' requires the input passed should be a `jax.Array`. Got' @@ -2530,12 +2530,12 @@ def _sharding_constraint_impl(x, sharding, layout, resource_env, raise TypeError( 'The sharding on the input must be a `NamedSharding` since the target' ' sharding has an `AbstractMesh` in it. Got sharding type' - f' {type(x.sharding)}') + f' {type(x.sharding)} for shape {aval.str_short()}') if x.sharding.mesh.shape_tuple != sharding.mesh.shape_tuple: raise ValueError( f'Mesh shape of the input {x.sharding.mesh.shape_tuple} does not' ' match the mesh shape of the target sharding' - f' {sharding.mesh.shape_tuple}') + f' {sharding.mesh.shape_tuple} for shape {aval.str_short()}') sharding = NamedSharding._from_parsed_pspec( x.sharding.mesh, sharding._parsed_pspec) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 656f97027c18..94d2e2693e87 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -28,7 +28,6 @@ import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec from jax._src import ad_checkpoint -from jax._src import array from jax._src import ad_util from jax._src import callback from jax._src import config @@ -719,10 +718,11 @@ def get_mesh_from_args(args_flat, mesh): for a in args_flat: if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding): if a.sharding.mesh.shape_tuple != mesh.shape_tuple: + aval = shaped_abstractify(a) raise ValueError( f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" " match the mesh shape passed to shard_map " - f" {mesh.shape_tuple}") + f" {mesh.shape_tuple} for shape {aval.str_short()}") mesh = a.sharding.mesh if isinstance(mesh, AbstractMesh): raise ValueError( From a9e54b3e0a0e01300b983679c1609df14f20aa40 Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Tue, 20 Aug 2024 15:29:43 -0500 Subject: [PATCH 252/702] Add docker builds for ubu22 and 24 --- build/rocm/Dockerfile.ms | 36 +++++++++++---- build/rocm/docker/Dockerfile.jax-ubu22 | 64 ++++++++++++++++++++++++++ build/rocm/docker/Dockerfile.jax-ubu24 | 63 +++++++++++++++++++++++++ build/rocm/docker/Makefile | 20 ++++++++ 4 files changed, 174 insertions(+), 9 deletions(-) create mode 100644 build/rocm/docker/Dockerfile.jax-ubu22 create mode 100644 build/rocm/docker/Dockerfile.jax-ubu24 create mode 100644 build/rocm/docker/Makefile diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index dffe42de77f6..0bcc89f493ce 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -3,10 +3,10 @@ FROM ubuntu:20.04 AS rocm_base ################################################################################ RUN --mount=type=cache,target=/var/cache/apt \ - apt-get update && apt-get install -y python3 + apt-get update && apt-get install -y python3 python-is-python3 # Add target file to help determine which device(s) to build for -ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} # Install ROCM @@ -16,6 +16,7 @@ ENV ROCM_PATH=${ROCM_PATH} ARG ROCM_BUILD_JOB ARG ROCM_BUILD_NUM RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + --mount=type=cache,target=/var/cache/apt \ python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM # Set up paths @@ -42,14 +43,30 @@ RUN git clone https://github.com/pyenv/pyenv.git /pyenv ENV PYENV_ROOT /pyenv ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH RUN pyenv install $PYTHON_VERSION -RUN eval "$(pyenv init -)" && \ +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + eval "$(pyenv init -)" && \ pyenv local ${PYTHON_VERSION} && \ pip3 install --upgrade --force-reinstall setuptools pip && \ - pip install \ - numpy setuptools build wheel six auditwheel scipy \ - pytest pytest-html pytest_html_merger pytest-reportlog \ - pytest-rerunfailures cloudpickle portpicker matplotlib absl-py \ - flatbuffers hypothesis pytest-json-report pytest-csv + pip3 install \ + "numpy<2" \ + build \ + wheel \ + six \ + auditwheel \ + scipy \ + pytest \ + pytest-html \ + pytest_html_merger \ + pytest-reportlog \ + pytest-rerunfailures \ + pytest-json-report \ + pytest-csv \ + cloudpickle \ + portpicker \ + matplotlib \ + absl-py \ + flatbuffers \ + hypothesis ################################################################################ FROM rocm_base AS rt_build @@ -65,6 +82,7 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ com.amdgpu.jax_commit="$JAX_COMMIT" \ com.amdgpu.xla_commit="$XLA_COMMIT" -RUN --mount=type=bind,source=wheelhouse,target=/wheelhouse \ +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + --mount=type=bind,source=wheelhouse,target=/wheelhouse \ pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/docker/Dockerfile.jax-ubu22 b/build/rocm/docker/Dockerfile.jax-ubu22 new file mode 100644 index 000000000000..ba64efbbc682 --- /dev/null +++ b/build/rocm/docker/Dockerfile.jax-ubu22 @@ -0,0 +1,64 @@ +FROM ubuntu:22.04 + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -y python3 python-is-python3 + +# Add target file to help determine which device(s) to build for +ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} + +# Install ROCM +ARG ROCM_VERSION=6.0.0 +ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} +ENV ROCM_PATH=${ROCM_PATH} +ARG ROCM_BUILD_JOB +ARG ROCM_BUILD_NUM +RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + --mount=type=cache,target=/var/cache/apt \ + python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM + +# Set up paths +ENV HCC_HOME=$ROCM_PATH/hcc +ENV HIP_PATH=$ROCM_PATH/ +ENV OPENCL_ROOT=$ROCM_PATH/opencl +ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}" +ENV PATH="$ROCM_PATH/bin:${PATH}" +ENV PATH="$OPENCL_ROOT/bin:${PATH}" +ENV PATH="/root/bin:/root/.local/bin:$PATH" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + pip3 install --upgrade --force-reinstall setuptools pip && \ + pip3 install \ + "numpy<2" \ + build \ + wheel \ + six \ + auditwheel \ + scipy \ + pytest \ + pytest-html \ + pytest_html_merger \ + pytest-reportlog \ + pytest-rerunfailures \ + pytest-json-report \ + pytest-csv \ + cloudpickle \ + portpicker \ + matplotlib \ + absl-py \ + flatbuffers \ + hypothesis + +ARG JAX_VERSION +ARG JAX_COMMIT +ARG XLA_COMMIT + +LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ + com.amdgpu.python_version="3.10" \ + com.amdgpu.jax_version="$JAX_VERSION" \ + com.amdgpu.jax_commit="$JAX_COMMIT" \ + com.amdgpu.xla_commit="$XLA_COMMIT" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + --mount=type=bind,source=wheelhouse,target=/wheelhouse \ + pip3 install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/docker/Dockerfile.jax-ubu24 b/build/rocm/docker/Dockerfile.jax-ubu24 new file mode 100644 index 000000000000..44c59b1b7e6b --- /dev/null +++ b/build/rocm/docker/Dockerfile.jax-ubu24 @@ -0,0 +1,63 @@ +FROM ubuntu:24.04 + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -y python3 python-is-python3 python3-pip + +# Add target file to help determine which device(s) to build for +ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} + +# Install ROCM +ARG ROCM_VERSION=6.2.0 +ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} +ENV ROCM_PATH=${ROCM_PATH} +ARG ROCM_BUILD_JOB +ARG ROCM_BUILD_NUM +RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + --mount=type=cache,target=/var/cache/apt \ + python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM + +# Set up paths +ENV HCC_HOME=$ROCM_PATH/hcc +ENV HIP_PATH=$ROCM_PATH/ +ENV OPENCL_ROOT=$ROCM_PATH/opencl +ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}" +ENV PATH="$ROCM_PATH/bin:${PATH}" +ENV PATH="$OPENCL_ROOT/bin:${PATH}" +ENV PATH="/root/bin:/root/.local/bin:$PATH" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + pip3 install --break-system-packages \ + "numpy<2" \ + build \ + wheel \ + six \ + auditwheel \ + scipy \ + pytest \ + pytest-html \ + pytest_html_merger \ + pytest-reportlog \ + pytest-rerunfailures \ + pytest-json-report \ + pytest-csv \ + cloudpickle \ + portpicker \ + matplotlib \ + absl-py \ + flatbuffers \ + hypothesis + +ARG JAX_VERSION +ARG JAX_COMMIT +ARG XLA_COMMIT + +LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ + com.amdgpu.python_version="3.12" \ + com.amdgpu.jax_version="$JAX_VERSION" \ + com.amdgpu.jax_commit="$JAX_COMMIT" \ + com.amdgpu.xla_commit="$XLA_COMMIT" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + --mount=type=bind,source=wheelhouse,target=/wheelhouse \ + pip3 install --break-system-packages --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/docker/Makefile b/build/rocm/docker/Makefile new file mode 100644 index 000000000000..7fb38a936a64 --- /dev/null +++ b/build/rocm/docker/Makefile @@ -0,0 +1,20 @@ +.PHONY: all clean + +all: .docker-jax-ubu22 .docker-jax-ubu24 + +clean: clean-jax-ubu22 clean-jax-ubu24 + +ROCM_VERSION = 6.2.0 + +.docker-% : build/rocm/docker/Dockerfile.% + docker build -f $< --tag $(*F) --progress plain \ + --build-arg=ROCM_VERSION=${ROCM_VERSION} \ + --build-arg=JAX_VERSION=$(shell python setup.py -V) \ + --build-arg=JAX_COMMIT=$(shell git rev-parse HEAD) \ + . + @touch $@ + + +clean-%: + -docker rmi $(*F) + @rm -f .docker-$(*F) From db9e44fe565fcf7cccead0e69bff73c7a0ad63ab Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 27 Aug 2024 14:46:50 -0700 Subject: [PATCH 253/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/4170b9b1a900e8fcdee42e0810aacb7e0618701c. PiperOrigin-RevId: 668146966 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b8ef7cd6699f..fa41f6f7d101 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "baf026d13bc20c6232ffeab0991628ce758982f3" -XLA_SHA256 = "28adba477042bda38a541eebb79aeff49067d6bd2bc4c6ded86583ebdee60e08" +XLA_COMMIT = "4170b9b1a900e8fcdee42e0810aacb7e0618701c" +XLA_SHA256 = "5b8bb058d802e8fa83aa70210637d2f8d903348b3f1b5f7e315a04e144f6c123" def repo(): tf_http_archive( From 68be5b5085238e5bfc56f1568c934c8fe126a30c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 27 Aug 2024 14:54:11 -0700 Subject: [PATCH 254/702] CI: update ruff to v0.6.1 --- .pre-commit-config.yaml | 2 +- cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb | 4 +- cloud_tpu_colabs/JAX_demo.ipynb | 13 +- cloud_tpu_colabs/Wave_Equation.ipynb | 1 - docs/_tutorials/advanced-autodiff.md | 4 +- docs/jep/9407-type-promotion.ipynb | 2 - docs/jep/9407-type-promotion.md | 2 - docs/notebooks/Common_Gotchas_in_JAX.ipynb | 12 +- docs/notebooks/Common_Gotchas_in_JAX.md | 12 +- ...tom_derivative_rules_for_Python_code.ipynb | 6 +- ...Custom_derivative_rules_for_Python_code.md | 6 +- ...arrays_and_automatic_parallelization.ipynb | 2 - ...ed_arrays_and_automatic_parallelization.md | 2 - docs/notebooks/How_JAX_primitives_work.ipynb | 142 +++++++++--------- docs/notebooks/How_JAX_primitives_work.md | 142 +++++++++--------- .../Neural_Network_and_Data_Loading.ipynb | 4 +- .../Neural_Network_and_Data_Loading.md | 4 +- .../Writing_custom_interpreters_in_Jax.ipynb | 18 +-- .../Writing_custom_interpreters_in_Jax.md | 18 +-- docs/notebooks/autodiff_cookbook.ipynb | 4 +- docs/notebooks/autodiff_cookbook.md | 4 +- docs/notebooks/autodiff_remat.ipynb | 2 - docs/notebooks/autodiff_remat.md | 2 - docs/notebooks/convolutions.ipynb | 11 +- docs/notebooks/convolutions.md | 11 +- docs/notebooks/external_callbacks.ipynb | 1 - docs/notebooks/external_callbacks.md | 1 - .../neural_network_with_tfds_data.ipynb | 4 +- .../neural_network_with_tfds_data.md | 4 +- docs/notebooks/vmapped_log_probs.ipynb | 36 ++--- docs/notebooks/vmapped_log_probs.md | 33 ++-- docs/sharded-computation.ipynb | 3 - docs/sharded-computation.md | 3 - jax/_src/core.py | 2 +- jax/_src/cudnn/fused_attention_stablehlo.py | 4 +- jax/_src/pallas/mosaic/pipeline.py | 6 +- jax/_src/test_util.py | 2 +- pyproject.toml | 11 ++ tests/mosaic/gpu_test.py | 2 +- tests/notebooks/colab_cpu.ipynb | 9 -- tests/pallas/indexing_test.py | 4 +- 41 files changed, 241 insertions(+), 314 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 355f134f0551..79a0df6e9e38 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: files: \.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.4 + rev: v0.6.1 hooks: - id: ruff diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index 279aef3e9c65..cb5a42ced8c4 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -510,8 +510,8 @@ "outputs": [], "source": [ "image_partitions = P(1, 1, 4, 2)\n", - "sharded_conv = sharded_jit(conv, \n", - " in_parts=(image_partitions, None), \n", + "sharded_conv = sharded_jit(conv,\n", + " in_parts=(image_partitions, None),\n", " out_parts=image_partitions)\n", "\n", "sharded_conv(image, kernel)" diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index 9acb1971c3b6..4952cdbe9365 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -877,7 +877,7 @@ " def g(z):\n", " return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n", " return grad(lambda w: jnp.sum(g(w)))(x)\n", - " \n", + "\n", "f(x)" ] }, @@ -950,17 +950,6 @@ "per_example_hess = pmap(input_hess) # pmap!\n", "per_example_hess(inputs)" ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "u3ggM_WYZ8QC" - }, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/cloud_tpu_colabs/Wave_Equation.ipynb b/cloud_tpu_colabs/Wave_Equation.ipynb index 0591739191e0..16f675a76140 100644 --- a/cloud_tpu_colabs/Wave_Equation.ipynb +++ b/cloud_tpu_colabs/Wave_Equation.ipynb @@ -67,7 +67,6 @@ "source": [ "from functools import partial\n", "import jax\n", - "from jax import jit, pmap\n", "from jax import lax\n", "from jax import tree_util\n", "import jax.numpy as jnp\n", diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index 0449b82e9a5b..20affa8cf29f 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -640,7 +640,7 @@ def our_jacrev(f): y, vjp_fun = vjp(f, x) # Use vmap to do a matrix-Jacobian product. # Here, the matrix is the Euclidean basis, so we get all - # entries in the Jacobian at once. + # entries in the Jacobian at once. J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) return J return jacfun @@ -654,7 +654,7 @@ from jax import jacfwd as builtin_jacfwd def our_jacfwd(f): def jacfun(x): _jvp = lambda s: jvp(f, (x,), (s,))[1] - Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x))) + Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x))) return jnp.transpose(Jt) return jacfun diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index 2aef1768112f..3e99daabed93 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -3317,7 +3317,6 @@ ], "source": [ "# @title\n", - "from jax import dtypes\n", "import jax\n", "import jax.numpy as jnp\n", "import pandas as pd\n", @@ -3802,7 +3801,6 @@ ], "source": [ "# @title\n", - "from jax import dtypes\n", "import jax\n", "import jax.numpy as jnp\n", "import pandas as pd\n", diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index 107bcd8c968b..2d12944f16f9 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -908,7 +908,6 @@ display.HTML(table.to_html()) :tags: [hide-input] # @title -from jax import dtypes import jax import jax.numpy as jnp import pandas as pd @@ -963,7 +962,6 @@ display.HTML(table.to_html()) :tags: [hide-input] # @title -from jax import dtypes import jax import jax.numpy as jnp import pandas as pd diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index d769144406eb..c143b520af51 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -226,7 +226,6 @@ ], "source": [ "import jax.numpy as jnp\n", - "import jax.lax as lax\n", "from jax import make_jaxpr\n", "\n", "# lax.fori_loop\n", @@ -1031,7 +1030,6 @@ } ], "source": [ - "from jax import random\n", "key = random.key(0)\n", "key" ] @@ -1105,8 +1103,8 @@ "print(\"old key\", key)\n", "key, subkey = random.split(key)\n", "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(\" \\---SPLIT --> new key \", key)\n", - "print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" + "print(r\" \\---SPLIT --> new key \", key)\n", + "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" ] }, { @@ -1140,8 +1138,8 @@ "print(\"old key\", key)\n", "key, subkey = random.split(key)\n", "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(\" \\---SPLIT --> new key \", key)\n", - "print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" + "print(r\" \\---SPLIT --> new key \", key)\n", + "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" ] }, { @@ -1701,7 +1699,7 @@ ], "source": [ "init_val = 0\n", - "cond_fun = lambda x: x<10\n", + "cond_fun = lambda x: x < 10\n", "body_fun = lambda x: x+1\n", "lax.while_loop(cond_fun, body_fun, init_val)\n", "# --> array(10, dtype=int32)" diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index edf5c9446743..0b21a57e36c9 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -130,7 +130,6 @@ It is not recommended to use iterators in any JAX function you want to `jit` or :outputId: 52d885fd-0239-4a08-f5ce-0c38cc008903 import jax.numpy as jnp -import jax.lax as lax from jax import make_jaxpr # lax.fori_loop @@ -471,7 +470,6 @@ The random state is described by a special array element that we call a __key__: :id: yPHE7KTWgAWs :outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3 -from jax import random key = random.key(0) key ``` @@ -504,8 +502,8 @@ Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a ne print("old key", key) key, subkey = random.split(key) normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(" \---SPLIT --> new key ", key) -print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom) +print(r" \---SPLIT --> new key ", key) +print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) ``` +++ {"id": "tqtFVE4MthO3"} @@ -519,8 +517,8 @@ We propagate the __key__ and make new __subkeys__ whenever we need a new random print("old key", key) key, subkey = random.split(key) normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(" \---SPLIT --> new key ", key) -print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom) +print(r" \---SPLIT --> new key ", key) +print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) ``` +++ {"id": "0KLYUluz3lN3"} @@ -805,7 +803,7 @@ def while_loop(cond_fun, body_fun, init_val): :outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e init_val = 0 -cond_fun = lambda x: x<10 +cond_fun = lambda x: x < 10 body_fun = lambda x: x+1 lax.while_loop(cond_fun, body_fun, init_val) # --> array(10, dtype=int32) diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 3abb6d9cbaec..ec85f6e63159 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -247,7 +247,6 @@ } ], "source": [ - "import jax.numpy as jnp\n", "\n", "def log1pexp(x):\n", " return jnp.log(1. + jnp.exp(x))\n", @@ -984,7 +983,7 @@ " (a, x_star, x_star_bar),\n", " x_star_bar))\n", " return a_bar, jnp.zeros_like(x_star)\n", - " \n", + "\n", "def rev_iter(f, packed, u):\n", " a, x_star, x_star_bar = packed\n", " _, vjp_x = vjp(lambda x: f(a, x), x_star)\n", @@ -1884,7 +1883,6 @@ } ], "source": [ - "from jax import vjp\n", "\n", "y, f_vjp = vjp(f, 3.)\n", "print(y)" @@ -1983,7 +1981,7 @@ " return x, x\n", "\n", "def debug_bwd(x, g):\n", - " import pdb; pdb.set_trace()\n", + " pdb.set_trace()\n", " return g\n", "\n", "debug.defvjp(debug_fwd, debug_bwd)" diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index ad577d55cd0d..3c60cce0cf30 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -145,7 +145,6 @@ Say we want to write a function called `log1pexp`, which computes $x \mapsto \lo :id: 6lWbTvs40ET- :outputId: 8caff99e-add1-4c70-ace3-212c0c5c6f4e -import jax.numpy as jnp def log1pexp(x): return jnp.log(1. + jnp.exp(x)) @@ -524,7 +523,7 @@ def fixed_point_rev(f, res, x_star_bar): (a, x_star, x_star_bar), x_star_bar)) return a_bar, jnp.zeros_like(x_star) - + def rev_iter(f, packed, u): a, x_star, x_star_bar = packed _, vjp_x = vjp(lambda x: f(a, x), x_star) @@ -965,7 +964,6 @@ print(grad(f)(3.)) :id: s1Pn_qCIODcF :outputId: 423d34e0-35b8-4b57-e89d-f70f20e28ea9 -from jax import vjp y, f_vjp = vjp(f, 3.) print(y) @@ -1015,7 +1013,7 @@ def debug_fwd(x): return x, x def debug_bwd(x, g): - import pdb; pdb.set_trace() + pdb.set_trace() return g debug.defvjp(debug_fwd, debug_bwd) diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 3d8c5b0203d5..8bc0e0a52ce6 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -30,9 +30,7 @@ }, "outputs": [], "source": [ - "import os\n", "\n", - "import functools\n", "from typing import Optional\n", "\n", "import numpy as np\n", diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index cb5d4602c055..c5f3c08ed4f9 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -26,9 +26,7 @@ This tutorial discusses parallelism via `jax.Array`, the unified array object mo ```{code-cell} :id: FNxScTfq3vGF -import os -import functools from typing import Optional import numpy as np diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb index f42e3f74b4e3..0c20fc47dc47 100644 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ b/docs/notebooks/How_JAX_primitives_work.ipynb @@ -15,12 +15,12 @@ "*necula@google.com*, October 2019.\n", "\n", "JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,\n", - "`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, \n", + "`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,\n", "which means that as the Python function executes\n", "the only operations it applies to the data are either inspections of data\n", "attributes such as shape or type, or special operations called JAX primitives.\n", "In particular, a JAX-traceable function is sometimes invoked by JAX with\n", - "abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, \n", + "abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,\n", "which captures the type and the shape of values, but not the concrete data values.\n", "JAX primitives know how to operate on both concrete data\n", "values and on the JAX abstract values.\n", @@ -30,7 +30,7 @@ "to ensure that these transformations\n", "can be composed, e.g., `jit(jacfwd(grad(f)))`.\n", "\n", - "There are pre-defined JAX primitives corresponding to most XLA operations, \n", + "There are pre-defined JAX primitives corresponding to most XLA operations,\n", "e.g., add, matmul, sin, cos, indexing.\n", "JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs\n", "using JAX’s implementation of numpy are JAX-traceable and therefore transformable.\n", @@ -42,8 +42,8 @@ "**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**\n", "\n", "Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n", - "as \"multiply_add(x, y, z) = x * y + z\". \n", - "This function operates on 3 identically-shaped tensors of floating point \n", + "as \"multiply_add(x, y, z) = x * y + z\".\n", + "This function operates on 3 identically-shaped tensors of floating point\n", "values and performs the operations pointwise." ] }, @@ -56,7 +56,7 @@ "## Using existing primitives\n", "\n", "The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other\n", - "functions that are themselves written using JAX primitives, e.g., those \n", + "functions that are themselves written using JAX primitives, e.g., those\n", "defined in the `jax.lax` module:" ] }, @@ -165,7 +165,7 @@ " return str(v)\n", " def pp_values(args):\n", " return \", \".join([pp(arg) for arg in args])\n", - " \n", + "\n", " @functools.wraps(func)\n", " def func_wrapper(*args):\n", " _trace_indent(\"call {}({})\".format(name, pp_values(args)))\n", @@ -199,7 +199,7 @@ "id": "Qf4eLrLCFYDl" }, "source": [ - "Instead of using `jax.lax` primitives directly, we can use other functions \n", + "Instead of using `jax.lax` primitives directly, we can use other functions\n", "that are already written in terms of those primitives, such as those in `jax.numpy`:" ] }, @@ -244,7 +244,7 @@ "def square_add_numpy(a, b):\n", " return multiply_add_numpy(a, a, b)\n", "\n", - "print(\"\\nNormal evaluation:\") \n", + "print(\"\\nNormal evaluation:\")\n", "print(\"square_add_numpy = \", square_add_numpy(2., 10.))\n", "print(\"\\nGradient evaluation:\")\n", "print(\"grad(square_add_numpy) = \", api.grad(square_add_numpy)(2.0, 10.))" @@ -257,13 +257,13 @@ }, "source": [ "Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and\n", - "`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further \n", - "below in this colab). \n", - "It is important to remember that a JAX-traceable function must be able to \n", + "`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further\n", + "below in this colab).\n", + "It is important to remember that a JAX-traceable function must be able to\n", "operate not only on concrete arguments but also on special abstract arguments\n", "that JAX may use to abstract the function execution.\n", "\n", - "The JAX traceability property is satisfied as long as the function is written \n", + "The JAX traceability property is satisfied as long as the function is written\n", "in terms of JAX primitives." ] }, @@ -277,7 +277,7 @@ "\n", "The right way to add support for multiply-add is in terms of existing\n", "JAX primitives, as shown above. However, in order to demonstrate how JAX\n", - "primitives work let us pretend that we want to add a new primitive to \n", + "primitives work let us pretend that we want to add a new primitive to\n", "JAX for the multiply-add functionality." ] }, @@ -295,9 +295,9 @@ "@trace(\"multiply_add_prim\")\n", "def multiply_add_prim(x, y, z):\n", " \"\"\"The JAX-traceable way to use the JAX primitive.\n", - " \n", + "\n", " Note that the traced arguments must be passed as positional arguments\n", - " to `bind`. \n", + " to `bind`.\n", " \"\"\"\n", " return multiply_add_p.bind(x, y, z)\n", "\n", @@ -392,7 +392,7 @@ "\n", " This function does not need to be JAX traceable.\n", " Args:\n", - " x, y, z: the concrete arguments of the primitive. Will only be called with \n", + " x, y, z: the concrete arguments of the primitive. Will only be called with\n", " concrete values.\n", " Returns:\n", " the concrete result of the primitive.\n", @@ -485,17 +485,17 @@ }, "source": [ "#### Abstract evaluation rules\n", - "In order to JIT the function, and for other transformations as well, \n", - "JAX first evaluates it abstractly using only the \n", + "In order to JIT the function, and for other transformations as well,\n", + "JAX first evaluates it abstractly using only the\n", "shape and type of the arguments. This abstract evaluation serves multiple\n", "purposes:\n", "\n", - " * Gets the sequence of JAX primitives that are used in the computation. This \n", - " sequence will be compiled. \n", - " * Computes the shape and type of all vectors and operations used in the computation. \n", + " * Gets the sequence of JAX primitives that are used in the computation. This\n", + " sequence will be compiled.\n", + " * Computes the shape and type of all vectors and operations used in the computation.\n", "\n", "\n", - "For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. \n", + "For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.\n", "In the latter case, JAX uses the actual concrete value wrapped as an abstract value." ] }, @@ -527,7 +527,7 @@ " \"\"\"Abstract evaluation of the primitive.\n", "\n", " This function does not need to be JAX traceable. It will be invoked with\n", - " abstractions of the actual arguments. \n", + " abstractions of the actual arguments.\n", " Args:\n", " xs, ys, zs: abstractions of the arguments.\n", " Result:\n", @@ -603,7 +603,7 @@ "\n", "JAX compilation works by compiling each primitive into a graph of XLA operations.\n", "\n", - "This is the biggest hurdle to adding new functionality to JAX, because the \n", + "This is the biggest hurdle to adding new functionality to JAX, because the\n", "set of XLA operations is limited, and JAX already has pre-defined primitives\n", "for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++." ] @@ -642,7 +642,7 @@ }, "source": [ "Now we succeed to JIT. Notice below that JAX first evaluates the function\n", - "abstractly, which triggers the `multiply_add_abstract_eval` function, and \n", + "abstractly, which triggers the `multiply_add_abstract_eval` function, and\n", "then compiles the set of primitives it has encountered, including `multiply_add`.\n", "At this point JAX invokes `multiply_add_xla_translation`." ] @@ -682,7 +682,7 @@ "source": [ "Below is another use of `jit` where we compile only\n", "with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads\n", - "in the third argument to `multiply_add_abstract_eval` being \n", + "in the third argument to `multiply_add_abstract_eval` being\n", "`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with\n", "both `ShapedArray` and `ConcreteArray`." ] @@ -711,7 +711,7 @@ } ], "source": [ - "assert api.jit(lambda x, y: square_add_prim(x, y), \n", + "assert api.jit(lambda x, y: square_add_prim(x, y),\n", " static_argnums=1)(2., 10.) == 14." ] }, @@ -794,16 +794,16 @@ "def multiply_add_value_and_jvp(arg_values, arg_tangents):\n", " \"\"\"Evaluates the primal output and the tangents (Jacobian-vector product).\n", "\n", - " Given values of the arguments and perturbation of the arguments (tangents), \n", + " Given values of the arguments and perturbation of the arguments (tangents),\n", " compute the output of the primitive and the perturbation of the output.\n", "\n", - " This method must be JAX-traceable. JAX may invoke it with abstract values \n", + " This method must be JAX-traceable. JAX may invoke it with abstract values\n", " for the arguments and tangents.\n", "\n", " Args:\n", " arg_values: a tuple of arguments\n", - " arg_tangents: a tuple with the tangents of the arguments. The tuple has \n", - " the same length as the arg_values. Some of the tangents may also be the \n", + " arg_tangents: a tuple with the tangents of the arguments. The tuple has\n", + " the same length as the arg_values. Some of the tangents may also be the\n", " special value ad.Zero to specify a zero tangent.\n", " Returns:\n", " a pair of the primal output and the tangent.\n", @@ -811,26 +811,26 @@ " x, y, z = arg_values\n", " xt, yt, zt = arg_tangents\n", " _trace(\"Primal evaluation:\")\n", - " # Now we have a JAX-traceable computation of the output. \n", - " # Normally, we can use the ma primitive itself to compute the primal output. \n", + " # Now we have a JAX-traceable computation of the output.\n", + " # Normally, we can use the ma primitive itself to compute the primal output.\n", " primal_out = multiply_add_prim(x, y, z)\n", - " \n", + "\n", " _trace(\"Tangent evaluation:\")\n", - " # We must use a JAX-traceable way to compute the tangent. It turns out that \n", + " # We must use a JAX-traceable way to compute the tangent. It turns out that\n", " # the output tangent can be computed as (xt * y + x * yt + zt),\n", " # which we can implement in a JAX-traceable way using the same \"multiply_add_prim\" primitive.\n", - " \n", - " # We do need to deal specially with Zero. Here we just turn it into a \n", - " # proper tensor of 0s (of the same shape as 'x'). \n", - " # An alternative would be to check for Zero and perform algebraic \n", + "\n", + " # We do need to deal specially with Zero. Here we just turn it into a\n", + " # proper tensor of 0s (of the same shape as 'x').\n", + " # An alternative would be to check for Zero and perform algebraic\n", " # simplification of the output tangent computation.\n", " def make_zero(tan):\n", - " return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan \n", - " \n", + " return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan\n", + "\n", " output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))\n", " return (primal_out, output_tangent)\n", "\n", - "# Register the forward differentiation rule with JAX \n", + "# Register the forward differentiation rule with JAX\n", "ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp" ] }, @@ -880,7 +880,7 @@ "id": "69QsEcu-lP4u" }, "source": [ - "TO EXPLAIN: \n", + "TO EXPLAIN:\n", "\n", " * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.\n", " * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet\n", @@ -941,7 +941,7 @@ } ], "source": [ - "assert api.jit(lambda arg_values, arg_tangents: \n", + "assert api.jit(lambda arg_values, arg_tangents:\n", " api.jvp(square_add_prim, arg_values, arg_tangents))(\n", " (2., 10.), (1., 1.)) == (14., 5.)" ] @@ -953,7 +953,7 @@ }, "source": [ "Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn\n", - "evaluates abstractly both the primal and the tangent evaluation (a total of \n", + "evaluates abstractly both the primal and the tangent evaluation (a total of\n", "3 invocations of the `ma` primitive). Then we compile the 3 occurrences\n", "of the primitive." ] @@ -967,21 +967,21 @@ "### Reverse differentiation\n", "\n", "If we attempt now to use reverse differentiation we\n", - "see that JAX starts by using the `multiply_add_value_and_jvp` to \n", + "see that JAX starts by using the `multiply_add_value_and_jvp` to\n", "compute the forward differentiation for abstract values, but then runs\n", - "into a `NotImplementedError`. \n", + "into a `NotImplementedError`.\n", "\n", "When computing the reverse differentiation JAX first does abstract evaluation\n", - "of the forward differentiation code `multiply_add_value_and_jvp` to obtain a \n", - "trace of primitives that compute the output tangent. \n", + "of the forward differentiation code `multiply_add_value_and_jvp` to obtain a\n", + "trace of primitives that compute the output tangent.\n", "Observe that JAX performs this abstract evaluation with concrete values\n", - "for the differentiation point, and abstract values for the tangents. \n", + "for the differentiation point, and abstract values for the tangents.\n", "Observe also that JAX uses the special abstract tangent value `Zero` for\n", - "the tangent corresponding to the 3rd argument of `ma`. This reflects the \n", + "the tangent corresponding to the 3rd argument of `ma`. This reflects the\n", "fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n", "which flows to the 3rd argument to `multiply_add_prim`.\n", "\n", - "Observe also that during the abstract evaluation of the tangent we pass the \n", + "Observe also that during the abstract evaluation of the tangent we pass the\n", "value 0.0 as the tangent for the 3rd argument. This is due to the use\n", "of the `make_zero` function in the definition of `multiply_add_value_and_jvp`." ] @@ -1071,7 +1071,7 @@ "\n", "As explained above, when computing reverse differentiation JAX obtains\n", "a trace of primitives that compute the tangent using forward differentiation.\n", - "Then, **JAX interprets this trace abstractly backwards** and for each \n", + "Then, **JAX interprets this trace abstractly backwards** and for each\n", "primitive it applies a **transposition** rule.\n", "\n", "To understand what is going on, consider for now a simpler example of the function \"f(x, y) = x * y + y\". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:\n", @@ -1082,7 +1082,7 @@ " ft = c + yt\n", "```\n", "\n", - "By construction, the tangent calculation is always linear in the input tangents. \n", + "By construction, the tangent calculation is always linear in the input tangents.\n", "The only non-linear operator that may arise in the tangent calculation is multiplication,\n", "but then one of the operands is constant.\n", "\n", @@ -1108,8 +1108,8 @@ " xct += act * 4.\n", "```\n", "\n", - "One can verify that this computation produces `xct = 4.` and `yct = 3.`, which \n", - "are the partial derivatives of the function `f`. \n", + "One can verify that this computation produces `xct = 4.` and `yct = 3.`, which\n", + "are the partial derivatives of the function `f`.\n", "\n", "JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:\n", "```\n", @@ -1117,10 +1117,10 @@ "```\n", "\n", "Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other\n", - "arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned \n", + "arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned\n", "for the constant arguments.\n", "\n", - "In particular, \n", + "In particular,\n", "```\n", " add_transpose(out_ct, _, _) = (out_ct, out_ct)\n", " mult_transpose(out_ct, x, _) = (None, x * out_ct)\n", @@ -1140,16 +1140,16 @@ "def multiply_add_transpose(ct, x, y, z):\n", " \"\"\"Evaluates the transpose of a linear primitive.\n", "\n", - " This method is only used when computing the backward gradient following \n", - " value_and_jvp, and is only needed for primitives that are used in the JVP \n", - " calculation for some other primitive. We need transposition for multiply_add_prim, \n", - " because we have used multiply_add_prim in the computation of the output_tangent in \n", + " This method is only used when computing the backward gradient following\n", + " value_and_jvp, and is only needed for primitives that are used in the JVP\n", + " calculation for some other primitive. We need transposition for multiply_add_prim,\n", + " because we have used multiply_add_prim in the computation of the output_tangent in\n", " multiply_add_value_and_jvp.\n", "\n", - " In our case, multiply_add is not a linear primitive. However, it is used linearly \n", + " In our case, multiply_add is not a linear primitive. However, it is used linearly\n", " w.r.t. tangents in multiply_add_value_and_jvp:\n", " output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n", - " \n", + "\n", " Always one of the first two multiplicative arguments is a constant.\n", "\n", " Args:\n", @@ -1244,7 +1244,7 @@ }, "source": [ "Notice the two calls to `multiply_add_transpose`. They correspond to the two\n", - "uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the \n", + "uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the\n", "last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0." ] }, @@ -1254,7 +1254,7 @@ "id": "EIJs6FYmPg6c" }, "source": [ - "#### JIT of reverse differentiation \n", + "#### JIT of reverse differentiation\n", "\n", "Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n", "abstract values, while in the absence of JIT we used `ConcreteArray`." @@ -1397,20 +1397,20 @@ "@trace(\"multiply_add_batch\")\n", "def multiply_add_batch(vector_arg_values, batch_axes):\n", " \"\"\"Computes the batched version of the primitive.\n", - " \n", + "\n", " This must be a JAX-traceable function.\n", - " \n", + "\n", " Since the multiply_add primitive already operates pointwise on arbitrary\n", " dimension tensors, to batch it we can use the primitive itself. This works as\n", " long as both the inputs have the same dimensions and are batched along the\n", " same axes. The result is batched along the axis that the inputs are batched.\n", - " \n", + "\n", " Args:\n", " vector_arg_values: a tuple of two arguments, each being a tensor of matching\n", " shape.\n", " batch_axes: the axes that are being batched. See vmap documentation.\n", " Returns:\n", - " a tuple of the result, and the result axis that was batched. \n", + " a tuple of the result, and the result axis that was batched.\n", " \"\"\"\n", " assert batch_axes[0] == batch_axes[1]\n", " assert batch_axes[0] == batch_axes[2]\n", diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md index 0ebf202f2258..656cd0e59f5a 100644 --- a/docs/notebooks/How_JAX_primitives_work.md +++ b/docs/notebooks/How_JAX_primitives_work.md @@ -22,12 +22,12 @@ kernelspec: *necula@google.com*, October 2019. JAX implements certain transformations of Python functions, e.g., `jit`, `grad`, -`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, +`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, which means that as the Python function executes the only operations it applies to the data are either inspections of data attributes such as shape or type, or special operations called JAX primitives. In particular, a JAX-traceable function is sometimes invoked by JAX with -abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, +abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, which captures the type and the shape of values, but not the concrete data values. JAX primitives know how to operate on both concrete data values and on the JAX abstract values. @@ -37,7 +37,7 @@ The JAX-transformed functions must themselves be JAX-traceable functions, to ensure that these transformations can be composed, e.g., `jit(jacfwd(grad(f)))`. -There are pre-defined JAX primitives corresponding to most XLA operations, +There are pre-defined JAX primitives corresponding to most XLA operations, e.g., add, matmul, sin, cos, indexing. JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs using JAX’s implementation of numpy are JAX-traceable and therefore transformable. @@ -49,8 +49,8 @@ one can define a new primitive that encapsulates the behavior of the function. **The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.** Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically -as "multiply_add(x, y, z) = x * y + z". -This function operates on 3 identically-shaped tensors of floating point +as "multiply_add(x, y, z) = x * y + z". +This function operates on 3 identically-shaped tensors of floating point values and performs the operations pointwise. +++ {"id": "HIJYIHNTD1yI"} @@ -58,7 +58,7 @@ values and performs the operations pointwise. ## Using existing primitives The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other -functions that are themselves written using JAX primitives, e.g., those +functions that are themselves written using JAX primitives, e.g., those defined in the `jax.lax` module: ```{code-cell} ipython3 @@ -134,7 +134,7 @@ def trace(name): return str(v) def pp_values(args): return ", ".join([pp(arg) for arg in args]) - + @functools.wraps(func) def func_wrapper(*args): _trace_indent("call {}({})".format(name, pp_values(args))) @@ -164,7 +164,7 @@ class expectNotImplementedError(object): +++ {"id": "Qf4eLrLCFYDl"} -Instead of using `jax.lax` primitives directly, we can use other functions +Instead of using `jax.lax` primitives directly, we can use other functions that are already written in terms of those primitives, such as those in `jax.numpy`: ```{code-cell} ipython3 @@ -182,7 +182,7 @@ def multiply_add_numpy(x, y, z): def square_add_numpy(a, b): return multiply_add_numpy(a, a, b) -print("\nNormal evaluation:") +print("\nNormal evaluation:") print("square_add_numpy = ", square_add_numpy(2., 10.)) print("\nGradient evaluation:") print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) @@ -191,13 +191,13 @@ print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) +++ {"id": "Sg-D8EdeFn4a"} Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and -`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further -below in this colab). -It is important to remember that a JAX-traceable function must be able to +`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further +below in this colab). +It is important to remember that a JAX-traceable function must be able to operate not only on concrete arguments but also on special abstract arguments that JAX may use to abstract the function execution. -The JAX traceability property is satisfied as long as the function is written +The JAX traceability property is satisfied as long as the function is written in terms of JAX primitives. +++ {"id": "WxrQO7-XGLcg"} @@ -206,7 +206,7 @@ in terms of JAX primitives. The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, in order to demonstrate how JAX -primitives work let us pretend that we want to add a new primitive to +primitives work let us pretend that we want to add a new primitive to JAX for the multiply-add functionality. ```{code-cell} ipython3 @@ -218,9 +218,9 @@ multiply_add_p = core.Primitive("multiply_add") # Create the primitive @trace("multiply_add_prim") def multiply_add_prim(x, y, z): """The JAX-traceable way to use the JAX primitive. - + Note that the traced arguments must be passed as positional arguments - to `bind`. + to `bind`. """ return multiply_add_p.bind(x, y, z) @@ -257,7 +257,7 @@ def multiply_add_impl(x, y, z): This function does not need to be JAX traceable. Args: - x, y, z: the concrete arguments of the primitive. Will only be called with + x, y, z: the concrete arguments of the primitive. Will only be called with concrete values. Returns: the concrete result of the primitive. @@ -293,17 +293,17 @@ with expectNotImplementedError(): +++ {"id": "rHS1bAGHH44E"} #### Abstract evaluation rules -In order to JIT the function, and for other transformations as well, -JAX first evaluates it abstractly using only the +In order to JIT the function, and for other transformations as well, +JAX first evaluates it abstractly using only the shape and type of the arguments. This abstract evaluation serves multiple purposes: - * Gets the sequence of JAX primitives that are used in the computation. This - sequence will be compiled. - * Computes the shape and type of all vectors and operations used in the computation. + * Gets the sequence of JAX primitives that are used in the computation. This + sequence will be compiled. + * Computes the shape and type of all vectors and operations used in the computation. -For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. +For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. In the latter case, JAX uses the actual concrete value wrapped as an abstract value. ```{code-cell} ipython3 @@ -316,7 +316,7 @@ def multiply_add_abstract_eval(xs, ys, zs): """Abstract evaluation of the primitive. This function does not need to be JAX traceable. It will be invoked with - abstractions of the actual arguments. + abstractions of the actual arguments. Args: xs, ys, zs: abstractions of the arguments. Result: @@ -349,7 +349,7 @@ with expectNotImplementedError(): JAX compilation works by compiling each primitive into a graph of XLA operations. -This is the biggest hurdle to adding new functionality to JAX, because the +This is the biggest hurdle to adding new functionality to JAX, because the set of XLA operations is limited, and JAX already has pre-defined primitives for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++. @@ -378,7 +378,7 @@ mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') +++ {"id": "K98LX-VaJkFu"} Now we succeed to JIT. Notice below that JAX first evaluates the function -abstractly, which triggers the `multiply_add_abstract_eval` function, and +abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_xla_translation`. @@ -393,7 +393,7 @@ assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. Below is another use of `jit` where we compile only with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads -in the third argument to `multiply_add_abstract_eval` being +in the third argument to `multiply_add_abstract_eval` being `ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with both `ShapedArray` and `ConcreteArray`. @@ -401,7 +401,7 @@ both `ShapedArray` and `ConcreteArray`. :id: mPfTwIBoKOEK :outputId: b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b -assert api.jit(lambda x, y: square_add_prim(x, y), +assert api.jit(lambda x, y: square_add_prim(x, y), static_argnums=1)(2., 10.) == 14. ``` @@ -437,16 +437,16 @@ from jax.interpreters import ad def multiply_add_value_and_jvp(arg_values, arg_tangents): """Evaluates the primal output and the tangents (Jacobian-vector product). - Given values of the arguments and perturbation of the arguments (tangents), + Given values of the arguments and perturbation of the arguments (tangents), compute the output of the primitive and the perturbation of the output. - This method must be JAX-traceable. JAX may invoke it with abstract values + This method must be JAX-traceable. JAX may invoke it with abstract values for the arguments and tangents. Args: arg_values: a tuple of arguments - arg_tangents: a tuple with the tangents of the arguments. The tuple has - the same length as the arg_values. Some of the tangents may also be the + arg_tangents: a tuple with the tangents of the arguments. The tuple has + the same length as the arg_values. Some of the tangents may also be the special value ad.Zero to specify a zero tangent. Returns: a pair of the primal output and the tangent. @@ -454,26 +454,26 @@ def multiply_add_value_and_jvp(arg_values, arg_tangents): x, y, z = arg_values xt, yt, zt = arg_tangents _trace("Primal evaluation:") - # Now we have a JAX-traceable computation of the output. - # Normally, we can use the ma primitive itself to compute the primal output. + # Now we have a JAX-traceable computation of the output. + # Normally, we can use the ma primitive itself to compute the primal output. primal_out = multiply_add_prim(x, y, z) - + _trace("Tangent evaluation:") - # We must use a JAX-traceable way to compute the tangent. It turns out that + # We must use a JAX-traceable way to compute the tangent. It turns out that # the output tangent can be computed as (xt * y + x * yt + zt), # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive. - - # We do need to deal specially with Zero. Here we just turn it into a - # proper tensor of 0s (of the same shape as 'x'). - # An alternative would be to check for Zero and perform algebraic + + # We do need to deal specially with Zero. Here we just turn it into a + # proper tensor of 0s (of the same shape as 'x'). + # An alternative would be to check for Zero and perform algebraic # simplification of the output tangent computation. def make_zero(tan): - return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan - + return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan + output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt))) return (primal_out, output_tangent) -# Register the forward differentiation rule with JAX +# Register the forward differentiation rule with JAX ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp ``` @@ -487,7 +487,7 @@ assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.) +++ {"id": "69QsEcu-lP4u"} -TO EXPLAIN: +TO EXPLAIN: * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here. * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet @@ -504,7 +504,7 @@ We can apply JIT to the forward differentiation function: :id: hg-hzVu-N-hv :outputId: 38d32067-e152-4046-ad80-7f95a31ba628 -assert api.jit(lambda arg_values, arg_tangents: +assert api.jit(lambda arg_values, arg_tangents: api.jvp(square_add_prim, arg_values, arg_tangents))( (2., 10.), (1., 1.)) == (14., 5.) ``` @@ -512,7 +512,7 @@ assert api.jit(lambda arg_values, arg_tangents: +++ {"id": "jlZt1_v2mU88"} Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn -evaluates abstractly both the primal and the tangent evaluation (a total of +evaluates abstractly both the primal and the tangent evaluation (a total of 3 invocations of the `ma` primitive). Then we compile the 3 occurrences of the primitive. @@ -521,21 +521,21 @@ of the primitive. ### Reverse differentiation If we attempt now to use reverse differentiation we -see that JAX starts by using the `multiply_add_value_and_jvp` to +see that JAX starts by using the `multiply_add_value_and_jvp` to compute the forward differentiation for abstract values, but then runs -into a `NotImplementedError`. +into a `NotImplementedError`. When computing the reverse differentiation JAX first does abstract evaluation -of the forward differentiation code `multiply_add_value_and_jvp` to obtain a -trace of primitives that compute the output tangent. +of the forward differentiation code `multiply_add_value_and_jvp` to obtain a +trace of primitives that compute the output tangent. Observe that JAX performs this abstract evaluation with concrete values -for the differentiation point, and abstract values for the tangents. +for the differentiation point, and abstract values for the tangents. Observe also that JAX uses the special abstract tangent value `Zero` for -the tangent corresponding to the 3rd argument of `ma`. This reflects the +the tangent corresponding to the 3rd argument of `ma`. This reflects the fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`, which flows to the 3rd argument to `multiply_add_prim`. -Observe also that during the abstract evaluation of the tangent we pass the +Observe also that during the abstract evaluation of the tangent we pass the value 0.0 as the tangent for the 3rd argument. This is due to the use of the `make_zero` function in the definition of `multiply_add_value_and_jvp`. @@ -560,7 +560,7 @@ to use the forward differentiation code to compute reverse differentiation. As explained above, when computing reverse differentiation JAX obtains a trace of primitives that compute the tangent using forward differentiation. -Then, **JAX interprets this trace abstractly backwards** and for each +Then, **JAX interprets this trace abstractly backwards** and for each primitive it applies a **transposition** rule. To understand what is going on, consider for now a simpler example of the function "f(x, y) = x * y + y". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`: @@ -571,7 +571,7 @@ To understand what is going on, consider for now a simpler example of the functi ft = c + yt ``` -By construction, the tangent calculation is always linear in the input tangents. +By construction, the tangent calculation is always linear in the input tangents. The only non-linear operator that may arise in the tangent calculation is multiplication, but then one of the operands is constant. @@ -597,8 +597,8 @@ of the operation: xct += act * 4. ``` -One can verify that this computation produces `xct = 4.` and `yct = 3.`, which -are the partial derivatives of the function `f`. +One can verify that this computation produces `xct = 4.` and `yct = 3.`, which +are the partial derivatives of the function `f`. JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is: ``` @@ -606,10 +606,10 @@ p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz) ``` Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other -arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned +arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned for the constant arguments. -In particular, +In particular, ``` add_transpose(out_ct, _, _) = (out_ct, out_ct) mult_transpose(out_ct, x, _) = (None, x * out_ct) @@ -623,16 +623,16 @@ In particular, def multiply_add_transpose(ct, x, y, z): """Evaluates the transpose of a linear primitive. - This method is only used when computing the backward gradient following - value_and_jvp, and is only needed for primitives that are used in the JVP - calculation for some other primitive. We need transposition for multiply_add_prim, - because we have used multiply_add_prim in the computation of the output_tangent in + This method is only used when computing the backward gradient following + value_and_jvp, and is only needed for primitives that are used in the JVP + calculation for some other primitive. We need transposition for multiply_add_prim, + because we have used multiply_add_prim in the computation of the output_tangent in multiply_add_value_and_jvp. - In our case, multiply_add is not a linear primitive. However, it is used linearly + In our case, multiply_add is not a linear primitive. However, it is used linearly w.r.t. tangents in multiply_add_value_and_jvp: output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt)) - + Always one of the first two multiplicative arguments is a constant. Args: @@ -674,12 +674,12 @@ assert api.grad(square_add_prim)(2., 10.) == 4. +++ {"id": "8M1xLCXW4fK7"} Notice the two calls to `multiply_add_transpose`. They correspond to the two -uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the +uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0. +++ {"id": "EIJs6FYmPg6c"} -#### JIT of reverse differentiation +#### JIT of reverse differentiation Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only abstract values, while in the absence of JIT we used `ConcreteArray`. @@ -721,20 +721,20 @@ from jax.interpreters import batching @trace("multiply_add_batch") def multiply_add_batch(vector_arg_values, batch_axes): """Computes the batched version of the primitive. - + This must be a JAX-traceable function. - + Since the multiply_add primitive already operates pointwise on arbitrary dimension tensors, to batch it we can use the primitive itself. This works as long as both the inputs have the same dimensions and are batched along the same axes. The result is batched along the axis that the inputs are batched. - + Args: vector_arg_values: a tuple of two arguments, each being a tensor of matching shape. batch_axes: the axes that are being batched. See vmap documentation. Returns: - a tuple of the result, and the result axis that was batched. + a tuple of the result, and the result axis that was batched. """ assert batch_axes[0] == batch_axes[1] assert batch_axes[0] == batch_axes[2] diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index 16e623d0f28b..a4a4d7d1652b 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -119,7 +119,7 @@ " for w, b in params[:-1]:\n", " outputs = jnp.dot(w, activations) + b\n", " activations = relu(outputs)\n", - " \n", + "\n", " final_w, final_b = params[-1]\n", " logits = jnp.dot(final_w, activations) + final_b\n", " return logits - logsumexp(logits)" @@ -238,7 +238,7 @@ "def one_hot(x, k, dtype=jnp.float32):\n", " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", - " \n", + "\n", "def accuracy(params, images, targets):\n", " target_class = jnp.argmax(targets, axis=1)\n", " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index 87533117e56a..d234700e446c 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -96,7 +96,7 @@ def predict(params, image): for w, b in params[:-1]: outputs = jnp.dot(w, activations) + b activations = relu(outputs) - + final_w, final_b = params[-1] logits = jnp.dot(final_w, activations) + final_b return logits - logsumexp(logits) @@ -156,7 +156,7 @@ At this point, we have all the ingredients we need to define our neural network def one_hot(x, k, dtype=jnp.float32): """Create a one-hot encoding of x of size k.""" return jnp.array(x[:, None] == jnp.arange(k), dtype) - + def accuracy(params, images, targets): target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(batched_predict(params, images), axis=1) diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 7e65aefe359c..2c231bf99c46 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -35,7 +35,6 @@ }, "outputs": [], "source": [ - "import numpy as np\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import jit, grad, vmap\n", @@ -214,7 +213,6 @@ "outputs": [], "source": [ "# Importing Jax functions useful for tracing/interpreting.\n", - "import numpy as np\n", "from functools import wraps\n", "\n", "from jax import core\n", @@ -273,7 +271,7 @@ "def eval_jaxpr(jaxpr, consts, *args):\n", " # Mapping from variable -> value\n", " env = {}\n", - " \n", + "\n", " def read(var):\n", " # Literals are values baked into the Jaxpr\n", " if type(var) is core.Literal:\n", @@ -290,16 +288,16 @@ " # Loop through equations and evaluate primitives using `bind`\n", " for eqn in jaxpr.eqns:\n", " # Read inputs to equation from environment\n", - " invals = safe_map(read, eqn.invars) \n", + " invals = safe_map(read, eqn.invars)\n", " # `bind` is how a primitive is called\n", " outvals = eqn.primitive.bind(*invals, **eqn.params)\n", " # Primitives may return multiple outputs or not\n", - " if not eqn.primitive.multiple_results: \n", + " if not eqn.primitive.multiple_results:\n", " outvals = [outvals]\n", " # Write the results of the primitive into the environment\n", - " safe_map(write, eqn.outvars, outvals) \n", + " safe_map(write, eqn.outvars, outvals)\n", " # Read the final result of the Jaxpr from the environment\n", - " return safe_map(read, jaxpr.outvars) " + " return safe_map(read, jaxpr.outvars)" ] }, { @@ -417,7 +415,7 @@ "source": [ "def inverse_jaxpr(jaxpr, consts, *args):\n", " env = {}\n", - " \n", + "\n", " def read(var):\n", " if type(var) is core.Literal:\n", " return var.val\n", @@ -431,12 +429,12 @@ "\n", " # Looping backward\n", " for eqn in jaxpr.eqns[::-1]:\n", - " # outvars are now invars \n", + " # outvars are now invars\n", " invals = safe_map(read, eqn.outvars)\n", " if eqn.primitive not in inverse_registry:\n", " raise NotImplementedError(\n", " f\"{eqn.primitive} does not have registered inverse.\")\n", - " # Assuming a unary function \n", + " # Assuming a unary function\n", " outval = inverse_registry[eqn.primitive](*invals)\n", " safe_map(write, eqn.invars, [outval])\n", " return safe_map(read, jaxpr.invars)" diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index e52c6a5f8742..883d64c374f3 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -32,7 +32,6 @@ Here we show how to add your own function transformations to the system, by writ ```{code-cell} ipython3 :id: s27RDKvKXFL8 -import numpy as np import jax import jax.numpy as jnp from jax import jit, grad, vmap @@ -146,7 +145,6 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr. :id: BHkg_3P1pXJj # Importing Jax functions useful for tracing/interpreting. -import numpy as np from functools import wraps from jax import core @@ -185,7 +183,7 @@ To do this, we first create an environment to store the values for each of the v def eval_jaxpr(jaxpr, consts, *args): # Mapping from variable -> value env = {} - + def read(var): # Literals are values baked into the Jaxpr if type(var) is core.Literal: @@ -202,16 +200,16 @@ def eval_jaxpr(jaxpr, consts, *args): # Loop through equations and evaluate primitives using `bind` for eqn in jaxpr.eqns: # Read inputs to equation from environment - invals = safe_map(read, eqn.invars) + invals = safe_map(read, eqn.invars) # `bind` is how a primitive is called outvals = eqn.primitive.bind(*invals, **eqn.params) # Primitives may return multiple outputs or not - if not eqn.primitive.multiple_results: + if not eqn.primitive.multiple_results: outvals = [outvals] # Write the results of the primitive into the environment - safe_map(write, eqn.outvars, outvals) + safe_map(write, eqn.outvars, outvals) # Read the final result of the Jaxpr from the environment - return safe_map(read, jaxpr.outvars) + return safe_map(read, jaxpr.outvars) ``` ```{code-cell} ipython3 @@ -279,7 +277,7 @@ Now we just need to define `inverse_jaxpr`, which will walk through the Jaxpr ba def inverse_jaxpr(jaxpr, consts, *args): env = {} - + def read(var): if type(var) is core.Literal: return var.val @@ -293,12 +291,12 @@ def inverse_jaxpr(jaxpr, consts, *args): # Looping backward for eqn in jaxpr.eqns[::-1]: - # outvars are now invars + # outvars are now invars invals = safe_map(read, eqn.outvars) if eqn.primitive not in inverse_registry: raise NotImplementedError( f"{eqn.primitive} does not have registered inverse.") - # Assuming a unary function + # Assuming a unary function outval = inverse_registry[eqn.primitive](*invals) safe_map(write, eqn.invars, [outval]) return safe_map(read, jaxpr.invars) diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 86c8bfea8468..478d84935b7f 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -1148,7 +1148,7 @@ " y, vjp_fun = vjp(f, x)\n", " # Use vmap to do a matrix-Jacobian product.\n", " # Here, the matrix is the Euclidean basis, so we get all\n", - " # entries in the Jacobian at once. \n", + " # entries in the Jacobian at once.\n", " J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))\n", " return J\n", " return jacfun\n", @@ -1169,7 +1169,7 @@ "def our_jacfwd(f):\n", " def jacfun(x):\n", " _jvp = lambda s: jvp(f, (x,), (s,))[1]\n", - " Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n", + " Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n", " return jnp.transpose(Jt)\n", " return jacfun\n", "\n", diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index bc2d803f1228..6dcba2470ef6 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -675,7 +675,7 @@ def our_jacrev(f): y, vjp_fun = vjp(f, x) # Use vmap to do a matrix-Jacobian product. # Here, the matrix is the Euclidean basis, so we get all - # entries in the Jacobian at once. + # entries in the Jacobian at once. J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) return J return jacfun @@ -691,7 +691,7 @@ from jax import jacfwd as builtin_jacfwd def our_jacfwd(f): def jacfun(x): _jvp = lambda s: jvp(f, (x,), (s,))[1] - Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x))) + Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x))) return jnp.transpose(Jt) return jacfun diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index f0552e52688f..041cf65314f2 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -739,8 +739,6 @@ "metadata": {}, "outputs": [], "source": [ - "from jax.ad_checkpoint import checkpoint_name\n", - "\n", "def predict(params, x):\n", " *Ws, Wlast = params\n", " for i, W in enumerate(Ws):\n", diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index b31e093b6f91..077a8b6b17b6 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -370,8 +370,6 @@ Notice also that by providing a policy, we didn't need to edit the code defining Some policies can refer to values named with `jax.ad_checkpoint.checkpoint_name`: ```{code-cell} -from jax.ad_checkpoint import checkpoint_name - def predict(params, x): *Ws, Wlast = params for i, W in enumerate(Ws): diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index 5246e810de64..f628625bd041 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -410,7 +410,7 @@ ], "source": [ "dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape\n", - " kernel.shape, # only ndim matters, not shape \n", + " kernel.shape, # only ndim matters, not shape\n", " ('NHWC', 'HWIO', 'NHWC')) # the important bit\n", "print(dn)" ] @@ -806,8 +806,8 @@ ], "source": [ "# 1D kernel - WIO layout\n", - "kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], \n", - " [[1, 1, 1], [-1, -1, -1]]], \n", + "kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],\n", + " [[1, 1, 1], [-1, -1, -1]]],\n", " dtype=jnp.float32).transpose([2,1,0])\n", "# 1D data - NWC layout\n", "data = np.zeros((1, 200, 2), dtype=jnp.float32)\n", @@ -895,8 +895,8 @@ "# Random 3D kernel - HWDIO layout\n", "kernel = jnp.array([\n", " [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n", - " [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n", - " [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n", + " [[0, -1, 0], [-1, 0, -1], [0, -1, 0]],\n", + " [[0, 0, 0], [0, 1, 0], [0, 0, 0]]],\n", " dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]\n", "\n", "# 3D data - NHWDC layout\n", @@ -919,7 +919,6 @@ "print(\"out shape: \", out.shape)\n", "\n", "# Make some simple 3d density plots:\n", - "from mpl_toolkits.mplot3d import Axes3D\n", "def make_alpha(cmap):\n", " my_cmap = cmap(jnp.arange(cmap.N))\n", " my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3\n", diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 2dec35847359..467deeec2c89 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -210,7 +210,7 @@ The important argument is the 3-tuple of axis layout arguments: :outputId: d5a569b3-febc-4832-f725-1d5e8fd31b9b dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape - kernel.shape, # only ndim matters, not shape + kernel.shape, # only ndim matters, not shape ('NHWC', 'HWIO', 'NHWC')) # the important bit print(dn) ``` @@ -363,8 +363,8 @@ You aren't limited to 2D convolutions, a simple 1D demo is below: :outputId: 67c46ace-6adc-4c47-c1c7-1f185be5fd4b # 1D kernel - WIO layout -kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], - [[1, 1, 1], [-1, -1, -1]]], +kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], + [[1, 1, 1], [-1, -1, -1]]], dtype=jnp.float32).transpose([2,1,0]) # 1D data - NWC layout data = np.zeros((1, 200, 2), dtype=jnp.float32) @@ -406,8 +406,8 @@ import matplotlib as mpl # Random 3D kernel - HWDIO layout kernel = jnp.array([ [[0, 0, 0], [0, 1, 0], [0, 0, 0]], - [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], - [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], + [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], + [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis] # 3D data - NHWDC layout @@ -430,7 +430,6 @@ out = lax.conv_general_dilated(data, # lhs = image tensor print("out shape: ", out.shape) # Make some simple 3d density plots: -from mpl_toolkits.mplot3d import Axes3D def make_alpha(cmap): my_cmap = cmap(jnp.arange(cmap.N)) my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3 diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb index 25c551c9834e..3c022124e3cc 100644 --- a/docs/notebooks/external_callbacks.ipynb +++ b/docs/notebooks/external_callbacks.ipynb @@ -840,7 +840,6 @@ }, "outputs": [], "source": [ - "from functools import partial\n", "j1 = partial(jv, 1)\n", "z = jnp.arange(5.0)" ] diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md index c93139e1658c..be76f9913928 100644 --- a/docs/notebooks/external_callbacks.md +++ b/docs/notebooks/external_callbacks.md @@ -410,7 +410,6 @@ This lets us call into `scipy.special.jv` from transformed JAX code, including w ```{code-cell} :id: f4e46670f4e4 -from functools import partial j1 = partial(jv, 1) z = jnp.arange(5.0) ``` diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 7d353c924845..91f2ee571b4b 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -132,7 +132,7 @@ " for w, b in params[:-1]:\n", " outputs = jnp.dot(w, activations) + b\n", " activations = relu(outputs)\n", - " \n", + "\n", " final_w, final_b = params[-1]\n", " logits = jnp.dot(final_w, activations) + final_b\n", " return logits - logsumexp(logits)" @@ -251,7 +251,7 @@ "def one_hot(x, k, dtype=jnp.float32):\n", " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", - " \n", + "\n", "def accuracy(params, images, targets):\n", " target_class = jnp.argmax(targets, axis=1)\n", " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 2f7ba3271312..0c0c4bc5cb8e 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -104,7 +104,7 @@ def predict(params, image): for w, b in params[:-1]: outputs = jnp.dot(w, activations) + b activations = relu(outputs) - + final_w, final_b = params[-1] logits = jnp.dot(final_w, activations) + final_b return logits - logsumexp(logits) @@ -164,7 +164,7 @@ At this point, we have all the ingredients we need to define our neural network def one_hot(x, k, dtype=jnp.float32): """Create a one-hot encoding of x of size k.""" return jnp.array(x[:, None] == jnp.arange(k), dtype) - + def accuracy(params, images, targets): target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(batched_predict(params, images), axis=1) diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index a355959ba45d..9aef1a8eb599 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -25,17 +25,10 @@ }, "outputs": [], "source": [ - "import functools\n", - "import itertools\n", - "import re\n", - "import sys\n", - "import time\n", - "\n", - "from matplotlib.pyplot import *\n", + "import matplotlib.pyplot as plt\n", "\n", "import jax\n", "\n", - "from jax import lax\n", "import jax.numpy as jnp\n", "import jax.scipy as jsp\n", "from jax import random\n", @@ -348,7 +341,7 @@ "def elbo(beta_loc, beta_log_scale, epsilon):\n", " beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n", " return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n", - " \n", + "\n", "elbo = jax.jit(elbo)\n", "elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))" ] @@ -548,25 +541,16 @@ } ], "source": [ - "figure(figsize=(7, 7))\n", - "plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n", - "plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n", - "plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n", + "plt.figure(figsize=(7, 7))\n", + "plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n", + "plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\\sigma$ Error Bars')\n", + "plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n", "plot_scale = 3\n", - "plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n", - "xlabel('True beta')\n", - "ylabel('Estimated beta')\n", - "legend(loc='best')" + "plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n", + "plt.xlabel('True beta')\n", + "plt.ylabel('Estimated beta')\n", + "plt.legend(loc='best')" ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "id": "_bXdOlvUEJl0" - }, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index 9ecbd9d23a0b..f8cfc3553cc6 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -27,17 +27,10 @@ Inspired by a notebook by @davmre. ```{code-cell} ipython3 :id: 8RZDkfbV3zdR -import functools -import itertools -import re -import sys -import time - -from matplotlib.pyplot import * +import matplotlib.pyplot as plt import jax -from jax import lax import jax.numpy as jnp import jax.scipy as jsp from jax import random @@ -192,7 +185,7 @@ batched_log_joint = jax.jit(jax.vmap(log_joint)) def elbo(beta_loc, beta_log_scale, epsilon): beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi)) - + elbo = jax.jit(elbo) elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1))) ``` @@ -240,19 +233,13 @@ Coverage isn't quite as good as we might like, but it's not bad, and nobody said :id: zt1NBLoVHtOG :outputId: fb159795-e6e7-497c-e501-9933ec761af4 -figure(figsize=(7, 7)) -plot(true_beta, beta_loc, '.', label='Approximated Posterior Means') -plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars') -plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.') +plt.figure(figsize=(7, 7)) +plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means') +plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\sigma$ Error Bars') +plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.') plot_scale = 3 -plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k') -xlabel('True beta') -ylabel('Estimated beta') -legend(loc='best') -``` - -```{code-cell} ipython3 -:id: _bXdOlvUEJl0 - - +plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k') +plt.xlabel('True beta') +plt.ylabel('Estimated beta') +plt.legend(loc='best') ``` diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index 8fa2107795fd..b7bc919ebcde 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -189,9 +189,6 @@ ], "source": [ "# Pardon the boilerplate; constructing a sharding will become easier in future!\n", - "from jax.sharding import Mesh\n", - "from jax.sharding import PartitionSpec\n", - "from jax.sharding import NamedSharding\n", "from jax.experimental import mesh_utils\n", "\n", "P = jax.sharding.PartitionSpec\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 345ca7987b41..85dfcdc1733b 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -73,9 +73,6 @@ Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimens :outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5 # Pardon the boilerplate; constructing a sharding will become easier in future! -from jax.sharding import Mesh -from jax.sharding import PartitionSpec -from jax.sharding import NamedSharding from jax.experimental import mesh_utils P = jax.sharding.PartitionSpec diff --git a/jax/_src/core.py b/jax/_src/core.py index 61ed81cdeea9..f80cd0418b81 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1268,7 +1268,7 @@ def new_base_main(trace_type: type[Trace], @contextmanager def pop_level(level: int): if level == 0: - return (yield) + return (yield) # noqa: B901 prev, thread_local_state.trace_state.trace_stack.stack = \ thread_local_state.trace_state.trace_stack.stack, \ thread_local_state.trace_state.trace_stack.stack[:level] diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 51a86fdcb978..7ceac8940147 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -760,7 +760,7 @@ def _dot_product_attention_fwd_partition( scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, mesh, arg_shapes, result_shape): # args sharding - arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) out_shardings = _infer_fwd_output_sharding( mesh, arg_shapes, variadic_args, is_training) impl = functools.partial( @@ -810,7 +810,7 @@ def _dot_product_attention_bwd_partition( arg_shapes, result_shape): out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args) # args sharding - arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) def sharded_impl(*args): impl = functools.partial( _dot_product_attention_bwd_impl, diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index bc548a69c5c5..af3d55f581be 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -238,7 +238,7 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef: Returns: Initialized BufferedRef """ - block_shape = tuple([1 if x is None else x for x in spec.block_shape]) + block_shape = tuple(1 if x is None else x for x in spec.block_shape) if buffer_type is BufferType.ACCUMULATOR: accum_ref = VMEM(block_shape, dtype) else: @@ -310,7 +310,7 @@ def memory_space(self): @property def current_ref(self): buffer_slice = tuple( - [0 if x is None else slice(None) for x in self.block_shape]) + 0 if x is None else slice(None) for x in self.block_shape) if self.memory_space == VMEM: return self.vmem_ref.at[buffer_slice] else: @@ -349,7 +349,7 @@ def bind_existing_ref(self, vmem_ref, indices): def compute_slice(self, grid_indices): """Compute DMA slice from grid indices.""" - block_shape = tuple([1 if x is None else x for x in self.block_shape]) + block_shape = tuple(1 if x is None else x for x in self.block_shape) indices = self.compute_index(*grid_indices) return jax.tree.map(_make_ds, indices, block_shape) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index b0f06624b718..72533e619fe6 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -172,7 +172,7 @@ def _normalize_tolerance(tol): if isinstance(tol, dict): return {np.dtype(k): v for k, v in tol.items()} else: - return {k: tol for k in _default_tolerance} + return dict.fromkeys(_default_tolerance, tol) def join_tolerance(tol1, tol2): tol1 = _normalize_tolerance(tol1) diff --git a/pyproject.toml b/pyproject.toml index 193c6b9fdad0..fb706dbbf8ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,8 @@ ignore = [ "C408", # Unnecessary map usage "C417", + # Unnecessary dict comprehension for iterable + "C420", # Object names too complex "C901", # Local variable is assigned to but never used @@ -141,7 +143,16 @@ max-complexity = 18 [tool.ruff.lint.per-file-ignores] # F811: Redefinition of unused name. +# F821: Undefined name. "docs/autodidax.py" = ["F811"] +"docs/pallas/tpu/matmul.ipynb" = ["F811"] +"docs/pallas/tpu/distributed.ipynb" = ["F811"] +"docs/pallas/quickstart.ipynb" = ["F811"] +"docs/notebooks/autodiff_cookbook.ipynb" = ["F811", "F821"] +"docs/notebooks/autodiff_remat.ipynb" = ["F811", "F821"] +"docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb" = ["F811"] +"docs/jep/9407-type-promotion.ipynb" = ["F811"] +"docs/autodidax.ipynb" = ["F811"] # Note: we don't use jax/*.py because this matches contents of jax/_src "__init__.py" = ["F401"] "jax/abstract_arrays.py" = ["F401"] diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index ce1c02f5a01b..ec9a7cd8b64e 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -308,7 +308,7 @@ def test_fold_strided( expanded_shape = get_packed_shape(strides, shape) total_size = np.prod(expanded_shape) np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape) - index = tuple([slice(0, s) for s in shape]) + index = tuple(slice(0, s) for s in shape) # Reference implementation def np_fold(inp, dim, fold_rank): diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index 1540b3d20892..e8cd40a67b88 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -88,15 +88,6 @@ "height": 68 } }, - "source": [ - "from jaxlib import xla_extension\n", - "import jax\n", - "key = jax.random.PRNGKey(1701)\n", - "arr = jax.random.normal(key, (1000,))\n", - "device = arr.device()\n", - "print(f\"JAX device type: {device}\")\n", - "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\"" - ], "execution_count": 2, "outputs": [ { diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 2cad1d064e87..2818712c0359 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -500,11 +500,11 @@ def test_strided_load_and_store( def body(x_ref, y_ref1, y_ref2): if slice_type == "slice": slices = tuple( - [slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides)] + slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides) ) else: slices = tuple( - [pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides)] + pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides) ) if indexer_type == "state": y_ref1[...] = x_ref[slices] From 09fd345de9aa11de24ed417542967be16016bb59 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 27 Aug 2024 15:23:13 -0700 Subject: [PATCH 255/702] pre-commit: update hooks & pin using hashes --- .pre-commit-config.yaml | 8 ++++---- docs/_tutorials/advanced-autodiff.md | 2 +- docs/_tutorials/advanced-debugging.md | 2 +- docs/_tutorials/external-callbacks.md | 2 +- docs/_tutorials/gradient-checkpointing.md | 2 +- docs/_tutorials/jax-primitives.md | 2 +- docs/_tutorials/jaxpr.md | 2 +- docs/autodidax.md | 2 +- docs/autodidax.py | 2 +- docs/automatic-differentiation.md | 2 +- docs/automatic-vectorization.md | 2 +- docs/debugging.md | 2 +- docs/distributed_data_loading.md | 2 +- docs/ffi.md | 2 +- docs/jep/9407-type-promotion.md | 2 +- docs/jit-compilation.md | 2 +- docs/key-concepts.md | 2 +- docs/notebooks/Common_Gotchas_in_JAX.md | 2 +- docs/notebooks/Custom_derivative_rules_for_Python_code.md | 2 +- .../Distributed_arrays_and_automatic_parallelization.md | 2 +- docs/notebooks/How_JAX_primitives_work.md | 2 +- docs/notebooks/Neural_Network_and_Data_Loading.md | 2 +- docs/notebooks/Writing_custom_interpreters_in_Jax.md | 2 +- docs/notebooks/autodiff_cookbook.md | 2 +- docs/notebooks/autodiff_remat.md | 2 +- docs/notebooks/convolutions.md | 2 +- docs/notebooks/external_callbacks.md | 2 +- docs/notebooks/neural_network_with_tfds_data.md | 2 +- docs/notebooks/shard_map.md | 2 +- docs/notebooks/thinking_in_jax.md | 2 +- docs/notebooks/vmapped_log_probs.md | 2 +- docs/pallas/quickstart.md | 2 +- docs/pallas/tpu/distributed.md | 2 +- docs/pallas/tpu/matmul.md | 2 +- docs/pallas/tpu/pipelining.md | 2 +- docs/quickstart.md | 2 +- docs/random-numbers.md | 2 +- docs/sharded-computation.md | 2 +- docs/stateful-computations.md | 2 +- docs/working-with-pytrees.md | 2 +- 40 files changed, 43 insertions(+), 43 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 79a0df6e9e38..c89aa934d95d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: 2c9f875913ee60ca25ce70243dc24d5b6415598c # frozen: v4.6.0 hooks: - id: check-ast - id: check-merge-conflict @@ -26,12 +26,12 @@ repos: files: \.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.1 + rev: 8b5112a3b2ad121439a2092f8ff548c0d80f2514 # frozen: v0.6.1 hooks: - id: ruff - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.11.0' + rev: 'd4911cfb7f1010759fde68da196036feeb25b99d' # frozen: v1.11.2 hooks: - id: mypy files: (jax/|tests/typing_test\.py) @@ -40,7 +40,7 @@ repos: args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext - rev: v1.16.1 + rev: 8ed836db64ad5d304f2315e6bfd9049c9142e190 # frozen: v1.16.4 hooks: - id: jupytext files: docs/ diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index 20affa8cf29f..180f65f5d492 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/_tutorials/advanced-debugging.md b/docs/_tutorials/advanced-debugging.md index 56188e0958fa..d4462feaf829 100644 --- a/docs/_tutorials/advanced-debugging.md +++ b/docs/_tutorials/advanced-debugging.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/_tutorials/external-callbacks.md b/docs/_tutorials/external-callbacks.md index a46927e6a8b4..c404f320fca7 100644 --- a/docs/_tutorials/external-callbacks.md +++ b/docs/_tutorials/external-callbacks.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/_tutorials/gradient-checkpointing.md b/docs/_tutorials/gradient-checkpointing.md index b768514e4bb0..14a532b54dd1 100644 --- a/docs/_tutorials/gradient-checkpointing.md +++ b/docs/_tutorials/gradient-checkpointing.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/_tutorials/jax-primitives.md b/docs/_tutorials/jax-primitives.md index 51abe2916693..41ff86fd60f0 100644 --- a/docs/_tutorials/jax-primitives.md +++ b/docs/_tutorials/jax-primitives.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/_tutorials/jaxpr.md b/docs/_tutorials/jaxpr.md index 9fe990c0a8ba..974ed39c1663 100644 --- a/docs/_tutorials/jaxpr.md +++ b/docs/_tutorials/jaxpr.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/autodidax.md b/docs/autodidax.md index 0551b9905db3..471dd7c63f6d 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -6,7 +6,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/autodidax.py b/docs/autodidax.py index b09534381c69..6d295fc50301 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -20,7 +20,7 @@ # extension: .py # format_name: light # format_version: '1.5' -# jupytext_version: 1.16.1 +# jupytext_version: 1.16.4 # kernelspec: # display_name: Python 3 # name: python3 diff --git a/docs/automatic-differentiation.md b/docs/automatic-differentiation.md index cc4a19aaba64..07af05e3d973 100644 --- a/docs/automatic-differentiation.md +++ b/docs/automatic-differentiation.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/automatic-vectorization.md b/docs/automatic-vectorization.md index 7559155e2e9e..032d1c56f27a 100644 --- a/docs/automatic-vectorization.md +++ b/docs/automatic-vectorization.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/debugging.md b/docs/debugging.md index 7ee36f19f5bf..94384035ca9d 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/distributed_data_loading.md b/docs/distributed_data_loading.md index be4d170eae81..14fb1bb55c35 100644 --- a/docs/distributed_data_loading.md +++ b/docs/distributed_data_loading.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/ffi.md b/docs/ffi.md index 4568b670e170..802fd4f2264e 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index 2d12944f16f9..cdb1f7805b7e 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 2d442c8411aa..bc6cb3c04cf8 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/key-concepts.md b/docs/key-concepts.md index c6cfb176e645..b87808d14faf 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 0b21a57e36c9..3324fdb53bcd 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 3c60cce0cf30..6c948650fcc1 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index c5f3c08ed4f9..97b07172b707 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md index 656cd0e59f5a..b926c22ea32f 100644 --- a/docs/notebooks/How_JAX_primitives_work.md +++ b/docs/notebooks/How_JAX_primitives_work.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index d234700e446c..03b8415fc91c 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 883d64c374f3..41d7a7e51dfc 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index 6dcba2470ef6..6615e65352d7 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 077a8b6b17b6..a4fb27c58128 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 467deeec2c89..83ab2d9fd56d 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md index be76f9913928..910d47bd72ae 100644 --- a/docs/notebooks/external_callbacks.md +++ b/docs/notebooks/external_callbacks.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 0c0c4bc5cb8e..480e7477b8ac 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 5b40e78dcfc3..21e4111d5bf9 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -7,7 +7,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 16be7b9e9369..dd0c73ec7699 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index f8cfc3553cc6..5989a87bc141 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index 36cc14bf5c34..b8f9254f21d9 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index b7c058b117ca..9f6048295826 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -6,7 +6,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/pallas/tpu/matmul.md b/docs/pallas/tpu/matmul.md index a00880ebaf37..aa19d7bc19d8 100644 --- a/docs/pallas/tpu/matmul.md +++ b/docs/pallas/tpu/matmul.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index d753b404db1a..5ae053da21a4 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/quickstart.md b/docs/quickstart.md index e071a7ce7555..e19cb33ea9c5 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 85bb5ce01974..2ad1eadb0968 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 85dfcdc1733b..4f8c1b0201bc 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 4e4063a68467..2eeffc30b255 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/working-with-pytrees.md b/docs/working-with-pytrees.md index 2bd1cc08ecdf..e41179996bc4 100644 --- a/docs/working-with-pytrees.md +++ b/docs/working-with-pytrees.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python From 76583c87bcb233a4ed30518da399bfece28f56c1 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Tue, 27 Aug 2024 15:07:13 -0700 Subject: [PATCH 256/702] gcd_lcm_docstring_added description_improved --- jax/_src/numpy/lax_numpy.py | 72 +++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 692bd3d49aaa..ae7ff06cc5c2 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9166,9 +9166,43 @@ def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]: where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) -@util.implements(np.gcd, module='numpy') @jit def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: + """Compute the greatest common divisor of two arrays. + + JAX implementation of :func:`numpy.gcd`. + + Args: + x1: First input array. The elements must have integer dtype. + x2: Second input array. The elements must have integer dtype. + + Returns: + An array containing the greatest common divisors of the corresponding + elements from the absolute values of `x1` and `x2`. + + See also: + - :func:`jax.numpy.lcm`: compute the least common multiple of two arrays. + + Examples: + Scalar inputs: + + >>> jnp.gcd(12, 18) + Array(6, dtype=int32, weak_type=True) + + Array inputs: + + >>> x1 = jnp.array([12, 18, 24]) + >>> x2 = jnp.array([5, 10, 15]) + >>> jnp.gcd(x1, x2) + Array([1, 2, 3], dtype=int32) + + Broadcasting: + + >>> x1 = jnp.array([12]) + >>> x2 = jnp.array([6, 9, 12]) + >>> jnp.gcd(x1, x2) + Array([ 6, 3, 12], dtype=int32) + """ util.check_arraylike("gcd", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) if not issubdtype(_dtype(x1), integer): @@ -9178,9 +9212,43 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: return gcd -@util.implements(np.lcm, module='numpy') @jit def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: + """Compute the least common multiple of two arrays. + + JAX implementation of :func:`numpy.lcm`. + + Args: + x1: First input array. The elements must have integer dtype. + x2: Second input array. The elements must have integer dtype. + + Returns: + An array containing the least common multiple of the corresponding + elements from the absolute values of `x1` and `x2`. + + See also: + - :func:`jax.numpy.gcd`: compute the greatest common divisor of two arrays. + + Examples: + Scalar inputs: + + >>> jnp.lcm(12, 18) + Array(36, dtype=int32, weak_type=True) + + Array inputs: + + >>> x1 = jnp.array([12, 18, 24]) + >>> x2 = jnp.array([5, 10, 15]) + >>> jnp.lcm(x1, x2) + Array([ 60, 90, 120], dtype=int32) + + Broadcasting: + + >>> x1 = jnp.array([12]) + >>> x2 = jnp.array([6, 9, 12]) + >>> jnp.lcm(x1, x2) + Array([12, 36, 12], dtype=int32) + """ util.check_arraylike("lcm", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) x1, x2 = ufuncs.abs(x1), ufuncs.abs(x2) From 4c642280d478b9228e77c3cab5cb85069b57e840 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Wed, 28 Aug 2024 04:39:45 +0000 Subject: [PATCH 257/702] Use BlockSpecs when possible. --- jax/experimental/pallas/ops/gpu/attention.py | 58 +++++++++----------- 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 1cf8349e7da2..63541e8cb439 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -41,7 +41,7 @@ def mha_forward_kernel( block_d: int, block_k: int, ): - seq_len = q_ref.shape[0] + seq_len = k_ref.shape[0] start_q = pl.program_id(0) # o is the buffer where we accumulate the output on sram. @@ -55,7 +55,7 @@ def mha_forward_kernel( # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. # q tile has shape [block_q, block_d], block_d == head_dim. curr_q_slice = pl.dslice(start_q * block_q, block_q) - q = pl.load(q_ref, (curr_q_slice, pl.dslice(None))) + q = q_ref[...] q_segment_ids = ( None if segment_ids_ref is None @@ -123,12 +123,9 @@ def body(start_k, carry): if residual_refs: lse_ref = residual_refs[0] - lse_i = m_i + jnp.log(l_i) - pl.store(lse_ref, (curr_q_slice,), lse_i) + lse_ref[...] = m_i + jnp.log(l_i) # Write output to dram. - o = o.astype(o_ref.dtype) - pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o) - + o_ref[...] = o.astype(o_ref.dtype) def segment_mask( q_segment_ids: jax.Array, @@ -197,7 +194,7 @@ def mha( in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) @@ -217,7 +214,7 @@ def mha( grid=grid_, in_specs=in_specs, out_specs=pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), compiler_params=dict( triton=dict(num_warps=num_warps_, num_stages=num_stages) @@ -268,7 +265,7 @@ def _mha_forward( ] in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) @@ -288,9 +285,9 @@ def _mha_forward( in_specs=in_specs, out_specs=[ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), ], compiler_params=dict( triton=dict(num_warps=num_warps_, num_stages=num_stages) @@ -303,17 +300,14 @@ def _mha_forward( return out, (q, k, v, segment_ids, out, lse) -def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, *, block_q: int): - pid_m = pl.program_id(0) - - off_m = pl.ds(pid_m * block_q, block_q) +def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref): # load - o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32) - do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32) + o = out_ref[...].astype(jnp.float32) + do = dout_ref[...].astype(jnp.float32) # compute delta = jnp.sum(o * do, axis=1) # write-back - pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype)) + delta_ref[...] = delta.astype(delta_ref.dtype) @jax.named_scope("preprocess_backward") def _preprocess_backward(out, do, lse, block_q: int, @@ -321,17 +315,17 @@ def _preprocess_backward(out, do, lse, block_q: int, batch_size, seq_len, num_heads, head_dim = out.shape out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype) delta = pl.pallas_call( - functools.partial(_preprocess_backward_kernel, block_q=block_q), + _preprocess_backward_kernel, grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), ], - out_specs=pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), compiler_params=dict(triton=dict(num_warps=4, num_stages=3)), out_shape=out_shape, debug=debug, @@ -431,8 +425,8 @@ def inner_loop_dkdv(start_q, carry): dv, dk = lax.fori_loop( lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk) ) - pl.store(dv_ref, (curr_k_slice, slice(None)), dv.astype(dv_ref.dtype)) - pl.store(dk_ref, (curr_k_slice, slice(None)), dk.astype(dk_ref.dtype)) + dv_ref[...] = dv.astype(dv_ref.dtype) + dk_ref[...] = dk.astype(dk_ref.dtype) del dv, dk @@ -495,7 +489,7 @@ def inner_loop_dq(start_k, dq): upper_bound = pl.cdiv(seq_len, block_k2) dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) - pl.store(dq_ref, (curr_q_slice, slice(None)), dq.astype(dq_ref.dtype)) + dq_ref[...] = dq.astype(dq_ref.dtype) def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, @@ -566,16 +560,16 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, grid=grid, out_specs=[ pl.BlockSpec( - (None, seq_len, None, head_dim), - lambda i, j, _: (i, 0, j, 0), # dq + (None, block_q, None, head_dim), + lambda i, j, k: (i, k, j, 0), # dq ), pl.BlockSpec( - (None, seq_len, None, head_dim), - lambda i, j, _: (i, 0, j, 0), # dk + (None, block_k, None, head_dim), + lambda i, j, k: (i, k, j, 0), # dk ), pl.BlockSpec( - (None, seq_len, None, head_dim), - lambda i, j, _: (i, 0, j, 0), # dv + (None, block_k, None, head_dim), + lambda i, j, k: (i, k, j, 0), # dv ), ], name="mha_backward", From eaefabee85d0732b5b6d40a205a4c78c461ea92f Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 27 Aug 2024 23:37:15 -0700 Subject: [PATCH 258/702] Fixes to api_benchmark.py. Testcases always fail without these fixes. PiperOrigin-RevId: 668299061 --- benchmarks/api_benchmark.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 710ffb6d7cad..df9528ada9ff 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -566,7 +566,7 @@ def bench_repeated_static_slicing(state): while state: jax.block_until_ready([x[i:i + 2] for i in range(0, 1000, 2)]) -def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False): +def pjit_simple_benchmark(state, num_devices, num_args, use_aot=False): spec = jax.sharding.PartitionSpec('x') mesh = create_mesh((num_devices,), ('x',), state) if mesh is None: @@ -601,8 +601,7 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False): @google_benchmark.option.args([10]) @google_benchmark.option.args([100]) def pjit_simple_1_device(state): - pjit_simple_benchmark( - state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1)) + pjit_simple_benchmark(state, num_devices=1, num_args=state.range(0)) @google_benchmark.register @google_benchmark.option.arg_names(['num_args']) @@ -610,8 +609,7 @@ def pjit_simple_1_device(state): @google_benchmark.option.args([10]) @google_benchmark.option.args([100]) def pjit_simple_4_device(state): - pjit_simple_benchmark( - state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1)) + pjit_simple_benchmark(state, num_devices=4, num_args=state.range(0)) @google_benchmark.register @google_benchmark.option.arg_names(['num_args']) @@ -619,8 +617,7 @@ def pjit_simple_4_device(state): @google_benchmark.option.args([10]) @google_benchmark.option.args([100]) def pjit_simple_4000_device(state): - pjit_simple_benchmark( - state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1)) + pjit_simple_benchmark(state, num_devices=4000, num_args=state.range(0)) @google_benchmark.register @@ -633,7 +630,6 @@ def pjit_aot_1_device(state): state, num_devices=1, num_args=state.range(0), - cpp_jit=state.range(1), use_aot=True) @@ -647,7 +643,6 @@ def pjit_aot_4_device(state): state, num_devices=4, num_args=state.range(0), - cpp_jit=state.range(1), use_aot=True) @@ -661,7 +656,6 @@ def pjit_aot_4000_device(state): state, num_devices=4000, num_args=state.range(0), - cpp_jit=state.range(1), use_aot=True) @@ -697,6 +691,8 @@ def device_put_from_numpy_array(state): @google_benchmark.option.args([100]) @google_benchmark.option.args([1000]) def device_put_from_jax_array(state): + if len(jax.devices()) < 2: + state.skip_with_error('requires 2 devices') x = [np.array(1, np.int32)] * state.range(0) x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0])) d = jax.devices()[1] From b0bd9337c90e00747d264b8eb91b746d644e91af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Wed, 28 Aug 2024 03:05:50 -0700 Subject: [PATCH 259/702] Revert to initial formatting of CPU FFI Kernels list This list has accidentally been auto-formatted which has caused unnecessary conflicts for future PRs. PiperOrigin-RevId: 668368321 --- jax/_src/export/_export.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 88e4f21546fe..54a4f6d5498e 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -923,22 +923,10 @@ def _check_lowering(lowering) -> None: "\n".join(not_implemented_msgs)) _CPU_FFI_KERNELS = [ - "lapack_spotrf_ffi", - "lapack_dpotrf_ffi", - "lapack_cpotrf_ffi", - "lapack_zpotrf_ffi", - "lapack_sgeqrf_ffi", - "lapack_dgeqrf_ffi", - "lapack_cgeqrf_ffi", - "lapack_zgeqrf_ffi", - "lapack_sgesdd_ffi", - "lapack_dgesdd_ffi", - "lapack_cgesdd_ffi", - "lapack_zgesdd_ffi", - "lapack_sgetrf_ffi", - "lapack_dgetrf_ffi", - "lapack_cgetrf_ffi", - "lapack_zgetrf_ffi", + "lapack_spotrf_ffi", "lapack_dpotrf_ffi", "lapack_cpotrf_ffi", "lapack_zpotrf_ffi", + "lapack_sgeqrf_ffi", "lapack_dgeqrf_ffi", "lapack_cgeqrf_ffi", "lapack_zgeqrf_ffi", + "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", + "lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi", ] # These are the JAX custom call target names that are guaranteed to be stable. # Their backwards compatibility is tested by back_compat_test.py. From 3c6103f2dfd8c8a56c5f9104eef52c4320a3d1ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Wed, 28 Aug 2024 03:53:07 -0700 Subject: [PATCH 260/702] Activate Eigenvalue Decompositions to XLA's FFI Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD. PiperOrigin-RevId: 668381949 --- jax/_src/export/_export.py | 2 + .../cpu_eig_lapack_geev.py | 265 +++++++++++ .../cpu_eigh_lapack_syev.py | 443 ++++++++++++++++++ jax/_src/lax/linalg.py | 22 +- jaxlib/lapack.py | 294 +++++++----- tests/export_back_compat_test.py | 18 + 6 files changed, 932 insertions(+), 112 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 54a4f6d5498e..54defa0e9c54 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -925,6 +925,8 @@ def _check_lowering(lowering) -> None: _CPU_FFI_KERNELS = [ "lapack_spotrf_ffi", "lapack_dpotrf_ffi", "lapack_cpotrf_ffi", "lapack_zpotrf_ffi", "lapack_sgeqrf_ffi", "lapack_dgeqrf_ffi", "lapack_cgeqrf_ffi", "lapack_zgeqrf_ffi", + "lapack_ssyevd_ffi", "lapack_dsyevd_ffi", "lapack_cheevd_ffi", "lapack_zheevd_ffi", + "lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi", "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", "lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi", ] diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py index 1e4b6428556b..bc28857fa325 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py @@ -283,3 +283,268 @@ mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\x02\r[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x87\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgeev\x00", xla_call_module_version=6, ) # End paste + + +data_2024_08_19 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464249196572972e+01+0.j, -2.4642491965729794e+00+0.j, + -1.4596915295025735e-15+0.j, 4.7403016698320490e-16+0.j]), array([[ 0.40377749076862324+0.j, 0.8288327563197503 +0.j, + -0.5409014947846461 +0.j, 0.10917005482608667-0.j], + [ 0.4648073711584899 +0.j, 0.43714638836388775-0.j, + 0.7854306338527134 +0.j, -0.5456169434539783 +0.j], + [ 0.5258372515483575 +0.j, 0.04546002040802463-0.j, + 0.05184321664851461-0.j, 0.7637237224296971 +0.j], + [ 0.5868671319382249 +0.j, -0.34622634754783843+0.j, + -0.296372355716581 +0.j, -0.32727683380180517+0.j]]), array([[ 0.11417645138733866+0.j, 0.7327780959803557 +0.j, + -0.5367326141844461 +0.j, -0.08617176416747369+0.j], + [ 0.33000459866554754+0.j, 0.28974835239692603-0.j, + 0.6342729310130916 +0.j, -0.28826848493327445+0.j], + [ 0.5458327459437569 +0.j, -0.15328139118650222+0.j, + 0.34165198052715445-0.j, 0.83505226236897 +0.j], + [ 0.7616608932219664 +0.j, -0.5963111347699301 +0.j, + -0.4391922973557999 +0.j, -0.460612013268222 +0.j]])), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:4 = stablehlo.custom_call @lapack_zgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %3 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %4 = stablehlo.compare EQ, %2#3, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %5, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %8 = stablehlo.select %7, %2#0, %6 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %12 = stablehlo.select %11, %2#1, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %13 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %16 = stablehlo.select %15, %2#2, %14 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %8, %12, %16 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xef\xa57\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1b/\x0f\x1f\x0f\x0b\x0bO/O\x01\x05\x0b\x0f\x033\x17\x07\x07\x13\x0b\x0f\x0f\x0f\x07\x17\x17\x1b\x07\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\xa6\x06\x1d;=\x03\x03\t\x99\x1f\x05\x19\x05\x1b\x03\x03\x07\x9f\x11\x03\x05\x05\x1d\x17\x13J\x03\x1d\x05\x1f\x03\x03\x07\x7f\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05!\x11\x01\x00\x05#\x05%\x05\'\x03\x0b\'])i+k\x0fy-{\x05)\x05+\x05-\x05/\x03\x031}\x051\x1d5\x11\x053\x1d9\x11\x055\x057\x17\x13N\x03\x1b\x03\x13A\x81C\x83E\x85G]I\x87K\x89M\x8fO]Q\x91\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05K\x05M\x03\x03\t\xa1\x03\x01\x1dO\x1dQ\x1dS\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13#V#\x1b\x03\x07mqu\r\x05_oac\x1dU\r\x05_sac\x1dW\r\x05_wac\x1dY\x1d[\x1d]\x13\x07\x01\x1f\x13\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d_\x1da\x05\x01\r\x05\x8bg\x8dg\x1dc\x1de\x03\x03e\x03\t\x93ee\x95\x1f\'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f)\x01\x1f\x0f\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\r\x1d\x01)\x03\x11\r\x03\x1d)\x01!)\x01\r)\x01\x07\x13)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05\x0b)\x03A\r\x1b!)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04"\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x035a\x0b\x033/\x03\x1f\r\x067\x03\x05\x03\x01\x05\x03\x01\x15\x03\x13\x05\x03\x01\x15\x03\x13\x0f\x07\x01?\t\x0b\x05\x05\x0f\x03\x03\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x11\x11\x07\x01U\x03-\x05\x0f\x13\x03\x07\x01\x03\x03/\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x19\x03\x07\x01[\x031\x03\x17\x07\x06\x01\x03\x0b\x07\x1d\t\x1b\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x17\x03\x19\x03!\x07\x06\x01\x03\x05\x07\'\x0b%\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03-\x03\x07\x01\x17\x03\x19\x03+\x07\x06\x01\x03\x05\x071\r/\x13\x04\x05\x07\x1f)3\x06\x03\x01\x05\x01\x00^\x0eg\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x87\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgeev_ffi\x00compute_left\x00compute_right\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464249e+01+0.j, -2.4642491e+00+0.j, -8.1492220e-07+0.j, + 3.0721142e-07+0.j], dtype=complex64), array([[ 0.40377736 +0.j, 0.8288328 +0.j, -0.53676015 +0.j, + 0.07707452 -0.j], + [ 0.4648074 +0.j, 0.43714643 -0.j, 0.79694915 +0.j, + -0.5069523 +0.j], + [ 0.52583736 +0.j, 0.04545992 -0.j, 0.016383484+0.j, + 0.7826807 +0.j], + [ 0.5868672 +0.j, -0.34622622 +0.j, -0.2765721 +0.j, + -0.35280296 +0.j]], dtype=complex64), array([[ 0.114176415+0.j, 0.73277825 +0.j, -0.54227245 +0.j, + -0.109032825+0.j], + [ 0.3300045 +0.j, 0.2897482 -0.j, 0.6655821 +0.j, + -0.25470036 +0.j], + [ 0.5458329 +0.j, -0.15328139 +0.j, 0.29565343 +0.j, + 0.83649963 +0.j], + [ 0.7616609 +0.j, -0.59631103 +0.j, -0.4189632 +0.j, + -0.47276634 +0.j]], dtype=complex64)), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:4 = stablehlo.custom_call @lapack_cgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %3 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %4 = stablehlo.compare EQ, %2#3, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %5, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %8 = stablehlo.select %7, %2#0, %6 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %12 = stablehlo.select %11, %2#1, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %13 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %16 = stablehlo.select %15, %2#2, %14 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %8, %12, %16 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xef\xa57\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1b/\x0f\x1f\x0f\x0b\x0b//O\x01\x05\x0b\x0f\x033\x17\x07\x07\x13\x0b\x0f\x0f\x0f\x07\x17\x17\x1b\x07\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\x86\x06\x1d;=\x03\x03\t\x99\x1f\x05\x19\x05\x1b\x03\x03\x07\x9f\x11\x03\x05\x05\x1d\x17\x13J\x03\x1d\x05\x1f\x03\x03\x07\x7f\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05!\x11\x01\x00\x05#\x05%\x05\'\x03\x0b\'])i+k\x0fy-{\x05)\x05+\x05-\x05/\x03\x031}\x051\x1d5\x11\x053\x1d9\x11\x055\x057\x17\x13N\x03\x1b\x03\x13A\x81C\x83E\x85G]I\x87K\x89M\x8fO]Q\x91\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05K\x05M\x03\x03\t\xa1\x03\x01\x1dO\x1dQ\x1dS\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13#V#\x1b\x03\x07mqu\r\x05_oac\x1dU\r\x05_sac\x1dW\r\x05_wac\x1dY\x1d[\x1d]\x13\x07\x01\x1f\x13\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d_\x1da\x05\x01\r\x05\x8bg\x8dg\x1dc\x1de\x03\x03e\x03\t\x93ee\x95\x1f\'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f)\x01\x1f\x0f\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\r\x1d\x01)\x03\x11\r\x03\x1d)\x01!)\x01\r)\x01\x07\x13)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05\t)\x03A\r\x1b!)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04"\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x035a\x0b\x033/\x03\x1f\r\x067\x03\x05\x03\x01\x05\x03\x01\x15\x03\x13\x05\x03\x01\x15\x03\x13\x0f\x07\x01?\t\x0b\x05\x05\x0f\x03\x03\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x11\x11\x07\x01U\x03-\x05\x0f\x13\x03\x07\x01\x03\x03/\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x19\x03\x07\x01[\x031\x03\x17\x07\x06\x01\x03\x0b\x07\x1d\t\x1b\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x17\x03\x19\x03!\x07\x06\x01\x03\x05\x07\'\x0b%\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03-\x03\x07\x01\x17\x03\x19\x03+\x07\x06\x01\x03\x05\x071\r/\x13\x04\x05\x07\x1f)3\x06\x03\x01\x05\x01\x00Z\x0eg\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x85\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgeev_ffi\x00compute_left\x00compute_right\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464241e+01+0.j, -2.4642482e+00+0.j, -4.5555478e-07+0.j, + 2.9215252e-07+0.j], dtype=complex64), array([[-0.40377742+0.j, 0.8288328 +0.j, -0.5253654 +0.j, + -0.11065983+0.j], + [-0.46480736+0.j, 0.43714654+0.j, 0.8159359 +0.j, + 0.547376 +0.j], + [-0.52583736+0.j, 0.04545998+0.j, -0.0557748 +0.j, + -0.7627722 +0.j], + [-0.5868672 +0.j, -0.34622627+0.j, -0.23479532+0.j, + 0.32605612+0.j]], dtype=complex64), array([[-0.114176415+0.j, 0.7327782 +0.j, -0.5364275 +0.j, + 0.15489015 +0.j], + [-0.33000445 +0.j, 0.28974816 +0.j, 0.6327556 +0.j, + 0.18506403 +0.j], + [-0.54583275 +0.j, -0.15328142 +0.j, 0.34377125 +0.j, + -0.83479893 +0.j], + [-0.761661 +0.j, -0.5963111 +0.j, -0.44009918 +0.j, + 0.49484456 +0.j]], dtype=complex64)), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xf32> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xf32>) -> tensor<4x4xf32> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:5 = stablehlo.custom_call @lapack_sgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %3 = stablehlo.complex %2#0, %2#1 : tensor<4xcomplex> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %4 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %5 = stablehlo.compare EQ, %2#4, %4, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %8 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %9 = stablehlo.select %8, %3, %7 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %13 = stablehlo.select %12, %2#2, %11 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %17 = stablehlo.select %16, %2#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %9, %13, %17 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xf3\xa5;\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f/\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1f\x0f\x1f\x0f\x0b\x0b//O\x01\x05\x0b\x0f\x037\x17\x07\x07\x13\x07\x0f\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x17\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\xae\x06\x1d;=\x03\x03\t\x99\x1f\x05\x1b\x05\x1d\x03\x03\x07\x9f\x11\x03\x05\x05\x1f\x17\x13J\x03\x1d\x05!\x03\x03\x07\x81\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05#\x11\x01\x00\x05%\x05'\x05)\x03\x0b'])k+m\x0f{-}\x05+\x05-\x05/\x051\x03\x031\x7f\x053\x1d5\x11\x055\x1d9\x11\x057\x059\x17\x13N\x03\x1b\x03\x13A\x83C\x85E\x87G]I\x89K\x8bM\x91O]Q\x93\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x05K\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05M\x05O\x03\x03\t\xa1\x03\x01\x1dQ\x1dS\x1dU\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13'V\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07osw\r\x05_qac\x1dW\r\x05_uac\x1dY\r\x05_yac\x1d[\x1d]\x1d_\x13\x07\x01\x1f\x15\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1da\x1dc\x05\x01\r\x05\x8dg\x8fg\x1de\x1dg\x03\x03e\x03\x0biiee\x95\x1f-\x01\x1f\x0f\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x13\x1d\x01)\x03\x11\x13\t)\x01%)\x01\x13\x03\r)\x01\x07\x13)\x03\x11\r)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05)\x03A\r)\x05\x11\x11\r\x1b!)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04F\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x039e\x0b\x033/\x03!\r\x067\x03#\x03\x01\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x0b\x19\x19\x05\x05\x0f\x03\x03\x11\x06\x01\x03\x0b\x05\t\x0b\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x15\x13\x07\x01U\x031\x05\x11\x17\x03\x07\x01\x03\x033\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x1d\x03\x07\x01[\x035\x03\x1b\x07\x06\x01\x03\x0b\x07!\x13\x1f\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03'\x03\x07\x01\x17\x03\x1d\x03%\x07\x06\x01\x03\x05\x07+\r)\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x0f3\x15\x04\x05\x07#-7\x06\x03\x01\x05\x01\x00\x82\x0ei\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgeev_ffi\x00compute_left\x00compute_right\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464249196572972e+01+0.j, -2.4642491965729789e+00+0.j, + -1.4885423746029788e-15+0.j, 4.7495173217146935e-16+0.j]), array([[-0.40377749076862246 +0.j, -0.8288327563197503 +0.j, + -0.541090767303977 +0.j, 0.10767692008040902 +0.j], + [-0.4648073711584901 +0.j, -0.43714638836388775 +0.j, + 0.7847911174458492 +0.j, -0.5438508504687168 +0.j], + [-0.5258372515483576 +0.j, -0.045460020408024666+0.j, + 0.05369006702023438 +0.j, 0.7646709406962073 +0.j], + [-0.5868671319382248 +0.j, 0.34622634754783854 +0.j, + -0.2973904171621061 +0.j, -0.32849701030789913 +0.j]]), array([[-0.11417645138733848+0.j, -0.7327780959803556 +0.j, + -0.5370341524353898 +0.j, -0.0849751818967924 +0.j], + [-0.33000459866554754+0.j, -0.2897483523969262 +0.j, + 0.6357878989446506 +0.j, -0.29000500336734825+0.j], + [-0.545832745943757 +0.j, 0.15328139118650214+0.j, + 0.33952665941686755+0.j, 0.8349355524250736 +0.j], + [-0.7616608932219664 +0.j, 0.5963111347699303 +0.j, + -0.43828040592612855+0.j, -0.45995536716093305+0.j]])), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xf64> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xf64>) -> tensor<4x4xf64> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:5 = stablehlo.custom_call @lapack_dgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf64>) -> (tensor<4xf64>, tensor<4xf64>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %3 = stablehlo.complex %2#0, %2#1 : tensor<4xcomplex> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %4 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %5 = stablehlo.compare EQ, %2#4, %4, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %8 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %9 = stablehlo.select %8, %3, %7 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %13 = stablehlo.select %12, %2#2, %11 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %17 = stablehlo.select %16, %2#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %9, %13, %17 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xf3\xa5;\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f/\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1f\x0f\x1f\x0f\x0b\x0bO/O\x01\x05\x0b\x0f\x037\x17\x07\x07\x13\x07\x0f\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x17\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\xce\x06\x1d;=\x03\x03\t\x99\x1f\x05\x1b\x05\x1d\x03\x03\x07\x9f\x11\x03\x05\x05\x1f\x17\x13J\x03\x1d\x05!\x03\x03\x07\x81\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05#\x11\x01\x00\x05%\x05'\x05)\x03\x0b'])k+m\x0f{-}\x05+\x05-\x05/\x051\x03\x031\x7f\x053\x1d5\x11\x055\x1d9\x11\x057\x059\x17\x13N\x03\x1b\x03\x13A\x83C\x85E\x87G]I\x89K\x8bM\x91O]Q\x93\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x05K\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05M\x05O\x03\x03\t\xa1\x03\x01\x1dQ\x1dS\x1dU\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13'V\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07osw\r\x05_qac\x1dW\r\x05_uac\x1dY\r\x05_yac\x1d[\x1d]\x1d_\x13\x07\x01\x1f\x15\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1da\x1dc\x05\x01\r\x05\x8dg\x8fg\x1de\x1dg\x03\x03e\x03\x0biiee\x95\x1f-\x01\x1f\x0f\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x13\x1d\x01)\x03\x11\x13\x0b)\x01%)\x01\x13\x03\r)\x01\x07\x13)\x03\x11\r)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05)\x03A\r)\x05\x11\x11\r\x1b!)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04F\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x039e\x0b\x033/\x03!\r\x067\x03#\x03\x01\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x0b\x19\x19\x05\x05\x0f\x03\x03\x11\x06\x01\x03\x0b\x05\t\x0b\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x15\x13\x07\x01U\x031\x05\x11\x17\x03\x07\x01\x03\x033\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x1d\x03\x07\x01[\x035\x03\x1b\x07\x06\x01\x03\x0b\x07!\x13\x1f\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03'\x03\x07\x01\x17\x03\x1d\x03%\x07\x06\x01\x03\x05\x07+\r)\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x0f3\x15\x04\x05\x07#-7\x06\x03\x01\x05\x01\x00\x82\x0ei\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgeev_ffi\x00compute_left\x00compute_right\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py index fcc32058bbee..f0696db1aeda 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py @@ -383,3 +383,446 @@ xla_call_module_version=4, ), # End paste ) + +data_2024_08_19 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zheevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-6.1857700048412056e-01+0.j, 2.4081403770912022e-01+0.j, + 3.5662489253627483e-01+0.j, -6.3034019033669797e-01+0.j, + 1.0043483479985752e-16+0.j, -2.8842036081919542e-02+0.j, + 7.7164692943283169e-25+0.j, -1.8446994643771725e-01+0.j], + [-4.7070881487314609e-01+0.j, 4.7473787464450828e-01+0.j, + -4.8036836210243361e-01+0.j, 4.3802686872516400e-01+0.j, + 1.7961797619639255e-01+0.j, 8.3080980076741355e-03+0.j, + 2.1415294457221759e-01+0.j, -2.2856669794666584e-01+0.j], + [-3.2284062926217072e-01+0.j, -5.4336490915553370e-01+0.j, + 2.2181041859724987e-01+0.j, 2.9947877954402286e-01+0.j, + -3.6491813600134637e-01+0.j, 3.2867679819727436e-01+0.j, + 3.8223299448843473e-01+0.j, -2.7266344945561438e-01+0.j], + [-1.7497244365119527e-01+0.j, -8.9251550609769331e-02+0.j, + -6.3518515114898352e-02+0.j, 1.9162997359209963e-01+0.j, + -2.2087281326110142e-01+0.j, 5.9957027043505008e-02+0.j, + -8.7632498908241274e-01+0.j, -3.1676020096456303e-01+0.j], + [-2.7104258040220017e-02+0.j, -3.3772873786627688e-01+0.j, + 2.5901386593721754e-01+0.j, 1.7032650752287815e-01+0.j, + 6.7521217612940321e-01+0.j, -4.5036136532965476e-01+0.j, + -1.2279030059078447e-02+0.j, -3.6085695247351163e-01+0.j], + [ 1.2076392757075533e-01+0.j, -3.3834734096469249e-01+0.j, + -6.5506827461665529e-01+0.j, -5.0472498521116760e-01+0.j, + 6.9987430903492132e-02+0.j, 1.0595648906599270e-01+0.j, + 8.3443844143082035e-02+0.j, -4.0495370398246017e-01+0.j], + [ 2.6863211318173102e-01+0.j, 2.2958613191407312e-01+0.j, + 6.3952843755683969e-02+0.j, 1.8776775771084192e-02+0.j, + -5.3523731432241317e-01+0.j, -5.9199531677602002e-01+0.j, + 1.7916671834524250e-01+0.j, -4.4905045549140887e-01+0.j], + [ 4.1650029879270667e-01+0.j, 3.6355449432857068e-01+0.j, + 2.9755313100756148e-01+0.j, 1.6826270392616000e-02+0.j, + 1.9621068035557282e-01+0.j, 5.6830030587314817e-01+0.j, + 2.9607517592514260e-02+0.j, -4.9314720700035747e-01+0.j]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, + -1.9932120610662194e-14, -5.7323356091157378e-15, + -4.5459724251334835e-16, 4.0479851042511616e-14, + 9.2325194924982089e-14, 2.7659880477613365e+02])), + mlir_module_text=r""" +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc18 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> loc(#loc10) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc11) + %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> loc(#loc12) + %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> loc(#loc13) + %5 = stablehlo.negate %4 : tensor<8x8xf64> loc(#loc14) + %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> loc(#loc15) + %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> loc(#loc16) + %cst = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc17) + %9 = stablehlo.divide %7, %8 : tensor<8x8xcomplex> loc(#loc17) + %10 = call @tril(%9) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc18) + %c = stablehlo.constant dense<8> : tensor loc(#loc19) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc19) + %11:3 = stablehlo.custom_call @lapack_zheevd_ffi(%10) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xcomplex>) -> (tensor<8x8xcomplex>, tensor<8xf64>, tensor) loc(#loc19) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc19) + %12 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc19) + %13 = stablehlo.compare EQ, %11#2, %12, SIGNED : (tensor, tensor) -> tensor loc(#loc19) + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc19) + %cst_2 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc19) + %15 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc19) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc19) + %17 = stablehlo.select %16, %11#0, %15 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc19) + %18 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> loc(#loc19) + %cst_3 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc19) + %19 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf64> loc(#loc19) + %20 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc19) + %21 = stablehlo.select %20, %11#1, %19 : tensor<8xi1>, tensor<8xf64> loc(#loc19) + return %17, %21 : tensor<8x8xcomplex>, tensor<8xf64> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7))) -> (tensor<8x8xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc20) + %c = stablehlo.constant dense<0> : tensor loc(#loc18) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc21) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc21) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc22) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc23) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc18) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc24) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc25) + return %6 : tensor<8x8xcomplex> loc(#loc18) + } loc(#loc18) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:25) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc8 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc9 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]"(#loc1)) +#loc10 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc11 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc12 = loc("jit()/jit(main)/real"(#loc4)) +#loc13 = loc("jit()/jit(main)/imag"(#loc4)) +#loc14 = loc("jit()/jit(main)/neg"(#loc4)) +#loc15 = loc("jit()/jit(main)/complex"(#loc4)) +#loc16 = loc("jit()/jit(main)/add"(#loc5)) +#loc17 = loc("jit()/jit(main)/div"(#loc6)) +#loc19 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc8)) +#loc20 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc7)) +#loc21 = loc("jit()/jit(main)/jit(tril)/add"(#loc7)) +#loc22 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc7)) +#loc23 = loc("jit()/jit(main)/jit(tril)/ge"(#loc7)) +#loc24 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc7)) +#loc25 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc7)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\'\x03\xda\x02*\x02?\x01\xab\x0f\x0b\x13\x17\x0f\x0b\x07\x17\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03a\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO/\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0bOO//\x01\x0b\x1f\x17\x17\x17\x17\x01\x05\x0b\x0f\x03;\x17\x07\x0f\x0f\x07\x07\x13\x17\x0b\x17\x0f\x07\x07\x17\x13\x07\x0f\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xa6\n\x1d\x93\x95\x05)\x03\x03\x13\xd5\x17\x03V\x047\x1d?\x07\x05+\x1f\x17\x03>\x043\x05-\x05/\x11\x03\x05\x051\x053\x055\x057\x03\x03!\xd1\x059\x03\x03\x0b\xd3\x1dE\x07\x05;\x05=\x1d\x8b\x8d\x03\x03\x0b\xe1\x03\t135\x157\x15\x119\x05?\x11\x01\x00\x05A\x05C\x05E\x03\x0b\x17\xaf\x19\xbb\x1b\xbd\x11\xc7\x1d\xc9\x03\x0b\x17\xb3\x19\xcd\x1b\xb3\x11\xb5\x1d\xcf\x05G\x1dC\x07\x05I\x05K\x03\x03!\xd7\x1dK\x07\x05M\x03\x05\'\xb7)\xd9\x1dQ\x07\x05O\x03\x03\x0b\xdb\x1dW\x07\x05Q\x1d[\x07\x05S\x1d_a\x05U\x17\x036\x045\x1deg\x05W\x17\x036\x04\x1d\x03\x03k\xdd\x05Y\x1doq\x05[\x17\x03>\x04E\x1du\x0f\x05]\x1dy\x0f\x05_\x1d}\x0f\x05a\x1d\x81\x0f\x05c\x1d\x85\x87\x05e\x17\x03>\x04\x1f\x03\x03\x0b\xdf\x05g\x17\x03>\x04\x1d\x03\x03\x91\xb5\x05i\x05k\x17\x03V\x04\x17\x03\x13\x99\xe3\x9b\xe5\x9d\xe7\x9f\xaf\xa1\xe9\xa3\xeb\xa5\xf5\xa7\xf7\xa9\xfb\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x05}\x1d\x7f\x1d\x81\x03\x01\x1d\x83\x03\x03\xcb\x1d\x85\t\x07\x1f/!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\'\x03\x05\xbf\xc3\r\x05\xb1\xc1\xab\xad\x1d\x87\r\x05\xb1\xc5\xab\xad\x1d\x89\x1d\x8b\x1d\x8d\r\x03\xab\xad#)\x1d\x8f\x13\x07\x01\x1f\x0b\t\x00\x00\x00\x00\x1f+\x01\x13\x07\x05\x07\x05\x1f\t!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x91\x1d\x93\x05\x01\r\x05\xed\xef\xf1\xf3\x1d\x95\x13#V\x1d\x97\x13#L\x03\x03\xb9\x03\x03\xf9\x15\x03\x01\x01\x01\x03\x07\xb9\xfd\xff\x1f1\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x01\x07\x01\x1f\t!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05\'\xb7)\x02\x02\x03\x03\x0b\x06\x02\x03\x03\x13\n\x02\x03\x03\x0b\x0e\x02\x03\x03\x13\x12\x02\x01\t\x01\x02\x02)\x05!!\x15\x1d)\x01\x15)\x01\x1d\x01\x0b)\x03!\x0f)\x05!!\x1d\x03\x0f)\x05!!\x0f)\x01\x07\x13\x1b)\x05!!\r)\x03\t\x07!)\x01\x0f\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x07)\x03\x02\x02\x15)\x03\t\x1b)\x03\x05\x1b)\x03\x01\x1b)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\x07\x04\x06\x05\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x03=u\x07\x03]\x1f\x03-\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\x17\x06s\x03\x17\x03\x05\x19\x06w\x03\x17\x03\x05\x1b\x06{\x03\x17\x03\t\x1d\x06\x7f\x03\x05\x05\x07\x0b\r\x06\x83\x03\x05\x05\x03\r\x05\x03\r\x89\x03\t\x03\x07+\x05\x03\x05\x03\x11\x1f\x06+\x03\x05\x05\x0f\x13!\x07\t\x8f\x03\x05\x03\x15\x05\x03\x01-\x03\x19\x05\x03\x01-\x03\x19#\x07\x01\x97\x07\x05\x11\x0b\x03\x17\x05\x03\x01#\x03\x0b\x03\x07\x01\x05\x03\x0b\x03#\x0f\x07\x01\x16\x02\x035\x05!%\x03\x07\x01\x05\x037\x03\'\x05\x03\x01\x1a\x02\x03\t\x03\x07\x01\x05\x03\x05\x03+\x03\x07\x01\x1e\x02\x03\x1f\x03)\t\x06\x01\x03\x05\x07/\x1d-\x03\x07\x01\x05\x039\x03\'\x05\x03\x01"\x02\x03%\x03\x07\x01\x05\x03\x11\x035\x03\x07\x01&\x02\x03;\x033\t\x06\x01\x03\x11\x079\x1f7\x11\x04\r\x051;\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1f\x03\x13\x05\x03\t#\x03\x0b\x03\x07%\x05\x03\x13\x03\x05\r\x06%\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1f\x05\t\x0b\x05\x03\tS\x03\t\x03\x07U\x05\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xe2\x1c\x99\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99A9;;m\x19\x85\x8fW\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_zheevd_ffi\x00mode\x00uplo\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cheevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-0.6185769 +0.j, -0.20142993 +0.j, -0.09725195 +0.j, + 0.62983674 +0.j, -0.07926044 +0.j, 0.3605001 -0.j, + -0.019093221 +0.j, -0.18446997 +0.j], + [-0.47070873 +0.j, 0.29325768 +0.j, -0.19454116 +0.j, + -0.6394365 +0.j, 0.06229549 +0.j, 0.33249345 +0.j, + 0.28112718 +0.j, -0.22856665 +0.j], + [-0.32284075 +0.j, -0.12361939 +0.j, 0.20547704 +0.j, + -0.18307868 +0.j, 0.47294614 +0.j, -0.3170349 +0.j, + -0.6373532 +0.j, -0.27266347 +0.j], + [-0.17497246 +0.j, -0.079641335 +0.j, 0.15042792 +0.j, + -0.15416273 +0.j, -0.815209 +0.j, -0.38054234 +0.j, + -0.083263926 +0.j, -0.31676024 +0.j], + [-0.027104257 +0.j, -0.26490977 +0.j, 0.32271704 +0.j, + 0.08653544 +0.j, 0.30305928 +0.j, -0.33998996 +0.j, + 0.6926741 +0.j, -0.360857 +0.j], + [ 0.120763965 +0.j, 0.43288827 +0.j, -0.64385164 +0.j, + 0.2652551 +0.j, 0.094823755 +0.j, -0.37435007 +0.j, + 0.00091664493+0.j, -0.40495378 +0.j], + [ 0.26863196 +0.j, 0.51607686 +0.j, 0.53846526 +0.j, + 0.16969058 +0.j, -0.0216703 +0.j, 0.35755336 +0.j, + -0.113144726 +0.j, -0.4490505 +0.j], + [ 0.4165004 +0.j, -0.57262254 +0.j, -0.28144246 +0.j, + -0.17463988 +0.j, -0.016984984 +0.j, 0.3613705 +0.j, + -0.12186296 +0.j, -0.49314725 +0.j]], dtype=complex64), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, + -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], + dtype=float32)), + mlir_module_text=r""" +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc18 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> loc(#loc10) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc11) + %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> loc(#loc12) + %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> loc(#loc13) + %5 = stablehlo.negate %4 : tensor<8x8xf32> loc(#loc14) + %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> loc(#loc15) + %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> loc(#loc16) + %cst = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc17) + %9 = stablehlo.divide %7, %8 : tensor<8x8xcomplex> loc(#loc17) + %10 = call @tril(%9) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc18) + %c = stablehlo.constant dense<8> : tensor loc(#loc19) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc19) + %11:3 = stablehlo.custom_call @lapack_cheevd_ffi(%10) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xcomplex>) -> (tensor<8x8xcomplex>, tensor<8xf32>, tensor) loc(#loc19) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc19) + %12 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc19) + %13 = stablehlo.compare EQ, %11#2, %12, SIGNED : (tensor, tensor) -> tensor loc(#loc19) + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc19) + %cst_2 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc19) + %15 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc19) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc19) + %17 = stablehlo.select %16, %11#0, %15 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc19) + %18 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> loc(#loc19) + %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc19) + %19 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf32> loc(#loc19) + %20 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc19) + %21 = stablehlo.select %20, %11#1, %19 : tensor<8xi1>, tensor<8xf32> loc(#loc19) + return %17, %21 : tensor<8x8xcomplex>, tensor<8xf32> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7))) -> (tensor<8x8xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc20) + %c = stablehlo.constant dense<0> : tensor loc(#loc18) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc21) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc21) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc22) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc23) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc18) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc24) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc25) + return %6 : tensor<8x8xcomplex> loc(#loc18) + } loc(#loc18) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:25) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc8 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc9 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(64,) dimension=0]"(#loc1)) +#loc10 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc11 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc12 = loc("jit()/jit(main)/real"(#loc4)) +#loc13 = loc("jit()/jit(main)/imag"(#loc4)) +#loc14 = loc("jit()/jit(main)/neg"(#loc4)) +#loc15 = loc("jit()/jit(main)/complex"(#loc4)) +#loc16 = loc("jit()/jit(main)/add"(#loc5)) +#loc17 = loc("jit()/jit(main)/div"(#loc6)) +#loc19 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc8)) +#loc20 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc7)) +#loc21 = loc("jit()/jit(main)/jit(tril)/add"(#loc7)) +#loc22 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc7)) +#loc23 = loc("jit()/jit(main)/jit(tril)/ge"(#loc7)) +#loc24 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc7)) +#loc25 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc7)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\'\x03\xda\x02*\x02?\x01\xab\x0f\x0b\x13\x17\x0f\x0b\x07\x17\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03a\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O//\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0b/O\x1f/\x01\x0b\x1f\x17\x17\x17\x17\x01\x05\x0b\x0f\x03;\x17\x07\x0f\x0f\x07\x07\x13\x17\x0b\x17\x0f\x07\x07\x17\x13\x07\x0f\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x026\n\x1d\x93\x95\x05)\x03\x03\x13\xd5\x17\x03V\x047\x1d?\x07\x05+\x1f\x17\x03>\x043\x05-\x05/\x11\x03\x05\x051\x053\x055\x057\x03\x03!\xd1\x059\x03\x03\x0b\xd3\x1dE\x07\x05;\x05=\x1d\x8b\x8d\x03\x03\x0b\xe1\x03\t135\x157\x15\x119\x05?\x11\x01\x00\x05A\x05C\x05E\x03\x0b\x17\xaf\x19\xbb\x1b\xbd\x11\xc7\x1d\xc9\x03\x0b\x17\xb3\x19\xcd\x1b\xb3\x11\xb5\x1d\xcf\x05G\x1dC\x07\x05I\x05K\x03\x03!\xd7\x1dK\x07\x05M\x03\x05\'\xb7)\xd9\x1dQ\x07\x05O\x03\x03\x0b\xdb\x1dW\x07\x05Q\x1d[\x07\x05S\x1d_a\x05U\x17\x036\x045\x1deg\x05W\x17\x036\x04\x1d\x03\x03k\xdd\x05Y\x1doq\x05[\x17\x03>\x04E\x1du\x0f\x05]\x1dy\x0f\x05_\x1d}\x0f\x05a\x1d\x81\x0f\x05c\x1d\x85\x87\x05e\x17\x03>\x04\x1f\x03\x03\x0b\xdf\x05g\x17\x03>\x04\x1d\x03\x03\x91\xb5\x05i\x05k\x17\x03V\x04\x17\x03\x13\x99\xe3\x9b\xe5\x9d\xe7\x9f\xaf\xa1\xe9\xa3\xeb\xa5\xf5\xa7\xf7\xa9\xfb\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x05}\x1d\x7f\x1d\x81\x03\x01\x1d\x83\x03\x03\xcb\x1d\x85\t\x07\x1f/!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\'\x03\x05\xbf\xc3\r\x05\xb1\xc1\xab\xad\x1d\x87\r\x05\xb1\xc5\xab\xad\x1d\x89\x1d\x8b\x1d\x8d\r\x03\xab\xad#)\x1d\x8f\x13\x07\x01\x1f\x0b\t\x00\x00\x00\x00\x1f+\x01\x13\x07\x05\x07\x05\x1f\t\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x19\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x91\x1d\x93\x05\x01\r\x05\xed\xef\xf1\xf3\x1d\x95\x13#V\x1d\x97\x13#L\x03\x03\xb9\x03\x03\xf9\x15\x03\x01\x01\x01\x03\x07\xb9\xfd\xff\x1f1\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x01\x07\x01\x1f\t\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f%\t\x00\x00\xc0\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05\'\xb7)\x02\x02\x03\x03\x0b\x06\x02\x03\x03\x13\n\x02\x03\x03\x0b\x0e\x02\x03\x03\x13\x12\x02\x01\t\x01\x02\x02)\x05!!\x15\x1d)\x01\x15)\x01\x1d\x01\t)\x03!\x0f)\x05!!\x1d\x03\x0f)\x05!!\x0f)\x01\x07\x13\x1b)\x05!!\r)\x03\t\x07!)\x01\x0f\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x07)\x03\x02\x02\x15)\x03\t\x1b)\x03\x05\x1b)\x03\x01\x1b)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\x07\x04\x06\x05\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x03=u\x07\x03]\x1f\x03-\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\x17\x06s\x03\x17\x03\x05\x19\x06w\x03\x17\x03\x05\x1b\x06{\x03\x17\x03\t\x1d\x06\x7f\x03\x05\x05\x07\x0b\r\x06\x83\x03\x05\x05\x03\r\x05\x03\r\x89\x03\t\x03\x07+\x05\x03\x05\x03\x11\x1f\x06+\x03\x05\x05\x0f\x13!\x07\t\x8f\x03\x05\x03\x15\x05\x03\x01-\x03\x19\x05\x03\x01-\x03\x19#\x07\x01\x97\x07\x05\x11\x0b\x03\x17\x05\x03\x01#\x03\x0b\x03\x07\x01\x05\x03\x0b\x03#\x0f\x07\x01\x16\x02\x035\x05!%\x03\x07\x01\x05\x037\x03\'\x05\x03\x01\x1a\x02\x03\t\x03\x07\x01\x05\x03\x05\x03+\x03\x07\x01\x1e\x02\x03\x1f\x03)\t\x06\x01\x03\x05\x07/\x1d-\x03\x07\x01\x05\x039\x03\'\x05\x03\x01"\x02\x03%\x03\x07\x01\x05\x03\x11\x035\x03\x07\x01&\x02\x03;\x033\t\x06\x01\x03\x11\x079\x1f7\x11\x04\r\x051;\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1f\x03\x13\x05\x03\t#\x03\x0b\x03\x07%\x05\x03\x13\x03\x05\r\x06%\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1f\x05\t\x0b\x05\x03\tS\x03\t\x03\x07U\x05\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xde\x1c\x99\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99A9;;m\x19\x85\x8dW\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_cheevd_ffi\x00mode\x00uplo\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_ssyevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-0.6185769 , -0.20142993 , -0.09725195 , 0.62983674 , + -0.07926044 , 0.3605001 , -0.019093221 , -0.18446997 ], + [-0.47070873 , 0.29325768 , -0.19454119 , -0.6394365 , + 0.0622955 , 0.33249345 , 0.28112718 , -0.22856665 ], + [-0.32284075 , -0.12361939 , 0.20547704 , -0.18307868 , + 0.47294614 , -0.3170349 , -0.6373532 , -0.27266347 ], + [-0.17497246 , -0.079641335 , 0.15042791 , -0.15416273 , + -0.815209 , -0.38054234 , -0.083263926 , -0.31676024 ], + [-0.027104253 , -0.26490977 , 0.32271704 , 0.08653544 , + 0.30305928 , -0.33998996 , 0.6926741 , -0.360857 ], + [ 0.12076397 , 0.43288827 , -0.64385164 , 0.2652551 , + 0.09482376 , -0.37435007 , 0.00091664493, -0.40495378 ], + [ 0.26863196 , 0.51607686 , 0.53846526 , 0.16969058 , + -0.021670295 , 0.35755336 , -0.113144726 , -0.4490505 ], + [ 0.4165004 , -0.57262254 , -0.2814425 , -0.17463988 , + -0.01698498 , 0.3613705 , -0.12186296 , -0.49314725 ]], + dtype=float32), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, + -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], + dtype=float32)), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc13 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xf32> loc(#loc8) + %1 = stablehlo.reshape %0 : (tensor<64xf32>) -> tensor<8x8xf32> loc(#loc9) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32> loc(#loc10) + %3 = stablehlo.add %1, %2 : tensor<8x8xf32> loc(#loc11) + %cst = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf32> loc(#loc12) + %5 = stablehlo.divide %3, %4 : tensor<8x8xf32> loc(#loc12) + %6 = call @tril(%5) : (tensor<8x8xf32>) -> tensor<8x8xf32> loc(#loc13) + %c = stablehlo.constant dense<8> : tensor loc(#loc14) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc14) + %7:3 = stablehlo.custom_call @lapack_ssyevd_ffi(%6) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xf32>) -> (tensor<8x8xf32>, tensor<8xf32>, tensor) loc(#loc14) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc14) + %8 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc14) + %9 = stablehlo.compare EQ, %7#2, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc14) + %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc14) + %cst_2 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc14) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor) -> tensor<8x8xf32> loc(#loc14) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc14) + %13 = stablehlo.select %12, %7#0, %11 : tensor<8x8xi1>, tensor<8x8xf32> loc(#loc14) + %14 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc14) + %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc14) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf32> loc(#loc14) + %16 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc14) + %17 = stablehlo.select %16, %7#1, %15 : tensor<8xi1>, tensor<8xf32> loc(#loc14) + return %13, %17 : tensor<8x8xf32>, tensor<8xf32> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6))) -> (tensor<8x8xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc15) + %c = stablehlo.constant dense<0> : tensor loc(#loc13) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc16) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc16) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc17) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc18) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc13) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf32> loc(#loc19) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xf32> loc(#loc20) + return %6 : tensor<8x8xf32> loc(#loc13) + } loc(#loc13) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc8 = loc("jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]"(#loc1)) +#loc9 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc10 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc11 = loc("jit()/jit(main)/add"(#loc4)) +#loc12 = loc("jit()/jit(main)/div"(#loc5)) +#loc14 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc7)) +#loc15 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc6)) +#loc16 = loc("jit()/jit(main)/jit(tril)/add"(#loc6)) +#loc17 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc6)) +#loc18 = loc("jit()/jit(main)/jit(tril)/ge"(#loc6)) +#loc19 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc6)) +#loc20 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc6)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01+\x05\x01\x03\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff9\x01\xa1\x0f\x13\x17\x0b\x0f\x0b\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x03_\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f/\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0b\x1fO/\x01\x05\x0b\x0f\x035\x17\x0f\x07\x0f\x07\x07\x13\x17\x0f\x07\x07\x17\x13\x07\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x02:\t\x1d\x83\x85\x03\x03\x11\xcb\x17\x07V\x047\x05!\x1d?\x05\x05#\x1f\x05%\x05'\x11\x03\x05\x05)\x05+\x05-\x05/\x03\x03\x1f\xc7\x051\x03\x03\x0b\xc9\x1dE\x05\x053\x055\x1d{}\x03\x03\x0b\xd7\x03\x03\x0b\xf9\x03\t135\x137\x13\x0f9\x057\x11\x01\x00\x059\x05;\x05=\x03\x0b\x15\xa5\x17\xb1\x19\xb3\x0f\xbd\x1b\xbf\x03\x0b\x15\xa9\x17\xc3\x19\xa9\x0f\xab\x1b\xc5\x05?\x1dC\x05\x05A\x05C\x03\x03\x1f\xcd\x1dK\x05\x05E\x03\x05%\xad'\xcf\x1dQ\x05\x05G\x03\x03\x0b\xd1\x1dW\x05\x05I\x1d[\x05\x05K\x1d_a\x05M\x17\x076\x045\x1deg\x05O\x17\x076\x04\x1d\x03\x03k\xd3\x05Q\x1doq\x05S\x17\x07>\x04E\x1duw\x05U\x17\x07>\x04\x1f\x03\x03\x0b\xd5\x05W\x17\x07>\x04\x1d\x03\x03\x81\xab\x05Y\x05[\x17\x07V\x04\x17\x03\x13\x89\xd9\x8b\xdb\x8d\xdd\x8f\xa5\x91\xdf\x93\xe1\x95\xeb\x97\xed\x99\xf1\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x05k\x05m\x03\x05%\xad'\xf7\x03\x03\x11\xfb\x03\x03\x11\xfd\x1do\x1dq\x03\x01\x1ds\x03\x03\xc1\x1du\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#!\x03\x05\xb5\xb9\r\x05\xa7\xb7\xa1\xa3\x1dw\r\x05\xa7\xbb\xa1\xa3\x1dy\x1d{\x1d}\r\x03\xa1\xa3##\x1d\x7f\x13\t\x01\x1f\x0b\t\x00\x00\x00\x00\x1f%\x01\x13\t\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x1d!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x15\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x81\x1d\x83\x05\x01\r\x05\xe3\xe5\xe7\xe9\x1d\x85\x13\x1fV\x1d\x87\x13\x1fL\x03\x03\xaf\x03\x03\xef\x15\x03\x01\x01\x01\x03\x07\xaf\xf3\xf5\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x01\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05!!\x0f)\x01\x0f\x1d)\x01\x19\x01\t)\x03!\x0f)\x05!!\x19)\x01\t\x13\x1b)\x05!!\r)\x03\t\t!\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\t)\x03\x02\x02\x0f)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\t\x04~\x04\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x035e\x07\x03]\x1d\x03'\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\r\x06s\x03\x05\x05\x03\x05\x05\x03\ry\x03\x07\x03\x07)\x03\x03\x05\x03\t\x17\x06)\x03\x05\x05\x07\x0b\x19\x07\t\x7f\x03\x05\x03\r\x05\x03\x01+\x03\x15\x05\x03\x01+\x03\x15\x1b\x07\x01\x87\x07\x05\x11\x0b\x03\x0f\x05\x03\x01!\x03\x0b\x03\x07\x01\x03\x03\x0b\x03\x1b\x0f\x07\x01\x9b\x03/\x05\x19\x1d\x03\x07\x01\x03\x031\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x9d\x03\x1b\x03!\t\x06\x01\x03\x05\x07'\x15%\x03\x07\x01\x03\x033\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x11\x03-\x03\x07\x01\x9f\x035\x03+\t\x06\x01\x03\x11\x071\x17/\x11\x04\r\x05)3\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1d\x03\x13\x05\x03\t!\x03\x0b\x03\x07#\x03\x03\x13\x03\x05\r\x06#\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1b\x05\t\x0b\x05\x03\tS\x03\x07\x03\x07U\x03\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00J\x1a\x89\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99m\x19\x85\x89W\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_ssyevd_ffi\x00mode\x00uplo\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dsyevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-6.1857700048412056e-01, 2.4081403770912022e-01, + 3.5662489253627483e-01, -6.3034019033669797e-01, + 1.0043483479985752e-16, -2.8842036081919542e-02, + 7.7164692943283169e-25, -1.8446994643771725e-01], + [-4.7070881487314614e-01, 4.7473787464450845e-01, + -4.8036836210243367e-01, 4.3802686872516400e-01, + 1.7961797619639258e-01, 8.3080980076741355e-03, + 2.1415294457221756e-01, -2.2856669794666584e-01], + [-3.2284062926217072e-01, -5.4336490915553370e-01, + 2.2181041859724990e-01, 2.9947877954402297e-01, + -3.6491813600134632e-01, 3.2867679819727436e-01, + 3.8223299448843473e-01, -2.7266344945561438e-01], + [-1.7497244365119530e-01, -8.9251550609769414e-02, + -6.3518515114898394e-02, 1.9162997359209971e-01, + -2.2087281326110139e-01, 5.9957027043505064e-02, + -8.7632498908241274e-01, -3.1676020096456303e-01], + [-2.7104258040220038e-02, -3.3772873786627672e-01, + 2.5901386593721748e-01, 1.7032650752287815e-01, + 6.7521217612940332e-01, -4.5036136532965476e-01, + -1.2279030059078447e-02, -3.6085695247351163e-01], + [ 1.2076392757075530e-01, -3.3834734096469254e-01, + -6.5506827461665540e-01, -5.0472498521116749e-01, + 6.9987430903492118e-02, 1.0595648906599275e-01, + 8.3443844143082022e-02, -4.0495370398246017e-01], + [ 2.6863211318173097e-01, 2.2958613191407318e-01, + 6.3952843755683941e-02, 1.8776775771084137e-02, + -5.3523731432241317e-01, -5.9199531677602002e-01, + 1.7916671834524248e-01, -4.4905045549140887e-01], + [ 4.1650029879270661e-01, 3.6355449432857079e-01, + 2.9755313100756142e-01, 1.6826270392615944e-02, + 1.9621068035557282e-01, 5.6830030587314817e-01, + 2.9607517592514246e-02, -4.9314720700035747e-01]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, + -1.9932120610662194e-14, -5.7323356091157378e-15, + -4.5459724251334835e-16, 4.0479851042511616e-14, + 9.2325194924982089e-14, 2.7659880477613365e+02])), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc13 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xf64> loc(#loc8) + %1 = stablehlo.reshape %0 : (tensor<64xf64>) -> tensor<8x8xf64> loc(#loc9) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf64>) -> tensor<8x8xf64> loc(#loc10) + %3 = stablehlo.add %1, %2 : tensor<8x8xf64> loc(#loc11) + %cst = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf64> loc(#loc12) + %5 = stablehlo.divide %3, %4 : tensor<8x8xf64> loc(#loc12) + %6 = call @tril(%5) : (tensor<8x8xf64>) -> tensor<8x8xf64> loc(#loc13) + %c = stablehlo.constant dense<8> : tensor loc(#loc14) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc14) + %7:3 = stablehlo.custom_call @lapack_dsyevd_ffi(%6) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8xf64>, tensor) loc(#loc14) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc14) + %8 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc14) + %9 = stablehlo.compare EQ, %7#2, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc14) + %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc14) + %cst_2 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc14) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor) -> tensor<8x8xf64> loc(#loc14) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc14) + %13 = stablehlo.select %12, %7#0, %11 : tensor<8x8xi1>, tensor<8x8xf64> loc(#loc14) + %14 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc14) + %cst_3 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc14) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf64> loc(#loc14) + %16 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc14) + %17 = stablehlo.select %16, %7#1, %15 : tensor<8xi1>, tensor<8xf64> loc(#loc14) + return %13, %17 : tensor<8x8xf64>, tensor<8xf64> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6))) -> (tensor<8x8xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc15) + %c = stablehlo.constant dense<0> : tensor loc(#loc13) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc16) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc16) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc17) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc18) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc13) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf64> loc(#loc19) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xf64> loc(#loc20) + return %6 : tensor<8x8xf64> loc(#loc13) + } loc(#loc13) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc8 = loc("jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]"(#loc1)) +#loc9 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc10 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc11 = loc("jit()/jit(main)/add"(#loc4)) +#loc12 = loc("jit()/jit(main)/div"(#loc5)) +#loc14 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc7)) +#loc15 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc6)) +#loc16 = loc("jit()/jit(main)/jit(tril)/add"(#loc6)) +#loc17 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc6)) +#loc18 = loc("jit()/jit(main)/jit(tril)/ge"(#loc6)) +#loc19 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc6)) +#loc20 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc6)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01+\x05\x01\x03\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff9\x01\xa1\x0f\x13\x17\x0b\x0f\x0b\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x03_\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O//\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0b/O/\x01\x05\x0b\x0f\x035\x17\x0f\x07\x0f\x07\x07\x13\x17\x0f\x07\x07\x17\x13\x07\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x02j\t\x1d\x83\x85\x03\x03\x11\xcb\x17\x07V\x047\x05!\x1d?\x05\x05#\x1f\x05%\x05'\x11\x03\x05\x05)\x05+\x05-\x05/\x03\x03\x1f\xc7\x051\x03\x03\x0b\xc9\x1dE\x05\x053\x055\x1d{}\x03\x03\x0b\xd7\x03\x03\x0b\xf9\x03\t135\x137\x13\x0f9\x057\x11\x01\x00\x059\x05;\x05=\x03\x0b\x15\xa5\x17\xb1\x19\xb3\x0f\xbd\x1b\xbf\x03\x0b\x15\xa9\x17\xc3\x19\xa9\x0f\xab\x1b\xc5\x05?\x1dC\x05\x05A\x05C\x03\x03\x1f\xcd\x1dK\x05\x05E\x03\x05%\xad'\xcf\x1dQ\x05\x05G\x03\x03\x0b\xd1\x1dW\x05\x05I\x1d[\x05\x05K\x1d_a\x05M\x17\x076\x045\x1deg\x05O\x17\x076\x04\x1d\x03\x03k\xd3\x05Q\x1doq\x05S\x17\x07>\x04E\x1duw\x05U\x17\x07>\x04\x1f\x03\x03\x0b\xd5\x05W\x17\x07>\x04\x1d\x03\x03\x81\xab\x05Y\x05[\x17\x07V\x04\x17\x03\x13\x89\xd9\x8b\xdb\x8d\xdd\x8f\xa5\x91\xdf\x93\xe1\x95\xeb\x97\xed\x99\xf1\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x05k\x05m\x03\x05%\xad'\xf7\x03\x03\x11\xfb\x03\x03\x11\xfd\x1do\x1dq\x03\x01\x1ds\x03\x03\xc1\x1du\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#!\x03\x05\xb5\xb9\r\x05\xa7\xb7\xa1\xa3\x1dw\r\x05\xa7\xbb\xa1\xa3\x1dy\x1d{\x1d}\r\x03\xa1\xa3##\x1d\x7f\x13\t\x01\x1f\x0b\t\x00\x00\x00\x00\x1f%\x01\x13\t\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x15\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x81\x1d\x83\x05\x01\r\x05\xe3\xe5\xe7\xe9\x1d\x85\x13\x1fV\x1d\x87\x13\x1fL\x03\x03\xaf\x03\x03\xef\x15\x03\x01\x01\x01\x03\x07\xaf\xf3\xf5\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x01\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05!!\x0f)\x01\x0f\x1d)\x01\x19\x01\x0b)\x03!\x0f)\x05!!\x19)\x01\t\x13\x1b)\x05!!\r)\x03\t\t!\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\t)\x03\x02\x02\x0f)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\t\x04~\x04\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x035e\x07\x03]\x1d\x03'\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\r\x06s\x03\x05\x05\x03\x05\x05\x03\ry\x03\x07\x03\x07)\x03\x03\x05\x03\t\x17\x06)\x03\x05\x05\x07\x0b\x19\x07\t\x7f\x03\x05\x03\r\x05\x03\x01+\x03\x15\x05\x03\x01+\x03\x15\x1b\x07\x01\x87\x07\x05\x11\x0b\x03\x0f\x05\x03\x01!\x03\x0b\x03\x07\x01\x03\x03\x0b\x03\x1b\x0f\x07\x01\x9b\x03/\x05\x19\x1d\x03\x07\x01\x03\x031\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x9d\x03\x1b\x03!\t\x06\x01\x03\x05\x07'\x15%\x03\x07\x01\x03\x033\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x11\x03-\x03\x07\x01\x9f\x035\x03+\t\x06\x01\x03\x11\x071\x17/\x11\x04\r\x05)3\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1d\x03\x13\x05\x03\t!\x03\x0b\x03\x07#\x03\x03\x13\x03\x05\r\x06#\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1b\x05\t\x0b\x05\x03\tS\x03\x07\x03\x07U\x03\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00J\x1a\x89\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99m\x19\x85\x89W\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_dsyevd_ffi\x00mode\x00uplo\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 45eed43e0b4f..1a792e3adc0c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -615,7 +615,9 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand, + # TODO(b/344892332): Remove the conditional after the compatibility period. + ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand, input_shape_vals=op_shape_vals, jobvl=compute_left_eigenvectors, jobvr=compute_right_eigenvectors) @@ -801,7 +803,8 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index): def _eigh_cpu_gpu_lowering( - syevd_impl, ctx, operand, *, lower, sort_eigenvalues, subset_by_index + syevd_impl, ctx, operand, *, lower, sort_eigenvalues, subset_by_index, + platform=None ): del sort_eigenvalues # The CPU/GPU implementations always sort. operand_aval, = ctx.avals_in @@ -821,7 +824,12 @@ def _eigh_cpu_gpu_lowering( raise NotImplementedError("subset_by_index not implemented for CPU and GPU") op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - v, w, info = syevd_impl(operand_aval.dtype, operand, + cpu_args = [] + if platform == "cpu": + # TODO(b/344892332): Remove the conditional after the compatibility period. + ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + cpu_args.extend(ctx_args) + v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand, a_shape_vals=op_shape_vals, lower=lower) zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) @@ -955,15 +963,17 @@ def _eigh_batching_rule( batching.primitive_batchers[eigh_p] = _eigh_batching_rule mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo), + eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo, platform='cpu'), platform='cpu') if gpu_solver is not None: mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd), + eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd, + platform='cuda'), platform='cuda') mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd), + eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd, + platform='rocm'), platform='rocm') mlir.register_lowering( diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index a71f219acd1d..a389380a61ec 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -27,6 +27,7 @@ from jaxlib import xla_client from .cpu import _lapack +from .cpu._lapack import eig from .hlo_helpers import ( custom_call, hlo_u8, hlo_s32, ensure_hlo_s32, hlo_add, hlo_min, @@ -523,10 +524,9 @@ def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, # # syevd: Symmetric eigendecomposition -def syevd_hlo(dtype, a: ir.Value, +def syevd_hlo(ctx, dtype, a: ir.Value, a_shape_vals: tuple[DimensionSize, ...], lower=False): - _lapack.initialize() a_type = ir.RankedTensorType(a.type) assert len(a_shape_vals) >= 2 m, n = a_shape_vals[-2:] @@ -535,76 +535,110 @@ def syevd_hlo(dtype, a: ir.Value, batch_dims_vals = a_shape_vals[:-2] num_bd = len(a_shape_vals) - 2 + mode = _enum_to_char_attr(eig.ComputationMode.kComputeEigenvectors) i32_type = ir.IntegerType.get_signless(32) workspace: list[ShapeTypePair] - if dtype == np.float32: - fn = "lapack_ssyevd" - eigvals_type = ir.F32Type.get() - workspace = [ - ([_lapack.syevd_work_size(n)], a_type.element_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.float64: - fn = "lapack_dsyevd" - eigvals_type = ir.F64Type.get() - workspace = [ - ([_lapack.syevd_work_size(n)], a_type.element_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.complex64: - fn = "lapack_cheevd" + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + # Hermitian is for complex square matrices, symmetric otherwise. + fn_base = "he" if dtype == np.complex64 or dtype == np.complex128 else "sy" + fn_base = prepare_lapack_call(fn_base=fn_base + "evd", dtype=dtype) + if ctx.is_forward_compat(): + fn = fn_base + if dtype == np.float32: + eigvals_type = ir.F32Type.get() + workspace = [ + ([_lapack.syevd_work_size(n)], a_type.element_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + elif dtype == np.float64: + eigvals_type = ir.F64Type.get() + workspace = [ + ([_lapack.syevd_work_size(n)], a_type.element_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + elif dtype == np.complex64: + eigvals_type = ir.F32Type.get() + workspace = [ + ([_lapack.heevd_work_size(n)], a_type.element_type), + ([_lapack.heevd_rwork_size(n)], eigvals_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + elif dtype == np.complex128: + eigvals_type = ir.F64Type.get() + workspace = [ + ([_lapack.heevd_work_size(n)], a_type.element_type), + ([_lapack.heevd_rwork_size(n)], eigvals_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + + scalar_layout = [] + shape_layout = [0] + workspace_layouts = [shape_layout] * len(workspace) + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + + result_types, result_shapes = mk_result_types_and_shapes( + [(a_shape_vals, a_type.element_type), + (batch_dims_vals + (n,), eigvals_type), + (batch_dims_vals, i32_type)] + workspace + ) + + return custom_call( + fn, + result_types=result_types, + operands=[hlo_s32(1 if lower else 0), batch_size_val, ensure_hlo_s32(n), a], + operand_layouts=[scalar_layout] * 3 + [layout], + result_layouts=[ + layout, + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + ] + workspace_layouts, + operand_output_aliases={3: 0}, + result_shapes=result_shapes, + ).results[:3] + fn = fn_base + "_ffi" + if dtype == np.float32 or dtype == np.complex64: eigvals_type = ir.F32Type.get() - workspace = [ - ([_lapack.heevd_work_size(n)], a_type.element_type), - ([_lapack.heevd_rwork_size(n)], eigvals_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.complex128: - fn = "lapack_zheevd" + elif dtype == np.float64 or dtype == np.complex128: eigvals_type = ir.F64Type.get() - workspace = [ - ([_lapack.heevd_work_size(n)], a_type.element_type), - ([_lapack.heevd_rwork_size(n)], eigvals_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] else: raise NotImplementedError(f"Unsupported dtype {dtype}") - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - - scalar_layout = [] - shape_layout = [0] - workspace_layouts = [shape_layout] * len(workspace) - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - - result_types, result_shapes = mk_result_types_and_shapes( - [(a_shape_vals, a_type.element_type), - (batch_dims_vals + (n,), eigvals_type), - (batch_dims_vals, i32_type)] + workspace - ) + result_types, result_shapes = mk_result_types_and_shapes([ + (a_shape_vals, a_type.element_type), + (batch_dims_vals + (n,), eigvals_type), + (batch_dims_vals, i32_type), + ]) - out = custom_call( + return custom_call( fn, result_types=result_types, - operands=[hlo_s32(1 if lower else 0), batch_size_val, ensure_hlo_s32(n), a], - operand_layouts=[scalar_layout] * 3 + [layout], + operands=[a], + operand_layouts=[layout], result_layouts=[ layout, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), - ] + workspace_layouts, - operand_output_aliases={3: 0}, + ], + operand_output_aliases={0: 0}, result_shapes=result_shapes, + backend_config={ + "uplo": _matrix_uplo_attr(lower=lower), + "mode": mode, + }, + api_version=4, ).results - return out[:3] # # geev: Nonsymmetric eigendecomposition (eig) -def geev_hlo(dtype, input, *, +def geev_hlo(ctx, dtype, input, *, input_shape_vals: tuple[DimensionSize, ...], # input.shape as ir.Values jobvl=True, jobvr=True): # input_shape_vals are used for when input has dynamic shapes. @@ -617,80 +651,128 @@ def geev_hlo(dtype, input, *, layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - jobvl_c = ord('V' if jobvl else 'N') - jobvr_c = ord('V' if jobvr else 'N') + compute_left = ( + eig.ComputationMode.kComputeEigenvectors + if jobvl + else eig.ComputationMode.kNoEigenvectors + ) + + compute_right = ( + eig.ComputationMode.kComputeEigenvectors + if jobvr + else eig.ComputationMode.kNoEigenvectors + ) + fn_base = build_lapack_fn_target(fn_base="geev", dtype=dtype) i32_type = ir.IntegerType.get_signless(32) f32_type = ir.F32Type.get() f64_type = ir.F64Type.get() c64_type = ir.ComplexType.get(ir.F32Type.get()) c128_type = ir.ComplexType.get(ir.F64Type.get()) + if ctx.is_forward_compat(): + fn = fn_base + workspaces: list[ShapeTypePair] + eigvals: list[ShapeTypePair] + if dtype == np.float32: + real = True + eigvecs_type = c64_type + workspaces = [([n, n], f32_type)] * 3 + workspace_layouts = [[0, 1]] * 3 + eigvals = [(batch_dims_vals + (n,), f32_type)] * 2 + eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 + elif dtype == np.float64: + real = True + eigvecs_type = c128_type + workspaces = [([n, n], f64_type)] * 3 + workspace_layouts = [[0, 1]] * 3 + eigvals = [(batch_dims_vals + (n,), f64_type)] * 2 + eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 + elif dtype == np.complex64: + real = False + eigvecs_type = c64_type + workspaces = [([n, n], c64_type), ([hlo_add(n, n)], f32_type)] + workspace_layouts = [[0, 1], [0]] + eigvals = [(batch_dims_vals + (n,), c64_type)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + elif dtype == np.complex128: + real = False + eigvecs_type = c128_type + workspaces = [([n, n], c128_type), ([hlo_add(n, n)], f64_type)] + workspace_layouts = [[0, 1], [0]] + eigvals = [(batch_dims_vals + (n,), c128_type)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") - workspaces: list[ShapeTypePair] - eigvals: list[ShapeTypePair] - if dtype == np.float32: - fn = "lapack_sgeev" - real = True - eigvecs_type = c64_type - workspaces = [([n, n], f32_type)] * 3 - workspace_layouts = [[0, 1]] * 3 - eigvals = [(batch_dims_vals + (n,), f32_type)] * 2 - eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 - elif dtype == np.float64: - fn = "lapack_dgeev" - real = True - eigvecs_type = c128_type - workspaces = [([n, n], f64_type)] * 3 - workspace_layouts = [[0, 1]] * 3 - eigvals = [(batch_dims_vals + (n,), f64_type)] * 2 - eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 - elif dtype == np.complex64: - fn = "lapack_cgeev" - real = False - eigvecs_type = c64_type - workspaces = [([n, n], c64_type), ([hlo_add(n, n)], f32_type)] - workspace_layouts = [[0, 1], [0]] - eigvals = [(batch_dims_vals + (n,), c64_type)] - eigvals_layouts = [tuple(range(num_bd, -1, -1))] - elif dtype == np.complex128: - fn = "lapack_zgeev" - real = False - eigvecs_type = c128_type - workspaces = [([n, n], c128_type), ([hlo_add(n, n)], f64_type)] - workspace_layouts = [[0, 1], [0]] - eigvals = [(batch_dims_vals + (n,), c128_type)] - eigvals_layouts = [tuple(range(num_bd, -1, -1))] - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] - info_layout = tuple(range(num_bd - 1, -1, -1)) + scalar_layout = [] + info_layout = tuple(range(num_bd - 1, -1, -1)) - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - shape_type_pairs: Sequence[ShapeTypePair] = workspaces + eigvals + [ + shape_type_pairs: Sequence[ShapeTypePair] = workspaces + eigvals + [ + (input_shape_vals, eigvecs_type), + (input_shape_vals, eigvecs_type), + (batch_dims_vals, i32_type)] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + out = custom_call( + fn, + result_types=result_types, + operands=[batch_size_val, ensure_hlo_s32(n), + hlo_u8(compute_left.value), + hlo_u8(compute_right.value), + input], + operand_layouts=[scalar_layout] * 4 + [layout], + result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 + + [info_layout]), + result_shapes=result_shapes, + ).results + if real: + return (hlo.complex(out[3], out[4]), out[5], out[6], out[7]) + else: + return out[2:6] + fn = fn_base + "_ffi" + real = dtype == np.float32 or dtype == np.float64 + eigvecs_type = ( + c64_type if dtype == np.float32 or dtype == np.complex64 else c128_type + ) + input_type = ir.RankedTensorType(input.type) + eigvals = [(batch_dims_vals + (n,), input_type.element_type)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + if real: + eigvals = eigvals * 2 + eigvals_layouts = eigvals_layouts * 2 + info_layout = tuple(range(num_bd - 1, -1, -1)) + shape_type_pairs: Sequence[ShapeTypePair] = [ + *eigvals, (input_shape_vals, eigvecs_type), (input_shape_vals, eigvecs_type), - (batch_dims_vals, i32_type)] + (batch_dims_vals, i32_type), + ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) out = custom_call( fn, result_types=result_types, - operands=[batch_size_val, ensure_hlo_s32(n), - hlo_u8(jobvl_c), - hlo_u8(jobvr_c), - input], - operand_layouts=[scalar_layout] * 4 + [layout], - result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 + - [info_layout]), + operands=[input], + operand_layouts=[layout], + result_layouts=( + *eigvals_layouts, + layout, + layout, + info_layout, + ), result_shapes=result_shapes, + backend_config={ + "compute_left": _enum_to_char_attr(compute_left), + "compute_right": _enum_to_char_attr(compute_right), + }, + api_version=4, ).results if real: - return (hlo.complex(out[3], out[4]), out[5], out[6], out[7]) + return (hlo.complex(out[0], out[1]), out[2], out[3], out[4]) else: - return out[2:6] + return out[:4] # # gees : Schur factorization diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 26eb34088460..4e7898d57fe0 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -115,6 +115,8 @@ def test_custom_call_coverage(self): cpu_ffi_testdatas = [ cpu_cholesky_lapack_potrf.data_2024_05_31, cpu_qr_lapack_geqrf.data_2024_08_22, + cpu_eig_lapack_geev.data_2024_08_19, + cpu_eigh_lapack_syev.data_2024_08_19, cpu_lu_lapack_getrf.data_2024_05_31, cpu_svd_lapack_gesdd.data_2024_08_13, ] @@ -256,6 +258,14 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_eig_results) + # TODO(b/344892332): Remove the check after the compatibility period. + has_xla_ffi_support = jaxlib_version >= (0, 4, 32) + if has_xla_ffi_support: + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=check_eig_results) @staticmethod def eigh_input(shape, dtype): @@ -306,6 +316,14 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) + # TODO(b/344892332): Remove the check after the compatibility period. + has_xla_ffi_support = jaxlib_version >= (0, 4, 32) + if has_xla_ffi_support: + # FFI Kernel test + with config.export_ignore_forward_compatibility(True): + data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_eigh_results, operand)) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{variant}", From fc1af8d050923950c831ae2fef5d36a3b5a4e5f8 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 28 Aug 2024 05:10:08 -0700 Subject: [PATCH 261/702] Support strided load / store in interpret mode This is a part of the efforts to fix the indexing implementation in JAX state. This PR adds support for strides in array indexing. In other words, the aim of the PR is to support this test: https://github.com/google/jax/blob/bb160cf54ef5f69c2a59a4001a3210fd56a7b286/tests/pallas/ops_test.py#L772-L786 This PR adds a set of test cases that makes it easier to track the completeness of the indexing implementation in JAX state. Test cases that are not yet supported are temporarily commented out. PiperOrigin-RevId: 668402290 --- jax/_src/pallas/primitives.py | 4 +- jax/_src/state/discharge.py | 90 +++++++++++++++++++++++------------ tests/pallas/indexing_test.py | 73 ++++++++++++++++++++++++++++ tests/pallas/ops_test.py | 4 -- 4 files changed, 135 insertions(+), 36 deletions(-) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 2d4ca1b8ca5b..53227478c312 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -471,7 +471,7 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): raise NotImplementedError("Only one indexer supported in discharge rule.") idx = indexers[0] if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): - # TODO(b/329733289): support strided load/store in interpret mode. + # TODO(ayx): support strided load/store in interpret mode. for s in idx.indices: if isinstance(s, Slice) and s.stride > 1: raise NotImplementedError("Unimplemented stride support.") @@ -583,7 +583,7 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): raise NotImplementedError("Only one indexer supported in discharge rule.") idx = indexers[0] if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): - # TODO(b/329733289): support strided load/store in interpret mode. + # TODO(ayx): support strided load/store in interpret mode. for s in idx.indices: if isinstance(s, Slice) and s.stride > 1: raise NotImplementedError("Unimplemented stride support.") diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 1feac75eb530..4795af054280 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -158,33 +158,27 @@ def _is_trivial_indexer(indexer: indexing.NDIndexer): return False return True -def _convert_to_array_indexer(indexer: indexing.NDIndexer - ) -> tuple[int | Array, ...]: - # This is the general gather case. We need to create the gather arrays. - is_integer_indexer, _, integer_indexer = ( - indexing.unpack_ndindexer(indexer) - ) - total_shape = indexer.get_indexer_shape() - int_indexer_shape = indexer.int_indexer_shape - slice_shape = total_shape[len(int_indexer_shape):] - slice_dims = tuple( - i + len(int_indexer_shape) for i in range(len(slice_shape)) - ) - slice_dim_iter = iter(slice_dims) - slice_indexer: list[Array] = [] - for idx, is_int_index in zip(indexer.indices, is_integer_indexer): - if not is_int_index: - assert isinstance(idx, indexing.Slice) - slice_indices = lax.broadcasted_iota( - np.dtype("int32"), total_shape, next(slice_dim_iter) - ) + idx.start - slice_indexer.append(slice_indices) - integer_indexer = tuple( - lax.expand_dims(idx, (-1,)) for idx in integer_indexer - ) - continue - assert next(slice_dim_iter, None) is None - return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer)) + +def _maybe_convert_to_slice( + indexer: indexing.NDIndexer +) -> list[tuple[int, int, int]] | None: + args = [] + + for i in indexer.indices: + if not isinstance(i, indexing.Slice): + return None + + start = i.start + end = i.start + i.size * i.stride + stride = i.stride + + # cannot convert to static `slice` if `start` or `end` is dynamic + if not isinstance(start, int) or not isinstance(end, int): + return None + + args.append((start, end, stride)) + + return args def _maybe_convert_to_dynamic_slice( @@ -198,10 +192,12 @@ def _maybe_convert_to_dynamic_slice( if not all(isinstance(i, indexing.Slice) or not np.shape(i) for i in indexer.indices): return None - # TODO(b/329733289): support strided load/store in interpret mode. + + # `lax.dynamic_slice` does not handle striding for i in indexer.indices: if isinstance(i, indexing.Slice) and i.stride > 1: - raise NotImplementedError("Unimplemented stride support.") + return None + _convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32")) starts = tuple( _convert_i32(i.start) if isinstance(i, indexing.Slice) @@ -218,6 +214,35 @@ def _maybe_convert_to_dynamic_slice( return starts, sizes, squeeze_dims +def _convert_to_array_indexer(indexer: indexing.NDIndexer + ) -> tuple[int | Array, ...]: + # This is the general gather case. We need to create the gather arrays. + is_integer_indexer, _, integer_indexer = ( + indexing.unpack_ndindexer(indexer) + ) + total_shape = indexer.get_indexer_shape() + int_indexer_shape = indexer.int_indexer_shape + slice_shape = total_shape[len(int_indexer_shape):] + slice_dims = tuple( + i + len(int_indexer_shape) for i in range(len(slice_shape)) + ) + slice_dim_iter = iter(slice_dims) + slice_indexer: list[Array] = [] + for idx, is_int_index in zip(indexer.indices, is_integer_indexer): + if not is_int_index: + assert isinstance(idx, indexing.Slice) + slice_indices = lax.broadcasted_iota( + np.dtype("int32"), total_shape, next(slice_dim_iter) + ) * idx.stride + idx.start + slice_indexer.append(slice_indices) + integer_indexer = tuple( + lax.expand_dims(idx, (-1,)) for idx in integer_indexer + ) + continue + assert next(slice_dim_iter, None) is None + return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer)) + + @register_discharge_rule(get_p) def _get_discharge_rule( in_avals: Sequence[core.AbstractValue], @@ -249,10 +274,15 @@ def index_array(x, indexers): continue if indexer is None: continue + + # Try the three APIs in the following order: `lax.slice`, + # `lax.dynamic_slice` and gather + if maybe_slice := _maybe_convert_to_slice(indexer): + result = lax_slicing.slice(result, *zip(*maybe_slice)) # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. - if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_slice y = lax_slicing.dynamic_slice(result, starts, sizes) result = lax.squeeze(y, squeeze_dims) diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 2818712c0359..59e28db6d9e2 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -645,5 +645,78 @@ class IndexerOpsInterpretTest(IndexerOpsTest): INTERPRET = True +# TODO(ayx): Fix all test cases here +_ADVANCED_INDEXER_TEST_CASES = [ + ((8, 2), lambda arr, a, b, c, d: arr[2]), + ((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]), + ((16, 3), lambda arr, a, b, c, d: arr[a, a]), + ((16, 16), lambda arr, a, b, c, d: arr[::4, ::4]), + ((16, 16), lambda arr, a, b, c, d: arr[1:14:2, 2:13:4]), + ((16, 3), lambda arr, a, b, c, d: arr[a, :]), + # ((16, 3), lambda arr, a, b, c, d: arr[:, a]), + ((16, 3), lambda arr, a, b, c, d: arr[a, ::4]), + # ((16, 3), lambda arr, a, b, c, d: arr[::4, a]), + # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), + # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), + # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, ::2, ::2]), + # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, ::2, ::2]), + # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]), + # ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]), + # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]), + # ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, c]), + ((8, 6, 4), lambda arr, a, b, c, d: arr[a]), + ((6, 8, 4), lambda arr, a, b, c, d: arr[c, c]), + ((6, 8, 4), lambda arr, a, b, c, d: arr[c, ::3]), + # ((8, 6, 4), lambda arr, a, b, c, d: arr[::3, c]), + # ((6, 2), lambda arr, a, b, c, d: arr[d]), + # ((8, 6), lambda arr, a, b, c, d: arr[::4, d]), +] + + +class AdvancedIndexerOpsTest(PallasBaseTest): + + def setUp(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Advanced indexers are not supported on TPU") + + # 4 arrays that are used in test cases of advanced indexing + self.a = jnp.array([1, 1, 1, 1, 1], dtype=jnp.int32) + self.b = jnp.array([1, 2, 2, 2, 2], dtype=jnp.int32) + self.c = jnp.array([1, 0, 2, 2, -1, 1], dtype=jnp.int32) + self.d = jnp.array([1, 0, 0, 0, 0, 1], dtype=jnp.bool_) + + super().setUp() + + @parameterized.parameters(_ADVANCED_INDEXER_TEST_CASES) + def test_advanced_indexer(self, in_shape: tuple[int, ...], indexing_func): + a, b, c, d = self.a, self.b, self.c, self.d + + x = jnp.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape) + y = indexing_func(x, a, b, c, d) + + # `a_ref`, `b_ref`, `c_ref` and `d_ref` are for testing purposes. + # We have them here because we need to have a unified function signature + # for all test cases, even if the arrays are actually not used in any + # computation. + def kernel(x_ref, a_ref, b_ref, c_ref, d_ref, o_ref): + a = a_ref[...] + b = b_ref[...] + c = c_ref[...] + d = d_ref[...] + o = indexing_func(x_ref, a, b, c, d) + o_ref[...] = o + + y_ = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32), + )(x, a, b, c, d) + + np.testing.assert_array_equal(y_, y) + + +class AdvancedIndexerOpsInterpretTest(AdvancedIndexerOpsTest): + INTERPRET = True + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index cf247ac3f6a4..849631d969d9 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1167,10 +1167,6 @@ def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): np.testing.assert_array_equal(out, o_new) def test_strided_load(self): - if self.INTERPRET: - # TODO(b/329733289): Remove this once the bug is fixed. - self.skipTest("Strided load not yet supported in interpret mode") - # Reproducer from https://github.com/google/jax/issues/20895. @functools.partial( self.pallas_call, From ced012f5eda2da851ed1c905496b884ccca0c84e Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 28 Aug 2024 20:16:09 +0530 Subject: [PATCH 262/702] Update jnp.fabs to emulate the behavior of np.fabs for complex inputs --- CHANGELOG.md | 2 ++ jax/_src/numpy/ufuncs.py | 41 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d09d610bdb5..9d0b55b36476 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `jax.config.update('jax_cpu_enable_async_dispatch', False)`. * Added new {func}`jax.process_indices` function to replace the `jax.host_ids()` function that was deprecated in JAX v0.2.13. + * To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been + modified to no longer support `complex dtypes`. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 2893b14f7059..c4f9009eb877 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -52,9 +52,48 @@ def _replace_inf(x: ArrayLike) -> Array: def _to_bool(x: Array) -> Array: return x if x.dtype == bool else lax.ne(x, _lax_const(x, 0)) -@implements(np.fabs, module='numpy') + @partial(jit, inline=True) def fabs(x: ArrayLike, /) -> Array: + """Compute the element-wise absolute values of the real-valued input. + + JAX implementation of :func:`numpy.fabs`. + + Args: + x: input array or scalar. Must not have a complex dtype. + + Returns: + An array with same shape as ``x`` and dtype float, containing the element-wise + absolute values. + + See also: + - :func:`jax.numpy.absolute`: Computes the absolute values of the input including + complex dtypes. + - :func:`jax.numpy.abs`: Computes the absolute values of the input including + complex dtypes. + + Examples: + For integer inputs: + + >>> x = jnp.array([-5, -9, 1, 10, 15]) + >>> jnp.fabs(x) + Array([ 5., 9., 1., 10., 15.], dtype=float32) + + For float type inputs: + + >>> x1 = jnp.array([-1.342, 5.649, 3.927]) + >>> jnp.fabs(x1) + Array([1.342, 5.649, 3.927], dtype=float32) + + For boolean inputs: + + >>> x2 = jnp.array([True, False]) + >>> jnp.fabs(x2) + Array([1., 0.], dtype=float32) + """ + check_arraylike('fabs', x) + if dtypes.issubdtype(dtypes.dtype(x), np.complexfloating): + raise TypeError("ufunc 'fabs' does not support complex dtypes") return lax.abs(*promote_args_inexact('fabs', x)) @implements(getattr(np, 'bitwise_invert', np.invert), module='numpy') From 78d5b75b0dfefeb3feaf753b15cbf7fbbca40071 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Wed, 28 Aug 2024 09:00:30 -0700 Subject: [PATCH 263/702] Trim StableHLO python binding dependencies With proper CAPI in place these dependencies are no longer needed, llvm support needed for string ostream for string APIs. PiperOrigin-RevId: 668476145 --- jaxlib/mlir/_mlir_libs/BUILD.bazel | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index b7101a07d989..db02eb8bbff1 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -333,15 +333,12 @@ py_extension( linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", + "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonHeaders", "@local_config_python//:headers", "@pybind11", - "@stablehlo//:reference_api", "@stablehlo//:stablehlo_capi_headers", - "@stablehlo//:stablehlo_portable_api", - "@stablehlo//:stablehlo_serialization", ], ) From 46957052c5cb797d38f1cefdd501ad7abfe5f751 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 28 Aug 2024 10:12:51 -0700 Subject: [PATCH 264/702] Don't share the same global jit cpp cache between jit and pjit PiperOrigin-RevId: 668503956 --- jax/_src/interpreters/pxla.py | 3 ++- jax/_src/pjit.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7a231c06fd61..f6f413307f9a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2924,7 +2924,7 @@ class MeshExecutableFastpathData(NamedTuple): in_device_local_layouts: Sequence[DeviceLocalLayout | None] -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class JitGlobalCppCacheKeys: donate_argnums: tuple[int, ...] | None = None donate_argnames: tuple[str, ...] | None = None @@ -2938,6 +2938,7 @@ class JitGlobalCppCacheKeys: in_layouts_leaves: tuple[Any, ...] | None = None out_layouts_treedef: PyTreeDef | None = None out_layouts_leaves: tuple[Any, ...] | None = None + use_resource_env: bool = False @functools.cached_property def contains_explicit_attributes(self): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 7badefab7922..197388afe84b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -379,7 +379,8 @@ def cache_miss(*args, **kwargs): in_layouts_treedef=jit_info.in_layouts_treedef, in_layouts_leaves=jit_info.in_layouts_leaves, out_layouts_treedef=jit_info.out_layouts_treedef, - out_layouts_leaves=jit_info.out_layouts_leaves) + out_layouts_leaves=jit_info.out_layouts_leaves, + use_resource_env=jit_info.use_resource_env) cpp_pjit_f = xc._xla.pjit( fun_name(fun), fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore @@ -1767,7 +1768,8 @@ def call_impl_cache_miss(*args_, **kwargs_): in_shardings_treedef=None, in_shardings_leaves=in_shardings, out_shardings_treedef=None, out_shardings_leaves=out_shardings, in_layouts_treedef=None, in_layouts_leaves=in_layouts, - out_layouts_treedef=None, out_layouts_leaves=out_layouts) + out_layouts_treedef=None, out_layouts_leaves=out_layouts, + use_resource_env=resource_env is not None) return xc._xla.pjit( name, f, call_impl_cache_miss, [], [], cache_key, tree_util.dispatch_registry, pxla.cc_shard_arg, From 672a013b3a90782a7ebc7c5646fb9b44d8183e3a Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 10:46:47 -0700 Subject: [PATCH 265/702] docs: prefer "summary" to "tl;dr" This was common and we typically just mean "summary." --- docs/debugging.md | 4 ++-- docs/debugging/checkify_guide.md | 2 +- docs/debugging/flags.md | 4 ++-- docs/debugging/index.md | 8 ++++---- docs/debugging/print_breakpoint.md | 5 +++-- docs/installation.md | 2 +- .../Custom_derivative_rules_for_Python_code.ipynb | 2 +- docs/notebooks/Custom_derivative_rules_for_Python_code.md | 2 +- docs/notebooks/autodiff_remat.ipynb | 2 +- docs/notebooks/autodiff_remat.md | 2 +- 10 files changed, 17 insertions(+), 16 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index 94384035ca9d..1e8501f99e39 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -23,7 +23,7 @@ Let's begin with {func}`jax.debug.print`. ## JAX `debug.print` for high-level -**TL;DR** Here is a rule of thumb: +Here is a rule of thumb: - Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others. - Use Python {func}`print` for static values, such as dtypes and array shapes. @@ -113,7 +113,7 @@ To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`a ## JAX `debug.breakpoint` for `pdb`-like debugging -**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. +**Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack. diff --git a/docs/debugging/checkify_guide.md b/docs/debugging/checkify_guide.md index 2dad9b863b06..8b012e97ef28 100644 --- a/docs/debugging/checkify_guide.md +++ b/docs/debugging/checkify_guide.md @@ -2,7 +2,7 @@ -**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: +**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python from jax.experimental import checkify diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 1cf1829e5152..13e34a6c3ac4 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -6,7 +6,7 @@ JAX offers flags and context managers that enable catching errors more easily. ## `jax_debug_nans` configuration option and context manager -**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code). +**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code). `jax_debug_nans` is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN. @@ -41,7 +41,7 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! ## `jax_disable_jit` configuration option and context manager -**TL;DR** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb` +**Summary:** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb` `jax_disable_jit` is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like `jax.lax.cond` and `jax.lax.scan`). diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 724827f837e3..46523d681512 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -2,7 +2,7 @@ -Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more. +Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has summaries and you can click the "Read more" links at the bottom to learn more. Table of contents: @@ -12,7 +12,7 @@ Table of contents: ## [Interactive inspection with `jax.debug`](print_breakpoint) - **TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, + **Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack: ```python @@ -38,7 +38,7 @@ Click [here](print_breakpoint) to learn more! ## [Functional error checks with `jax.experimental.checkify`](checkify_guide) - **TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: + **Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python from jax.experimental import checkify @@ -81,7 +81,7 @@ Click [here](checkify_guide) to learn more! ## [Throwing Python errors with JAX's debug flags](flags) -**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. +**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. ```python import jax diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index d7cb68bd1b0b..d33498697bb3 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -7,7 +7,8 @@ inside of JIT-ted functions. ## Debugging with `jax.debug.print` and other debugging callbacks -**TL;DR** Use {func}`jax.debug.print` to print traced array values to stdout in `jit`- and `pmap`-decorated functions: +**Summary:** Use {func}`jax.debug.print` to print traced array values to +stdout in compiled (e.g. `jax.jit` or `jax.pmap`-decorated) functions: ```python import jax @@ -236,7 +237,7 @@ Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronizat ## Interactive inspection with `jax.debug.breakpoint()` -**TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values: +**Summary:** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values: ```python @jax.jit diff --git a/docs/installation.md b/docs/installation.md index bd0473d89201..4a831750e42b 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -7,7 +7,7 @@ Using JAX requires installing two packages: `jax`, which is pure Python and cross-platform, and `jaxlib` which contains compiled binaries, and requires different builds for different operating systems and accelerators. -**TL;DR** For most users, a typical JAX installation may look something like this: +**Summary:** For most users, a typical JAX installation may look something like this: * **CPU-only (Linux/macOS/Windows)** ``` diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index ec85f6e63159..6767b33a20f0 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -30,7 +30,7 @@ "id": "9Fg3NFNY-2RY" }, "source": [ - "## TL;DR" + "## Summary" ] }, { diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 6c948650fcc1..000d48c49b18 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -32,7 +32,7 @@ For an introduction to JAX's automatic differentiation API, see [The Autodiff Co +++ {"id": "9Fg3NFNY-2RY"} -## TL;DR +## Summary +++ {"id": "ZgMNRtXyWIW8"} diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index 041cf65314f2..82381838a5aa 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -27,7 +27,7 @@ "id": "qaIsQSh1XoKF" }, "source": [ - "### TL;DR\n", + "### Summary\n", "\n", "Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n", "\n", diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index a4fb27c58128..0a6c84b2d88f 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -24,7 +24,7 @@ import jax.numpy as jnp +++ {"id": "qaIsQSh1XoKF"} -### TL;DR +### Summary Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs. From 763d600508a8307a381f5d1f75efae3e75ffd43b Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 10:59:05 -0700 Subject: [PATCH 266/702] docs: remove inline authors/dates on misc doc pages Git history covers this better and automatically. --- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 2 -- docs/notebooks/Common_Gotchas_in_JAX.md | 2 -- docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb | 2 -- docs/notebooks/Custom_derivative_rules_for_Python_code.md | 2 -- docs/notebooks/autodiff_cookbook.ipynb | 2 -- docs/notebooks/autodiff_cookbook.md | 2 -- 6 files changed, 12 deletions(-) diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index c143b520af51..0cffc22f1e8d 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -19,8 +19,6 @@ "id": "4k5PVzEo2uJO" }, "source": [ - "*levskaya@ mattjj@*\n", - "\n", "When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_\"una anima di pura programmazione funzionale\"_](https://www.sscardapane.it/iaml-backup/jax-intro/).\n", "\n", "__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 3324fdb53bcd..543d9ecb1558 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -22,8 +22,6 @@ kernelspec: +++ {"id": "4k5PVzEo2uJO"} -*levskaya@ mattjj@* - When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_"una anima di pura programmazione funzionale"_](https://www.sscardapane.it/iaml-backup/jax-intro/). __JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU). diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 6767b33a20f0..88d446723dba 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -12,8 +12,6 @@ "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", - "*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n", - "\n", "There are two ways to define differentiation rules in JAX:\n", "\n", "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 000d48c49b18..fdf4b3ed0c8d 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -19,8 +19,6 @@ kernelspec: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) -*mattjj@ Mar 19 2020, last updated Oct 14 2020* - There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 478d84935b7f..3f2f0fd5650d 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -12,8 +12,6 @@ "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", "\n", - "*alexbw@, mattjj@* \n", - "\n", "JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics." ] }, diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index 6615e65352d7..496d676f794a 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -20,8 +20,6 @@ kernelspec: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) -*alexbw@, mattjj@* - JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics. ```{code-cell} ipython3 From 42de34263fd2f3bd9976d193f0a9dd892596202b Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 11:01:43 -0700 Subject: [PATCH 267/702] docs: runtime debugging tweaks Mainly make titles/headings easier to read, by swapping code for words and not using headings as links. --- docs/debugging/index.md | 18 ++++++++++++------ docs/debugging/print_breakpoint.md | 5 ++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 46523d681512..bcf561d06807 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -10,7 +10,9 @@ Table of contents: * [Functional error checks with jax.experimental.checkify](checkify_guide) * [Throwing Python errors with JAX’s debug flags](flags) -## [Interactive inspection with `jax.debug`](print_breakpoint) +## Interactive inspection with `jax.debug` + +Complete guide [here](print_breakpoint) **Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack: @@ -34,9 +36,11 @@ Table of contents: # 🤯 0.9092974662780762 🤯 ``` -Click [here](print_breakpoint) to learn more! +[Read more](print_breakpoint). + +## Functional error checks with `jax.experimental.checkify` -## [Functional error checks with `jax.experimental.checkify`](checkify_guide) +Complete guide [here](checkify_guide) **Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: @@ -77,9 +81,11 @@ Click [here](print_breakpoint) to learn more! # ValueError: nan generated by primitive sin at <...>:8 (f) ``` -Click [here](checkify_guide) to learn more! +[Read more](checkify_guide). + +## Throwing Python errors with JAX's debug flags -## [Throwing Python errors with JAX's debug flags](flags) +Complete guide [here](flags) **Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. @@ -92,7 +98,7 @@ def f(x, y): jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! ``` -Click [here](flags) to learn more! +[Read more](flags). ```{toctree} :caption: Read more diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index d33498697bb3..73ac0262851d 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -1,9 +1,9 @@ -# `jax.debug.print` and `jax.debug.breakpoint` +# Compiled prints and breakpoints The {mod}`jax.debug` package offers some useful tools for inspecting values -inside of JIT-ted functions. +inside of compiled functions. ## Debugging with `jax.debug.print` and other debugging callbacks @@ -27,7 +27,6 @@ f(2.) # 🤯 0.9092974662780762 🤯 ``` - With some transformations, like `jax.grad` and `jax.vmap`, you can use Python's builtin `print` function to print out numerical values. But `print` won't work with `jax.jit` or `jax.pmap` because those transformations delay numerical evaluation. So use `jax.debug.print` instead! Semantically, `jax.debug.print` is roughly equivalent to the following Python function From ef33cf5acee2668b3e847aa19c91a52f6c18328d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 28 Aug 2024 11:05:45 -0700 Subject: [PATCH 268/702] Standardize default layout to `None` in internals (dispatch, lowering and compilation) and non-default layouts to concrete layouts. This massively simplifies the amount of checks we need and improves dispatch time too. It also fixes a donation bug being hit in serving code related to layouts and non-standardization of default layout in JAX. PiperOrigin-RevId: 668527139 --- jax/_src/array.py | 11 ++--- jax/_src/interpreters/mlir.py | 23 +++++---- jax/_src/interpreters/pxla.py | 93 ++++++++++++----------------------- jax/_src/pjit.py | 21 +++++--- tests/layout_test.py | 24 +++++++-- 5 files changed, 80 insertions(+), 92 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 0f554a86a655..7659c180ddc9 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -751,8 +751,7 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: and sharding.is_fully_replicated and first_value.is_fully_replicated and first_value.sharding._device_assignment == tuple(devices) - and (first_value.layout.device_local_layout == - pxla._maybe_get_default_layout(Layout(dll, sharding), None, sharding, aval))): + and first_value.layout.device_local_layout == dll): return first_value if dtypes.issubdtype(aval.dtype, dtypes.extended): @@ -1105,11 +1104,6 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): dst_indices = dst_sharding.addressable_devices_indices_map(shape).values() return dst_indices, tuple(src_indices) == tuple(dst_indices) -def _layout_eq(x, dst_layout, sharding): - if pxla.is_default_layout(dst_layout, sharding, x.aval): - return True - return x.layout.device_local_layout == dst_layout - def _array_shard_arg(xs, shardings, layouts): results = [] @@ -1118,7 +1112,8 @@ def _array_shard_arg(xs, shardings, layouts): for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)): x._check_if_deleted() indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) - same_layout = _layout_eq(x, layout, sharding) + same_layout = (True if layout is None else + x.layout.device_local_layout == layout) if not x.is_fully_addressable: if same_indices and same_layout: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index ab2c77833f15..0e7e0146e984 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1053,6 +1053,7 @@ def lower_jaxpr_to_module( result_memory_kinds = (map(_get_mem_kind, result_shardings) if result_shardings is not None else None) + # TODO(yashkatariya): Simplify the donation logic. xla_donated_args = None platforms_with_donation = [p for p in platforms if p in _platforms_with_donation] @@ -1071,9 +1072,6 @@ def lower_jaxpr_to_module( input_output_aliases, donated_args, xla_donated_args = _set_up_aliases( input_output_aliases, in_avals, out_avals, donated_args, arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts) - unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) - if unlowerable_effects: - raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') if any(donated_args): unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d] msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation." @@ -1082,10 +1080,13 @@ def lower_jaxpr_to_module( if unused_donations: warnings.warn("Some donated buffers were not usable:" f" {', '.join(unused_donations)}.\n{msg}") - # Delete donated_args by default here, since it's not needed beyond this point del donated_args + unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) + if unlowerable_effects: + raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') + # HLO channels need to start at 1 channel_iter = itertools.count(1) # Create a keepalives list that will be mutated during the lowering. @@ -1167,8 +1168,7 @@ def emit_diagnostic_info(d): def _set_up_aliases(input_output_aliases, avals_in, avals_out, - donated_args, - arg_memory_kinds, result_memory_kinds, + donated_args, arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts): if input_output_aliases is None: input_output_aliases = [None] * len(avals_in) @@ -1207,15 +1207,14 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, if donations.get(key, ()): input_id = donations[key].popleft() out_donated_args[input_id] = False + # We can alias if XLA performs layout assignment because XLA will + # respect the aliases when assigning layouts. Its only for two + # mismatched explicitly assigned layouts that XLA will certainly fail. if (in_layouts is None or out_layouts is None or in_layouts[input_id] == out_layouts[i] or - # We can alias if XLA performs layout assignment because XLA will - # respect the aliases when assigning layouts. Its only for two - # mismatched explicitly assigned layouts that XLA will certainly - # fail. - isinstance(in_layouts[input_id], (AutoLayout, type(None))) or - isinstance(out_layouts[i], (AutoLayout, type(None)))): + isinstance(in_layouts[input_id], AutoLayout) or + isinstance(out_layouts[i], AutoLayout)): input_output_aliases[input_id] = i else: # Fallback to xla donation if layouts don't match. diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f6f413307f9a..4e0986cfaf36 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -150,7 +150,7 @@ def shard_args(shardings: Sequence[JSharding], layouts, args, @lru_cache(maxsize=2048) def is_default_layout(curr_layout, sharding, aval): - if curr_layout is None or sharding is None: + if curr_layout is None or sharding is None or is_unspecified(sharding): return True if (aval is core.abstract_token or aval.dtype == dtypes.float0 or dtypes.issubdtype(aval.dtype, dtypes.extended)): @@ -191,7 +191,7 @@ def _shard_np_array(xs, shardings, layouts): if x.dtype == dtypes.float0: x = np.zeros(x.shape, dtype=np.dtype(bool)) aval = api_util.shaped_abstractify(x) - if not is_default_layout(layout, sharding, aval): + if layout is not None: results.append(api.device_put(x, Layout(layout, sharding))) else: if sharding.is_fully_replicated: @@ -1884,35 +1884,6 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "extra data movement anyway, so maybe you don't want it after all).") -@lru_cache(maxsize=2048) -def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval - ) -> DeviceLocalLayout | None: - if is_unspecified_or_auto(sharding): - return None - # TODO(yashkatariya): Figure out how layouts work with extended dtypes. - if aval is core.abstract_token or dtypes.issubdtype(aval.dtype, dtypes.extended): - return None - if not core.is_constant_shape(aval.shape): - return None - shard_shape = sharding.shard_shape(aval.shape) - d = sharding._device_assignment[0] - # If a backend doesn't implement `get_default_layout` return `None` to avoid - # cache misses. This can happen when you have `jit(f, in_shardings=s)`. On - # first call you pass it a sharded array with layout and on second call you - # pass a numpy array. The layouts should be the same to get cache hits. - try: - al = DeviceLocalLayout.from_pjrt_layout( - d.client.get_default_layout(aval.dtype, shard_shape, d)) - except: - return None - # argument does not have `.layout` property. ShapedArray, numpy array, etc - # are some examples. - if arg_layout is None: - return al if jit_in_layout is None else arg_layout # arg_layout is None - # If arg has a `.layout` property, then return device_local_layout as is. - return arg_layout.device_local_layout - - @weakref_lru_cache def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, @@ -2775,13 +2746,14 @@ class UnloadedMeshExecutable: kept_var_idx: set[int] mut: MutationData | None auto_spmd_lowering: bool - in_layouts: Sequence[DeviceLocalLayout | None] - out_layouts: Sequence[DeviceLocalLayout | None] + xla_in_layouts: Sequence[DeviceLocalLayout | None] + dispatch_in_layouts: Sequence[DeviceLocalLayout | None] + xla_out_layouts: Sequence[DeviceLocalLayout | None] all_args_info: AllArgsInfo | None pgle_profiler: profiler.PGLEProfiler | None def build_unsafe_call(self): - handle_args = InputsHandler(self.input_shardings, self.in_layouts) + handle_args = InputsHandler(self.input_shardings, self.dispatch_in_layouts) handle_outs = global_avals_to_results_handler( self.output_avals, self.output_shardings, self.committed) @@ -2797,8 +2769,8 @@ def load(self) -> MeshExecutable: self.input_avals, self.output_avals, self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, - self.in_layouts, self.out_layouts, - self.all_args_info, self) + self.xla_in_layouts, self.dispatch_in_layouts, + self.xla_out_layouts, self.all_args_info, self) @staticmethod def from_hlo(name: str, @@ -2881,8 +2853,18 @@ def from_hlo(name: str, in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( xla_executable.local_devices(), len(in_shardings), len(out_shardings)) - in_layouts, out_layouts = _get_layouts_from_executable( + # xla_in_layouts are all either None or DeviceLocalLayout. Even default + # layout are concrete layouts and they are used in `compiled.input_layouts` + # to return concrete layouts to users. + # `dispatch_in_layouts` replaces default layouts with `None` to simplify + # dispatch logic downstream. + xla_in_layouts, xla_out_layouts = _get_layouts_from_executable( xla_executable, in_layouts, out_layouts, len(ordered_effects)) + del in_layouts, out_layouts + dispatch_in_layouts = [ + None if is_default_layout(l, s, a) else l + for l, s, a, in safe_zip(xla_in_layouts, in_shardings, global_in_avals) + ] out_shardings = maybe_recover_user_shardings( in_shardings, out_shardings, global_in_avals, global_out_avals, @@ -2907,8 +2889,9 @@ def from_hlo(name: str, kept_var_idx=kept_var_idx, mut=mut, auto_spmd_lowering=auto_spmd_lowering, - in_layouts=in_layouts, - out_layouts=out_layouts, + xla_in_layouts=xla_in_layouts, + dispatch_in_layouts=dispatch_in_layouts, + xla_out_layouts=xla_out_layouts, all_args_info=all_args_info, pgle_profiler=pgle_profiler).load() @@ -2964,13 +2947,13 @@ class MeshExecutable(stages.XlaExecutable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", "out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", - "_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info", - "_unloaded_executable", + "_kept_var_idx", "_xla_in_layouts", "_dispatch_in_layouts", + "_xla_out_layouts", "_all_args_info", "_unloaded_executable", ] def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, - in_layouts, out_layouts, + xla_in_layouts, dispatch_in_layouts, xla_out_layouts, all_args_info: AllArgsInfo | None = None, unloaded_executable=None): self.xla_executable = xla_executable @@ -2984,8 +2967,9 @@ def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals, self._out_shardings = out_shardings self._auto_spmd_lowering = auto_spmd_lowering self._kept_var_idx = kept_var_idx - self._in_layouts = in_layouts - self._out_layouts = out_layouts + self._xla_in_layouts = xla_in_layouts + self._dispatch_in_layouts = dispatch_in_layouts + self._xla_out_layouts = xla_out_layouts self._all_args_info = all_args_info self._unloaded_executable = unloaded_executable @@ -3013,9 +2997,8 @@ def call(self, *args): all_arg_avals = map(xla.abstractify, kept_args) check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info) - # Check the GDA sharding and the input sharding. check_array_xla_sharding_layout_match( - args_after_dce, self._in_shardings, self._in_layouts, debug_info, + args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info, self._kept_var_idx) return self.unsafe_call(*args) # pylint: disable=not-callable @@ -3027,11 +3010,11 @@ def output_shardings(self) -> Sequence[JSharding]: def input_layouts(self): return [Layout(l, s) - for l, s in safe_zip(self._in_layouts, self._in_shardings)] + for l, s in safe_zip(self._xla_in_layouts, self._in_shardings)] def output_layouts(self): return [Layout(l, s) - for l, s in safe_zip(self._out_layouts, self._out_shardings)] + for l, s in safe_zip(self._xla_out_layouts, self._out_shardings)] def create_cpp_call(self, no_kwargs, in_tree, out_tree): if not (isinstance(self.unsafe_call, ExecuteReplicated) and @@ -3057,12 +3040,10 @@ def aot_cache_miss(*args, **kwargs): else s for s, a in zip(self._in_shardings, self.in_avals) ] - in_dlls = get_layouts_for_fasthpath_data( - self._in_layouts, in_shardings, self.in_avals) fastpath_data = MeshExecutableFastpathData( self.xla_executable, out_tree_dispatch, in_shardings, self._out_shardings, out_avals, out_committed, kept_var_bitvec, - in_dlls) + self._dispatch_in_layouts) else: fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry @@ -3084,16 +3065,6 @@ def cc_shard_arg(x, sharding, layout): # type: ignore return shard_args([sharding], [layout], [x])[0] -def get_layouts_for_fasthpath_data(in_layouts, in_shardings, in_avals): - in_dlls = [] - for l, s, a in zip(in_layouts, in_shardings, in_avals): - if is_default_layout(l, s, a): - in_dlls.append(None) - else: - in_dlls.append(l) - return in_dlls - - def check_arg_avals_for_call(ref_avals, arg_avals, jaxpr_debug_info: core.JaxprDebugInfo | None = None): if len(ref_avals) != len(arg_avals): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 197388afe84b..bcd31ec09313 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -279,12 +279,10 @@ def _get_fastpath_data( else s for s, a in zip(executable._in_shardings, executable.in_avals) ] - in_dlls = pxla.get_layouts_for_fasthpath_data( - executable._in_layouts, in_shardings, executable.in_avals) fastpath_data = pxla.MeshExecutableFastpathData( executable.xla_executable, out_tree, in_shardings, executable._out_shardings, out_avals, out_committed, kept_var_bitvec, - in_dlls) + executable._dispatch_in_layouts) else: fastpath_data = None return fastpath_data @@ -1479,10 +1477,17 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): resolved_in_layouts = [] for arg, jit_in_l, rs, aval in safe_zip( args, jit_in_layouts, resolved_in_shardings, in_avals): - arg_layout, committed = ( - pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, - rs, aval), - getattr(arg, '_committed', True)) + committed = getattr(arg, '_committed', True) + # `arg_layout` is only used for checking purposes in the `else` branch + # below. We cannot replace default layout with None to raise nicer errors. + # `dispatch_arg_layout` replaces default layouts with `None` to simplify + # dispatch and lowering logic downstream. + if hasattr(arg, 'layout'): + arg_layout = arg.layout.device_local_layout + dispatch_arg_layout = (None if pxla.is_default_layout(arg_layout, rs, aval) + else arg_layout) + else: + arg_layout, dispatch_arg_layout = None, None # Sharding can be unspecified when array is committed if it's a PmapSharding. is_pmap_sharding = (is_unspecified(rs) or isinstance(getattr(arg, 'sharding', None), PmapSharding)) @@ -1491,7 +1496,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): if is_pmap_sharding: resolved_in_layouts.append(None) else: - resolved_in_layouts.append(arg_layout) + resolved_in_layouts.append(dispatch_arg_layout) else: resolved_in_layouts.append(None) else: diff --git a/tests/layout_test.py b/tests/layout_test.py index 33f3318bf4bb..7a587d099498 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -540,7 +540,7 @@ def test_layout_donation(self): def f(x): return x - out = f(arr) + f(arr) self.assertTrue(arr.is_deleted()) def test_layout_donation_auto(self): @@ -555,7 +555,7 @@ def test_layout_donation_auto(self): def f(x): return x * x - out = f(arr) + f(arr) self.assertTrue(arr.is_deleted()) def test_layout_donation_matching_in_and_out(self): @@ -572,9 +572,27 @@ def test_layout_donation_matching_in_and_out(self): def f(x): return x * x - out = f(arr) + f(arr) self.assertTrue(arr.is_deleted()) + def test_layout_donation_mismatching_in_and_out_fails(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (16*2, 32016*2) + np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) + + custom_dll1 = DLL(major_to_minor=(1, 0), _tiling=((8,128), (2,1))) + l1 = Layout(custom_dll1, s) + arr = jax.device_put(np_inp, s) + + @partial(jax.jit, out_shardings=l1, donate_argnums=0) + def f(x): + return x * x + + sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) + f.lower(sds).compile()(arr) + self.assertFalse(arr.is_deleted()) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 63004b9be96a5c6dfc669660b8782bc45243dbe7 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 28 Aug 2024 23:46:06 +0530 Subject: [PATCH 269/702] Better docs for jnp.floor and jnp.ceil --- jax/_src/numpy/ufuncs.py | 64 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index c4f9009eb877..36ce9f5135a3 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -126,17 +126,77 @@ def positive(x: ArrayLike, /) -> Array: def sign(x: ArrayLike, /) -> Array: return lax.sign(*promote_args('sign', x)) -@implements(np.floor, module='numpy') + @partial(jit, inline=True) def floor(x: ArrayLike, /) -> Array: + """Round input to the nearest integer downwards. + + JAX implementation of :func:`numpy.floor`. + + Args: + x: input array or scalar. Must not have complex dtype. + + Returns: + An array with same shape and dtype as ``x`` containing the values rounded to + the nearest integer that is less than or equal to the value itself. + + See also: + - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + zero. + - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. + + Examples: + >>> key = jax.random.key(42) + >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(x) + [[ 1.44 -1.77 -3.07] + [ 3.86 2.25 -3.08] + [-1.55 -2.48 1.32]] + >>> jnp.floor(x) + Array([[ 1., -2., -4.], + [ 3., 2., -4.], + [-2., -3., 1.]], dtype=float32) + """ check_arraylike('floor', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): return lax.asarray(x) return lax.floor(*promote_args_inexact('floor', x)) -@implements(np.ceil, module='numpy') + @partial(jit, inline=True) def ceil(x: ArrayLike, /) -> Array: + """Round input to the nearest integer upwards. + + JAX implementation of :func:`numpy.ceil`. + + Args: + x: input array or scalar. Must not have complex dtype. + + Returns: + An array with same shape and dtype as ``x`` containing the values rounded to + the nearest integer that is greater than or equal to the value itself. + + See also: + - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + zero. + - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. + + Examples: + >>> key = jax.random.key(1) + >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(x) + [[ 2.55 -1.87 -3.76] + [ 0.48 3.85 -1.94] + [ 3.2 4.56 -1.43]] + >>> jnp.ceil(x) + Array([[ 3., -1., -3.], + [ 1., 4., -1.], + [ 4., 5., -1.]], dtype=float32) + """ check_arraylike('ceil', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): return lax.asarray(x) From 7b297912008ffe2a2b6f0417a73494bd88ccb696 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 11:19:53 -0700 Subject: [PATCH 270/702] docs: shorten/clarify some page titles Shorten the titles for derivative rules and manual parallelism, and go for canonical wording in the parallel programming intro (typically we "shard" data, and "partition" computation, as part of parallel programming). --- docs/debugging.md | 6 +++--- docs/distributed_data_loading.md | 2 +- docs/multi_process.md | 2 +- .../notebooks/Custom_derivative_rules_for_Python_code.ipynb | 2 +- docs/notebooks/Custom_derivative_rules_for_Python_code.md | 2 +- docs/notebooks/shard_map.ipynb | 2 +- docs/notebooks/shard_map.md | 2 +- docs/sharded-computation.ipynb | 2 +- docs/sharded-computation.md | 2 +- 9 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index 1e8501f99e39..d07f42da5c85 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -21,7 +21,7 @@ This section introduces you to a set of built-in JAX debugging methods — {func Let's begin with {func}`jax.debug.print`. -## JAX `debug.print` for high-level +## `jax.debug.print` for simple inspection Here is a rule of thumb: @@ -111,7 +111,7 @@ f(1, 2) To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`. -## JAX `debug.breakpoint` for `pdb`-like debugging +## `jax.debug.breakpoint` for `pdb`-like debugging **Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. @@ -160,7 +160,7 @@ f(2., 1.) # ==> No breakpoint f(2., 0.) # ==> Pauses during execution ``` -## JAX `debug.callback` for more control during debugging +## `jax.debug.callback` for more control during debugging Both {func}`jax.debug.print` and {func}`jax.debug.breakpoint` are implemented using the more flexible {func}`jax.debug.callback`, which gives greater control over the diff --git a/docs/distributed_data_loading.md b/docs/distributed_data_loading.md index 14fb1bb55c35..d7b88be44178 100644 --- a/docs/distributed_data_loading.md +++ b/docs/distributed_data_loading.md @@ -12,7 +12,7 @@ kernelspec: name: python3 --- -# Distributed data loading in multi-host / multi-process environments +# Distributed data loading diff --git a/docs/multi_process.md b/docs/multi_process.md index 7d7083bde10f..32cfae126784 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -1,4 +1,4 @@ -# Using JAX in multi-host and multi-process environments +# Multi-host and multi-process environments diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 88d446723dba..dd7a36e57079 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -6,7 +6,7 @@ "id": "LqiaKasFjH82" }, "source": [ - "# Custom derivative rules for JAX-transformable Python functions\n", + "# Custom derivative rules\n", "\n", "\n", "\n", diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index fdf4b3ed0c8d..930887af1e1b 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "LqiaKasFjH82"} -# Custom derivative rules for JAX-transformable Python functions +# Custom derivative rules diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 157d6c567b24..1315783c340c 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -5,7 +5,7 @@ "id": "41a7e222", "metadata": {}, "source": [ - "# SPMD multi-device parallelism with `shard_map`\n", + "# Manual parallelism with `shard_map`\n", "\n", "\n", "\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 21e4111d5bf9..96667e709ac6 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -14,7 +14,7 @@ kernelspec: name: python3 --- -# SPMD multi-device parallelism with `shard_map` +# Manual parallelism with `shard_map` diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index b7bc919ebcde..22d9156f607b 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -5,7 +5,7 @@ "metadata": {}, "source": [ "(sharded-computation)=\n", - "# Introduction to sharded computation\n", + "# Introduction to parallel programming\n", "\n", "\n", "\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 4f8c1b0201bc..e6e1948f9902 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -12,7 +12,7 @@ kernelspec: --- (sharded-computation)= -# Introduction to sharded computation +# Introduction to parallel programming From 180573a8c89f6dc42f42e52cd92c90bee7135daf Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 11:21:52 -0700 Subject: [PATCH 271/702] docs: tweak JAX-Toolbox url anchor for clarity --- docs/investigating_a_regression.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md index 9b712056e9bc..389cc0b5a9e8 100644 --- a/docs/investigating_a_regression.md +++ b/docs/investigating_a_regression.md @@ -25,7 +25,7 @@ Here is a suggested investigation strategy: ## Nightly investigation -This can be done by using [JAX-Toolbox nightly +This can be done by using the [NVIDIA JAX-Toolbox nightly containers](https://github.com/NVIDIA/JAX-Toolbox). - Some days, bugs prevent the container from being built, or there are temporary regressions. Just discard those days. From 82c305719f9dc2d0bf72e563d0a05e8cab195d40 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 28 Aug 2024 13:48:17 -0700 Subject: [PATCH 272/702] [pallas] Disable two very slow tests in pallas_vmap_test on CPU. These take over a minute each, causing timeouts in CI. PiperOrigin-RevId: 668591954 --- tests/pallas/pallas_vmap_test.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index 3c33702b63e3..4b3f47e6f5c1 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -21,7 +21,6 @@ from absl.testing import absltest import jax from jax import random -from jax._src.lib import xla_extension from jax._src import config from jax._src import test_util as jtu from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr @@ -208,10 +207,8 @@ def sin(x_ref, o_ref): np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) @jtu.skip_on_flag("jax_skip_slow_tests", True) + @jtu.skip_on_devices("cpu") # Test is very slow on CPU def test_small_large_vmap(self): - if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]): - self.skipTest("Test is very slow under TSAN") - # Catches https://github.com/google/jax/issues/18361 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), @@ -229,9 +226,8 @@ def add_one(x_ref, o_ref): np.testing.assert_allclose(out, out_ref) + @jtu.skip_on_devices("cpu") # Test is very slow on CPU def test_small_small_large_vmap(self): - if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]): - self.skipTest("Test is very slow under TSAN") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), From 9e06cbe3b3b41d72a8a38bcc673781567629278b Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 28 Aug 2024 13:57:26 -0700 Subject: [PATCH 273/702] [jax:pallas] Minor cleanup in Triton add lowering. PiperOrigin-RevId: 668595441 --- jax/_src/pallas/triton/lowering.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 6db00a53671e..c2009d7d24c9 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -990,21 +990,19 @@ def _minus(x: ir.Value) -> ir.Value: def _add(x: ir.Value, y: ir.Value): x_element_type = _element_type(x.type) y_element_type = _element_type(y.type) - if tt_dialect.PointerType.isinstance(y_element_type): - assert not tt_dialect.PointerType.isinstance(x_element_type) - x, y = y, x - x_element_type, y_element_type = y_element_type, x_element_type if tt_dialect.PointerType.isinstance(x_element_type): + assert not tt_dialect.PointerType.isinstance(y_element_type) return tt_dialect.addptr(x.type, x, y) + if tt_dialect.PointerType.isinstance(y_element_type): + return tt_dialect.addptr(y.type, y, x) assert x.type == y.type, (str(x.type), str(y.type)) if isinstance(x_element_type, ir.IntegerType): return arith_dialect.addi(x, y) - elif isinstance(x_element_type, ir.FloatType): + if isinstance(x_element_type, ir.FloatType): return arith_dialect.addf(x, y) - else: - raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}") + raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}") def _sub(x: ir.Value, y: ir.Value) -> ir.Value: From 26619e2d10848a9047030b008e9ac6b5da86e05f Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 28 Aug 2024 13:58:30 -0700 Subject: [PATCH 274/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/09982ee2b99dc895ea4d038610f3dcfc2b10a9df. PiperOrigin-RevId: 668595829 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index fa41f6f7d101..d9cd8c56672f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "4170b9b1a900e8fcdee42e0810aacb7e0618701c" -XLA_SHA256 = "5b8bb058d802e8fa83aa70210637d2f8d903348b3f1b5f7e315a04e144f6c123" +XLA_COMMIT = "09982ee2b99dc895ea4d038610f3dcfc2b10a9df" +XLA_SHA256 = "42e3e82c95b095d0fb39c83c497886f3353292381e6aa6403b7e616e584ea32c" def repo(): tf_http_archive( From 1cba0970d83025574ef991d0d60a38ca8060046d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 28 Aug 2024 10:13:58 -0700 Subject: [PATCH 275/702] refactor lax.loops to avoid importing from jax.numpy --- jax/_src/lax/control_flow/loops.py | 4 +- jax/_src/lax/other.py | 66 +++++++++++++++++++++++++---- jax/_src/lax/windowed_reductions.py | 2 +- jax/_src/numpy/ufuncs.py | 45 ++++++++++---------- 4 files changed, 82 insertions(+), 35 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index f7f09424a9e8..5084ec43c2fa 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -50,9 +50,9 @@ _abstractify, _avals_short, _initial_style_jaxpr, _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, _typecheck_param) +from jax._src.lax.other import logaddexp from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy.ufuncs import logaddexp from jax._src.state import discharge as state_discharge from jax._src.traceback_util import api_boundary from jax._src.tree_util import equality_errors @@ -2170,7 +2170,7 @@ def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) return (new_keys, bits), (0, 0) -batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule +batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore[has-type] ### associative_scan diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 7bdfabb92df8..45f9167ab807 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -19,9 +19,11 @@ from typing import Any import jax -from jax._src.numpy import lax_numpy as jnp +from jax._src import dtypes from jax._src.lax import lax from jax._src.lax import convolution +from jax._src import util +import numpy as np DType = Any @@ -88,7 +90,7 @@ def conv_general_dilated_patches( (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`). """ - lhs_array = jnp.asarray(lhs) + lhs_array = lax.asarray(lhs) filter_shape = tuple(filter_shape) dimension_numbers = convolution.conv_dimension_numbers( lhs_array.shape, (1, 1) + filter_shape, dimension_numbers) @@ -99,11 +101,10 @@ def conv_general_dilated_patches( n_channels = lhs_array.shape[lhs_spec[1]] # Move separate `lhs` spatial locations into separate `rhs` channels. - rhs = jnp.eye(spatial_size, dtype=lhs_array.dtype).reshape(filter_shape * 2) - - rhs = rhs.reshape((spatial_size, 1) + filter_shape) - rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1)) - rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) + rhs = lax._eye(lhs_array.dtype, shape=(spatial_size, spatial_size), offset=0) + rhs = lax.broadcast_in_dim(rhs, (n_channels, spatial_size, spatial_size), (1, 2)) + rhs = lax.reshape(rhs, (n_channels * spatial_size, 1, *filter_shape)) + rhs = util.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) out = convolution.conv_general_dilated( lhs=lhs_array, @@ -200,7 +201,7 @@ def conv_general_dilated_local( If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')` (for a 2D convolution). """ - lhs_array = jnp.asarray(lhs) + lhs_array = lax.asarray(lhs) c_precision = lax.canonicalize_precision(precision) lhs_precision = ( @@ -234,5 +235,52 @@ def conv_general_dilated_local( dn = ((lhs_c_dims, rhs_c_dims), (lhs_b_dims, rhs_b_dims)) out = lax.dot_general(patches, rhs, dimension_numbers=dn, precision=precision) - out = jnp.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1])) + out = util.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1])) return out + + +def _wrap_between(x, _a): + """Wraps `x` between `[-a, a]`.""" + a = lax._const(x, _a) + two_a = lax._const(x, 2 * _a) + zero = lax._const(x, 0) + rem = lax.rem(lax.add(x, a), two_a) + rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) + return lax.sub(rem, a) + + +def _replace_inf(x: jax.Array) -> jax.Array: + re_x = lax.real(x) if dtypes.issubdtype(x.dtype, np.complexfloating) else x + inf = lax._const(re_x, float('inf')) + return lax.select(lax.eq(re_x, inf), lax._zeros(x), x) + + +@jax.custom_jvp +def logaddexp(x1: jax.typing.ArrayLike, x2: jax.typing.ArrayLike, /) -> jax.Array: + """Compute log(exp(x1) + exp(x2)) avoiding overflow.""" + x1_arr = lax.asarray(x1) + x2_arr = lax.asarray(x2) + assert x1_arr.dtype == x2_arr.dtype + + amax = lax.max(x1_arr, x2_arr) + if dtypes.isdtype(x1_arr.dtype, "real floating"): + delta = lax.sub(x1_arr, x2_arr) + return lax.select(lax._isnan(delta), + lax.add(x1_arr, x2_arr), # NaNs or infinities of the same sign. + lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) + elif dtypes.isdtype(x1_arr.dtype, "complex floating"): + delta = lax.sub(lax.add(x1, x2), lax.mul(amax, lax._const(amax, 2))) + out = lax.add(amax, lax.log1p(lax.exp(delta))) + return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) + else: + raise ValueError(f"logaddexp requires floating-point or complex inputs; got {x1_arr.dtype}") + + +@logaddexp.defjvp +def _logaddexp_jvp(primals, tangents): + x1, x2 = primals + t1, t2 = tangents + primal_out = logaddexp(x1, x2) + tangent_out = lax.add(lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + return primal_out, tangent_out diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 5d6eddad0e4d..dd8e664a095a 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -30,9 +30,9 @@ from jax._src.lax import convolution from jax._src.lax import lax from jax._src.lax import slicing +from jax._src.lax.other import logaddexp from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy.ufuncs import logaddexp from jax._src.typing import Array import numpy as np from jax._src.core import ClosedJaxpr diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 36ce9f5135a3..dfeff38df0fe 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -29,6 +29,7 @@ from jax._src.api import jit from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax +from jax._src.lax import other as lax_other from jax._src.typing import Array, ArrayLike from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, @@ -857,21 +858,30 @@ def _pow_int_int(x1, x2): return acc -@custom_jvp -@implements(np.logaddexp, module='numpy') @jit def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. + + JAX implementation of :func:`numpy.logaddexp` + + Args: + x1: input array + x2: input array + + Returns: + array containing the result. + + Examples: + + >>> x1 = jnp.array([1, 2, 3]) + >>> x2 = jnp.array([4, 5, 6]) + >>> result1 = jnp.logaddexp(x1, x2) + >>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2)) + >>> print(jnp.allclose(result1, result2)) + True + """ x1, x2 = promote_args_inexact("logaddexp", x1, x2) - amax = lax.max(x1, x2) - if dtypes.issubdtype(x1.dtype, np.floating): - delta = lax.sub(x1, x2) - return lax.select(lax._isnan(delta), - lax.add(x1, x2), # NaNs or infinities of the same sign. - lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) - else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) - out = lax.add(amax, lax.log1p(lax.exp(delta))) - return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) + return lax_other.logaddexp(x1, x2) def _wrap_between(x, _a): @@ -884,17 +894,6 @@ def _wrap_between(x, _a): return lax.sub(rem, a) -@logaddexp.defjvp -def _logaddexp_jvp(primals, tangents): - x1, x2 = primals - t1, t2 = tangents - x1, x2, t1, t2 = promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2) - primal_out = logaddexp(x1, x2) - tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) - return primal_out, tangent_out - - @custom_jvp @implements(np.logaddexp2, module='numpy') @jit From b01075054a0f28f115a7323127e17e64583b53a5 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 28 Aug 2024 15:00:04 -0700 Subject: [PATCH 276/702] [Mosaic TPU] Support memref bitcast. If element bitwidth changes, the ratio of bitwidth is multiplied to the 2nd minormost dim size and the leading dim in tiling. For example, we can bitcast Memref<8x128xf32> with tiling (8, 128) to Memref<16x128xi16> with tiling (16, 128). PiperOrigin-RevId: 668619683 --- jaxlib/mosaic/dialect/tpu/tpu.td | 10 ++++ jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 87 ++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index b1a9ac910998..c1fba60f4cc5 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -526,6 +526,16 @@ def TPU_MemRefReshapeOp : TPU_Op<"memref_reshape", [Pure]> { let hasCanonicalizeMethod = 1; } +def TPU_MemRefBitcastOp : TPU_Op<"memref_bitcast", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + def TPU_ReinterpretCastOp : TPU_Op<"reinterpret_cast", [Pure]> { let arguments = (ins AnyMemRef:$input); let results = (outs AnyMemRef:$result); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 5baec61ad138..ff349160dc50 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -303,6 +303,93 @@ LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op, return success(); } +LogicalResult MemRefBitcastOp::verify() { + auto src_ty = getMemRefType(getInput()); + auto tgt_ty = getType(); + if (tgt_ty.getMemorySpace() != nullptr && + tgt_ty.getMemorySpace() != src_ty.getMemorySpace()) { + return emitOpError("Memory spaces do not match."); + } + if (src_ty.getRank() != tgt_ty.getRank()) { + return emitOpError("Ranks do not match."); + } + if (src_ty.getRank() <= 1) { + return emitOpError("Not implemented: 1d memref bitcast."); + } + auto src_bitwidth = src_ty.getElementTypeBitWidth(); + auto tgt_bitwidth = tgt_ty.getElementTypeBitWidth(); + for (int i = 0; i < src_ty.getRank(); ++i) { + auto src_dim_size = src_ty.getDimSize(i); + auto tgt_dim_size = tgt_ty.getDimSize(i); + if (i == src_ty.getRank() - 2) { + src_dim_size *= src_bitwidth; + tgt_dim_size *= tgt_bitwidth; + } + if (src_dim_size != tgt_dim_size) { + return emitOpError( + "Expected the same dim size on the 2nd minormost dim: ") + << src_dim_size << " vs " << tgt_dim_size; + } + } + // Source and target attributes may be different before propagation is done by + // the canonicalizer, so we allow this when attributes are "unset" in the + // target type. + auto tgt_layout = dyn_cast(tgt_ty.getLayout()); + if (!tgt_layout) { + return success(); + } + auto src_layout = dyn_cast(src_ty.getLayout()); + if (!src_layout) { + return emitOpError("Expected a tiled layout for the input memref."); + } + // TODO(jevinjiang): verify memref tiling is valid. Here we just assume the + // source and target tilings are valid. + auto src_tile = src_layout.getTiles().front().dimensions(); + auto tgt_tile = tgt_layout.getTiles().front().dimensions(); + if (src_tile[0] * src_bitwidth != tgt_tile[0] * tgt_bitwidth) { + return emitOpError("Invalid memref bitcast."); + } + return success(); +} + +LogicalResult MemRefBitcastOp::canonicalize(MemRefBitcastOp op, + PatternRewriter &rewriter) { + auto src_ty = op.getInput().getType(); + auto dst_ty = op.getType(); + if (src_ty == dst_ty) { + rewriter.replaceOp(op, op.getInput()); + return success(); + } + auto erase_layout_op = op.getInput().getDefiningOp(); + if (!erase_layout_op) { + return failure(); + } + auto src_bitwidth = src_ty.getElementTypeBitWidth(); + auto tgt_bitwidth = dst_ty.getElementTypeBitWidth(); + auto layout_ref = erase_layout_op.getOperand(); + auto layout_ty = layout_ref.getType(); + auto layout = cast(layout_ty.getLayout()); + CHECK(!layout.getTiles().empty()); + auto tile = layout.getTiles().front().dimensions(); + if (tile[0] * src_bitwidth % tgt_bitwidth != 0) { + return failure(); + } + SmallVector new_tiles = + {xla::Tile({tile[0] * src_bitwidth / tgt_bitwidth, 128})}; + if (tgt_bitwidth < 32) { + new_tiles.push_back(xla::Tile({32 / tgt_bitwidth, 1})); + } + auto new_layout = tpu::TiledLayoutAttr::get(src_ty.getContext(), new_tiles, + layout.getTileStrides()); + auto new_result_ty = + MemRefType::get(dst_ty.getShape(), dst_ty.getElementType(), new_layout, + layout_ty.getMemorySpace()); + auto bitcast = + rewriter.create(op.getLoc(), new_result_ty, layout_ref); + rewriter.replaceOpWithNewOp(op, op.getType(), bitcast); + return success(); +} + template LogicalResult verifyStridedOp(Op op, MemRefType memref_ty, VectorType vector_ty) { From a3cccd34e2da316e358415f6961b9b9c7e18d233 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 28 Aug 2024 15:32:59 -0700 Subject: [PATCH 277/702] [Mosaic TPU] Print expected Mosaic version after finding unsupported version. PiperOrigin-RevId: 668632116 --- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 000bfe3eaea2..3f6050f31dab 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -166,8 +166,8 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { return; } if (version_attr.getInt() > kVersion) { - module->emitError("Unsupported Mosaic version: ") - << version_attr.getInt(); + module->emitError("Unsupported Mosaic version: expected <= ") + << kVersion << " but got " << version_attr.getInt(); signalPassFailure(); return; } From 28a65589f7489d9d0a1e62ca502e7c70e00b7e80 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 28 Aug 2024 15:35:54 -0700 Subject: [PATCH 278/702] `lazy_loader.attach` now only imports the submodule once on first access This is a partial roll-forward of #22998. PiperOrigin-RevId: 668633307 --- jax/_src/lazy_loader.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/jax/_src/lazy_loader.py b/jax/_src/lazy_loader.py index cf6e68e49c81..14822bff3eff 100644 --- a/jax/_src/lazy_loader.py +++ b/jax/_src/lazy_loader.py @@ -16,6 +16,7 @@ from collections.abc import Callable, Sequence import importlib +import sys from typing import Any @@ -26,17 +27,27 @@ def attach(package_name: str, submodules: Sequence[str]) -> tuple[ ]: """Lazily loads submodules of a package. - Example use: - ``` - __getattr__, __dir__, __all__ = lazy_loader.attach(__name__, ["sub1", "sub2"]) - ``` + Returns: + A tuple of ``__getattr__``, ``__dir__`` function and ``__all__`` -- + a list of available global names, which can be used to replace the + corresponding definitions in the package. + + Raises: + RuntimeError: If the ``__name__`` of the caller cannot be determined. """ + owner_name = sys._getframe(1).f_globals.get("__name__") + if owner_name is None: + raise RuntimeError("Cannot determine the ``__name__`` of the caller.") - __all__: list[str] = list(submodules) + __all__ = list(submodules) def __getattr__(name: str) -> Any: if name in submodules: - return importlib.import_module(f"{package_name}.{name}") + value = importlib.import_module(f"{package_name}.{name}") + # Update module-level globals to avoid calling ``__getattr__`` again + # for this ``name``. + setattr(sys.modules[owner_name], name, value) + return value raise AttributeError(f"module '{package_name}' has no attribute '{name}") def __dir__() -> list[str]: From 2c11a91d98c72122c2ad869080d31a83e82d2935 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 28 Aug 2024 15:46:53 -0700 Subject: [PATCH 279/702] [Pallas TPU] Fix itemsize check for int4 in bitcast lowering. PiperOrigin-RevId: 668637257 --- jax/_src/pallas/mosaic/lowering.py | 4 +++- jax/_src/pallas/mosaic/primitives.py | 21 ++++++++++++++------- jax/_src/pallas/utils.py | 4 ++++ 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 11ef428ae279..6a9fe216d4b1 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2546,7 +2546,9 @@ def _bitcast_convert_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype): (in_aval, ) = ctx.avals_in (out_aval,) = ctx.avals_out - if in_aval.dtype.itemsize != new_dtype.itemsize: + old_bitwidth = pallas_utils.dtype_bitwidth(in_aval.dtype) + new_bitwidth = pallas_utils.dtype_bitwidth(new_dtype) + if old_bitwidth != new_bitwidth: raise NotImplementedError("Changing bitwidths not supported.") return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index b7d02e5ccf8b..7b4faa002d1a 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -26,12 +26,13 @@ from jax._src import state from jax._src import tree_util from jax._src import util -from jax._src.state import indexing -from jax._src.state import primitives as sp from jax._src.interpreters import mlir from jax._src.pallas import core as pl_core +from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core from jax._src.state import discharge as state_discharge +from jax._src.state import indexing +from jax._src.state import primitives as sp from jax._src.typing import DTypeLike import jax.numpy as jnp @@ -65,7 +66,9 @@ def bitcast(x, ty: DTypeLike): ty = dtypes.canonicalize_dtype(ty) if len(x.shape) < 2: raise ValueError("Not implemented: bitcast 1D") - if x.shape[-2] * x.dtype.itemsize % ty.itemsize: + src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype) + dst_bitwidth = pallas_utils.dtype_bitwidth(ty) + if x.shape[-2] * src_bitwidth % dst_bitwidth: raise ValueError( "Not implemented: the 2nd minor dim can not be perfectly packed or" " unpacked" @@ -76,19 +79,23 @@ def bitcast(x, ty: DTypeLike): @bitcast_p.def_abstract_eval def _bitcast_abstract_eval(x, *, ty): shape = list(x.shape) - shape[-2] = shape[-2] * x.dtype.itemsize // ty.itemsize + src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype) + dst_bitwidth = pallas_utils.dtype_bitwidth(ty) + shape[-2] = shape[-2] * src_bitwidth // dst_bitwidth return jax_core.ShapedArray(shape, ty) def _bitcast_lowering_rule(ctx: mlir.LoweringRuleContext, x, *, ty): def _bitcast(x): - if x.dtype.itemsize < ty.itemsize: + src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype) + dst_bitwidth = pallas_utils.dtype_bitwidth(ty) + if src_bitwidth < dst_bitwidth: *leading, m, n = x.shape - packing = ty.itemsize // x.dtype.itemsize + packing = dst_bitwidth // src_bitwidth x = x.reshape(*leading, m // packing, packing, n) x = jnp.swapaxes(x, -1, -2) return jax.lax.bitcast_convert_type(x, ty) - if x.dtype.itemsize > ty.itemsize: + if src_bitwidth > dst_bitwidth: y = jax.lax.bitcast_convert_type(x, ty) *leading, m, n, packing = y.shape return jnp.swapaxes(y, -1, -2).reshape(*leading, m * packing, n) diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 41466be0822d..295134bd9855 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -71,6 +71,10 @@ def next_power_of_2(x: int) -> int: raise ValueError("`next_power_of_2` requires a non-negative integer.") return 1 if x == 0 else 2 ** (x - 1).bit_length() +def dtype_bitwidth(dtype: np.dtype | jnp.dtype) -> int: + if isinstance(dtype, jnp.integer): + return jnp.iinfo(dtype).bits + return np.dtype(dtype).itemsize * 8 def pattern_match_scan_to_fori_loop( jaxpr: jax_core.Jaxpr, num_consts: int, num_carry: int From f4793501168e2f25c584f751750f0196cfa03b2a Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 29 Aug 2024 00:50:37 +0100 Subject: [PATCH 280/702] Disable implicit type conversion during type matching --- jax/_src/pallas/triton/lowering.py | 2 +- tests/pallas/ops_test.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 6db00a53671e..e4e2dc0791ad 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -564,7 +564,7 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool: if len(avals) != len(self.arg_types): return False return all( - aval.weak_type or aval.dtype.name == arg_type + aval.dtype.name == arg_type for aval, arg_type in zip(avals, self.arg_types) ) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index cf247ac3f6a4..ae099abc0eb3 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -739,6 +739,17 @@ def kernel(x_ref, o_ref): x = jnp.array([0.42, 2.4]).astype(dtype) np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + def test_abs_weak_type(self): + # see https://github.com/google/jax/issues/23191 + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.abs(x_ref[...]) + + x = jnp.broadcast_to(-3.2, (4, 4)) # sets `weak_type` to `True` + np.testing.assert_allclose(kernel(x), jnp.abs(x), rtol=1e-6) + @parameterized.parameters( ("float32", "int32"), ("float64", "int32"), From 57a9b1807f90bc231b4c165ba6cfe5a81542be50 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 28 Aug 2024 17:10:42 -0700 Subject: [PATCH 281/702] Skip test_layout_donation_mismatching_in_and_out_fails on CPU and GPU PiperOrigin-RevId: 668666057 --- tests/layout_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/layout_test.py b/tests/layout_test.py index 7a587d099498..0a9a72e8f48b 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -575,6 +575,7 @@ def f(x): f(arr) self.assertTrue(arr.is_deleted()) + @jtu.skip_on_devices('cpu', 'gpu') def test_layout_donation_mismatching_in_and_out_fails(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) From 1dff3a2c71275c083cc3feb4f3526be50d2ba06c Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 28 Aug 2024 17:58:42 -0700 Subject: [PATCH 282/702] [Pallas GPU] Add 32-bit lowering rule for `lax.erf_inv` Add 32-bit lowering rule for `lax.erf_inv` for Pallas GPU, and move the original TPU test case into the general test PiperOrigin-RevId: 668681910 --- jax/_src/pallas/mosaic/lowering.py | 29 +---------------- jax/_src/pallas/triton/lowering.py | 51 +++++++++++++++++++++--------- jax/_src/pallas/utils.py | 27 ++++++++++++++++ tests/pallas/ops_test.py | 31 +++++++++--------- 4 files changed, 79 insertions(+), 59 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 6a9fe216d4b1..c8f910e185f3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2497,37 +2497,10 @@ def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d): skip_mlir_conversions.add(lax.shift_right_logical_p) -# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802 -def _erf_inv_32_helper(x): - k_degree = 9 - w_lt_5_constants = [ - 2.81022636e-08, 3.43273939e-07, -3.5233877e-06, - -4.39150654e-06, 0.00021858087, -0.00125372503, - -0.00417768164, 0.246640727, 1.50140941, - ] - w_gt_5_constants = [ - -0.000200214257, 0.000100950558, 0.00134934322, - -0.00367342844, 0.00573950773, -0.0076224613, - 0.00943887047, 1.00167406, 2.83297682, - ] - - w = -jnp.log1p(x * -x) - w_lt_5 = w < 5.0 - - w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0) - - p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0]) - for i in range(1, k_degree): - c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i]) - p = c + p * w - - return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x) - - def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x): (x_aval,) = ctx.avals_in if x_aval.dtype == jnp.float32: - return lower_fun(_erf_inv_32_helper, multiple_results=False)(ctx, x) + return lower_fun(pallas_utils.erf_inv_32_lowering_helper, multiple_results=False)(ctx, x) else: raise NotImplementedError diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 44b198c0f9eb..4057b125fcdc 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -404,6 +404,21 @@ def write_env(var: jax_core.Var, val): return map(read_env, jaxpr.outvars) +def lower_fun( + fun: Callable[..., Any], *, multiple_results: bool +) -> Callable[..., Any]: + fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) + + def f_lowered(ctx: LoweringRuleContext, *args, **params): + wrapped_fun = lu.wrap_init(fn, params) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) + out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) + return out if multiple_results else out[0] + + return f_lowered + + # # Primitive lowering rules # ## Programming model primitives @@ -978,6 +993,27 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): _Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"), ], ), + lax.erf_inv_p: _make_dispatch_table( + "erf_inv", + cuda=[ + _Fallback( + ["float32"], + lower_fun( + pallas_utils.erf_inv_32_lowering_helper, + multiple_results=False, + ), + ), + ], + rocm=[ + _Fallback( + ["float32"], + lower_fun( + pallas_utils.erf_inv_32_lowering_helper, + multiple_results=False, + ), + ), + ], + ), }) @@ -1255,21 +1291,6 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): return acc -def lower_fun( - fun: Callable[..., Any], *, multiple_results: bool -) -> Callable[..., Any]: - fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) - - def f_lowered(ctx: LoweringRuleContext, *args, **params): - wrapped_fun = lu.wrap_init(fn, params) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) - jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) - out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) - return out if multiple_results else out[0] - - return f_lowered - - _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 295134bd9855..6fc816e27b53 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -183,3 +183,30 @@ def pattern_match_while_to_fori_loop( outvars=new_outvars, ) return jaxpr, None + + +# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802 +def erf_inv_32_lowering_helper(x): + k_degree = 9 + w_lt_5_constants = [ + 2.81022636e-08, 3.43273939e-07, -3.5233877e-06, + -4.39150654e-06, 0.00021858087, -0.00125372503, + -0.00417768164, 0.246640727, 1.50140941, + ] + w_gt_5_constants = [ + -0.000200214257, 0.000100950558, 0.00134934322, + -0.00367342844, 0.00573950773, -0.0076224613, + 0.00943887047, 1.00167406, 2.83297682, + ] + + w = -jnp.log1p(x * -x) + w_lt_5 = w < 5.0 + + w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0) + + p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0]) + for i in range(1, k_degree): + c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i]) + p = c + p * w + + return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index b2171c46210d..ab7ffcb480b2 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1513,6 +1513,21 @@ def reduce(x_ref, y_ref): y_ref = jnp.cumsum(x, axis=axis) np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) + @parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4]) + def test_erf_inv(self, x): + @functools.partial( + self.pallas_call, + # TODO(ayx): add float64 support for `erf_inv` + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.erf_inv(x_ref[...]) + + x = jnp.full((8, 128), x) + out = kernel(x) + expected = lax.erf_inv(x) + np.testing.assert_array_equal(out, expected) + class OpsExtraInterpretTest(OpsExtraTest): INTERPRET = True @@ -1583,22 +1598,6 @@ def setUp(self): super().setUp() - @parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4]) - def test_erf_inv(self, x): - @jax.jit - @functools.partial( - pl.pallas_call, - # TODO(ayx): add float64 support for `erf_inv` - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - ) - def kernel(x_ref, o_ref): - o_ref[...] = lax.erf_inv(x_ref[...]) - - x = jnp.full((8, 128), x) - out = kernel(x) - expected = lax.erf_inv(x) - np.testing.assert_array_equal(out, expected) - SIGN_PARAMS = [ (jnp.int32, (-3, 0, 5)), (jnp.uint32, (0, 5)), From dd6f0e2e2ed13869eb7641ebff854ff5e8077e82 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 29 Aug 2024 08:35:00 -0700 Subject: [PATCH 283/702] Add weak_type to ShapeDtypeStruct because jax.Array also has it and SDS is a duck of jax.Array This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss. Fixes: https://github.com/google/jax/issues/23302 PiperOrigin-RevId: 668949430 --- jax/_src/api.py | 23 ++++++++++++++--------- jax/_src/pjit.py | 3 ++- tests/api_test.py | 12 ++++++++++++ 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index f83a8d73165d..aacb6b55618a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2693,10 +2693,11 @@ class ShapeDtypeStruct: dtype: a dtype-like object sharding: (optional) a :class:`jax.Sharding` object """ - __slots__ = ["shape", "dtype", "sharding", "_dll"] + __slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type"] named_shape = {} # type: ignore - def __init__(self, shape, dtype, named_shape=None, sharding=None): + def __init__(self, shape, dtype, named_shape=None, sharding=None, + weak_type=False): del named_shape # ignored, vestigial self.shape = tuple(shape) if dtype is None: @@ -2714,6 +2715,7 @@ def __init__(self, shape, dtype, named_shape=None, sharding=None): f" layout in a `ShapeDtypeStruct`. Got {sharding}") self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None + self.weak_type = weak_type size = property(lambda self: math.prod(self.shape)) ndim = property(lambda self: len(self.shape)) @@ -2731,8 +2733,9 @@ def __len__(self): def __repr__(self): sh = f", sharding={self.sharding}" if self.sharding is not None else "" l = f", layout={self.layout}" if self._dll is not None else "" + wt = f", weak_type={self.weak_type}" if self.weak_type else "" return (f"{type(self).__name__}(shape={self.shape}, " - f"dtype={self.dtype.name}{sh}{l})") + f"dtype={self.dtype.name}{sh}{l}{wt})") __str__ = __repr__ @@ -2740,17 +2743,19 @@ def __eq__(self, other): if not isinstance(other, ShapeDtypeStruct): return False else: - return ((other.shape, other.dtype, other.sharding, other.layout) == - (self.shape, self.dtype, self.sharding, self.layout)) + return ((self.shape, self.dtype, self.sharding, self.layout, self.weak_type) == + (other.shape, other.dtype, other.sharding, other.layout, other.weak_type)) def __hash__(self): # TODO(frostig): avoid the conversion from dict by addressing # https://github.com/google/jax/issues/8182 - return hash((self.shape, self.dtype, self.sharding, self.layout)) + return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type)) -core.pytype_aval_mappings[ShapeDtypeStruct] = ( - lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), - weak_type=False)) +def _sds_aval_mapping(x): + return ShapedArray( + x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), + weak_type=x.weak_type) +core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping @api_boundary diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index bcd31ec09313..1c93850f87b6 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -524,7 +524,8 @@ def eval_shape(*args, **kwargs): p, _ = _infer_params(fun, jit_info, args, kwargs) out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']] # TODO(yashkatariya): Add `Layout` to SDS. - out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s) + out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, + weak_type=x.weak_type) for x, s in zip(p.params['jaxpr'].out_avals, out_s)] return tree_unflatten(p.out_tree, out) diff --git a/tests/api_test.py b/tests/api_test.py index 3d3d08c092c3..1a119846be9c 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4099,6 +4099,18 @@ def __jax_array__(self): a2 = jnp.array(((x, x), [x, x])) self.assertAllClose(np.array(((1, 1), (1, 1))), a2) + def test_eval_shape_weak_type(self): + # https://github.com/google/jax/issues/23302 + arr = jax.numpy.array(1) + + with jtu.count_jit_tracing_cache_miss() as count: + jax.eval_shape(jax.numpy.array, 1) + out = jax.eval_shape(jax.numpy.array, 1) + + self.assertEqual(count[0], 1) + self.assertTrue(out.weak_type) + self.assertEqual(out.weak_type, arr.weak_type) + def test_dunder_jax_array_bug(self): @jax.tree_util.register_pytree_node_class class A: From 7dd9adba05b34552950591ef0b3eb7495ced82e9 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 29 Aug 2024 09:07:16 -0700 Subject: [PATCH 284/702] Fixed stack-use-after-scope in Mosaic GPU PiperOrigin-RevId: 668958750 --- jaxlib/mosaic/gpu/custom_call.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 8fdd34cd91dc..d9b1e0775ecc 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -337,8 +337,10 @@ absl::StatusOr> Compile( } // Create a transformer to run all LLVM optimization passes at the // specified optimization level. + auto transformer = mlir::makeOptimizingTransformer( + /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/nullptr); mlir::ExecutionEngineOptions options; - options.transformer = mlir::makeOptimizingTransformer(3, 0, nullptr); + options.transformer = transformer; options.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; options.sharedLibPaths = runtime_lib; auto maybe_execution_engine = mlir::ExecutionEngine::create(module, options); From b615266175effe4aefeb903620a19f3719a604da Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 29 Aug 2024 09:42:35 -0700 Subject: [PATCH 285/702] Reverts 82c9da020a78997862a8f7ccd494bed363f7ed01 PiperOrigin-RevId: 668969133 --- jax/_src/api.py | 6 +- jax/_src/interpreters/pxla.py | 41 +--------- jax/_src/pjit.py | 117 +++++++++------------------- jax/experimental/multihost_utils.py | 11 +-- tests/pjit_test.py | 32 +++----- 5 files changed, 56 insertions(+), 151 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index aacb6b55618a..d19f751f1251 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2970,8 +2970,7 @@ def clear_backends(): pjit._infer_params_cached.cache_clear() pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error - pjit._cpp_pjit_cache_fun_only.clear() - pjit._cpp_pjit_cache_explicit_attributes.clear() + pjit._cpp_pjit_cache.clear() xc._xla.PjitFunctionCache.clear_all() @atexit.register @@ -2999,8 +2998,7 @@ def clear_caches(): util.clear_all_weakref_lru_caches() # Clear all C++ compiled executable caches for pjit - pjit._cpp_pjit_cache_fun_only.clear() - pjit._cpp_pjit_cache_explicit_attributes.clear() + pjit._cpp_pjit_cache.clear() pjit._infer_params_cached.cache_clear() xc._xla.PjitFunctionCache.clear_all() diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 4e0986cfaf36..7db928f12704 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -22,7 +22,6 @@ from collections.abc import Callable, Sequence, Iterable, Iterator import dataclasses from functools import partial, lru_cache, cached_property -import functools import itertools as it import logging import math @@ -90,7 +89,6 @@ class WeakRefList(list): logger = logging.getLogger(__name__) Index = Union[int, slice, tuple[Union[int, slice], ...]] -PyTreeDef = tree_util.PyTreeDef NoSharding = sharding_specs.NoSharding Chunked = sharding_specs.Chunked @@ -2907,34 +2905,6 @@ class MeshExecutableFastpathData(NamedTuple): in_device_local_layouts: Sequence[DeviceLocalLayout | None] -@dataclasses.dataclass(frozen=True, kw_only=True) -class JitGlobalCppCacheKeys: - donate_argnums: tuple[int, ...] | None = None - donate_argnames: tuple[str, ...] | None = None - device: xc.Device | None = None - backend: str | None = None - in_shardings_treedef: PyTreeDef | None = None - in_shardings_leaves: tuple[Any, ...] | None = None - out_shardings_treedef: PyTreeDef | None = None - out_shardings_leaves: tuple[Any, ...] | None = None - in_layouts_treedef: PyTreeDef | None = None - in_layouts_leaves: tuple[Any, ...] | None = None - out_layouts_treedef: PyTreeDef | None = None - out_layouts_leaves: tuple[Any, ...] | None = None - use_resource_env: bool = False - - @functools.cached_property - def contains_explicit_attributes(self): - return (self.donate_argnums is not None or - self.donate_argnames is not None or - self.device is not None or - self.backend is not None or - any(not is_unspecified(i) for i in self.in_shardings_leaves) or - any(not is_unspecified(o) for o in self.out_shardings_leaves) or - any(i is not None for i in self.in_layouts_leaves) or - any(o is not None for o in self.out_layouts_leaves)) - - def reflatten_outputs_for_dispatch(out_tree, out_flat): # We arrive at dispatch having flattened according to the default # pytree registry, but we want to re-flatten according to our @@ -3048,14 +3018,9 @@ def aot_cache_miss(*args, **kwargs): fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry - if xla_extension_version >= 283: - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], - JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) - else: - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, cc_shard_arg) + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], [], + tree_util.dispatch_registry, cc_shard_arg) if xla_extension_version < 282: def cc_shard_arg(x, sharding): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 1c93850f87b6..ed5b825c62b4 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -63,7 +63,6 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src import sharding from jax._src.mesh import AbstractMesh from jax._src.sharding_impls import ( @@ -166,6 +165,7 @@ class PjitInfo(NamedTuple): keep_unused: bool inline: bool abstracted_axes: Any | None + has_explicit_sharding: bool use_resource_env: bool # False for jit, True for pjit # Hash and compare PjitInfo by identity when used as a cache key. @@ -312,39 +312,14 @@ def _cpp_pjit_evict_fn(self): # The entries are doubled here from the default 4096 because _pjit_call_impl # also has a cpp dispatch path and that would double the number of entries in # the global shared cache. -# This cache is only used for jit's with only fun. For example: jax.jit(f) -_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192) +_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192) -# This cache is used for jit where extra arguments are defined other than the -# fun. For example: jax.jit(f, donate_argnums=...) OR -# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the -# capacity might get full very fast because of all the jitted function in JAX -# which might evict train_step for example. -_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192) - -if xla_extension_version < 283: - def _get_cpp_global_cache(pjit_has_explicit_sharding): - if pjit_has_explicit_sharding: - return xc._xla.PjitFunctionCache() - else: - return _cpp_pjit_cache_fun_only - - def _pjit_explicit_sharding_and_layout( - in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, - device, backend) -> bool: - return (device is not None or - backend is not None or - any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(o) for o in out_shardings_flat) or - any(i is not None for i in in_layouts_flat) or - any(o is not None for o in out_layouts_flat)) -else: - def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore - if contains_explicit_attributes: - return _cpp_pjit_cache_explicit_attributes - else: - return _cpp_pjit_cache_fun_only +def _get_cpp_global_cache(pjit_has_explicit_sharding): + if pjit_has_explicit_sharding: + return xc._xla.PjitFunctionCache() + else: + return _cpp_pjit_cache def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @@ -365,35 +340,11 @@ def cache_miss(*args, **kwargs): return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) - if xla_extension_version >= 283: - cache_key = pxla.JitGlobalCppCacheKeys( - donate_argnums=jit_info.donate_argnums, - donate_argnames=jit_info.donate_argnames, - device=jit_info.device, backend=jit_info.backend, - in_shardings_treedef=jit_info.in_shardings_treedef, - in_shardings_leaves=jit_info.in_shardings_leaves, - out_shardings_treedef=jit_info.out_shardings_treedef, - out_shardings_leaves=jit_info.out_shardings_leaves, - in_layouts_treedef=jit_info.in_layouts_treedef, - in_layouts_leaves=jit_info.in_layouts_leaves, - out_layouts_treedef=jit_info.out_layouts_treedef, - out_layouts_leaves=jit_info.out_layouts_leaves, - use_resource_env=jit_info.use_resource_env) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), fun, cache_miss, jit_info.static_argnums, - jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore - pxla.cc_shard_arg, - _get_cpp_global_cache(cache_key.contains_explicit_attributes)) - else: - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - jit_info.in_shardings_leaves, jit_info.out_shardings_leaves, - jit_info.in_layouts_leaves, jit_info.out_layouts_leaves, - jit_info.device, jit_info.backend) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), fun, cache_miss, jit_info.static_argnums, - jit_info.static_argnames, jit_info.donate_argnums, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(has_explicit_sharding)) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), + fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, + jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(jit_info.has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun @@ -401,6 +352,17 @@ def cache_miss(*args, **kwargs): return cpp_pjitted_f +def _pjit_explicit_sharding_and_layout( + in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, + device, backend) -> bool: + return (device is not None or + backend is not None or + any(not is_unspecified(i) for i in in_shardings_flat) or + any(not is_unspecified(o) for o in out_shardings_flat) or + any(i is not None for i in in_layouts_flat) or + any(o is not None for o in out_layouts_flat)) + + def _split_layout_and_sharding(entries): entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) layouts, shardings = [], [] @@ -484,6 +446,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings_leaves, out_shardings_leaves, in_layouts_leaves, + out_layouts_leaves, device, backend) + return PjitInfo( fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, @@ -501,6 +467,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, abstracted_axes=abstracted_axes, + has_explicit_sharding=has_explicit_sharding, use_resource_env=use_resource_env) @@ -1766,27 +1733,13 @@ def call_impl_cache_miss(*args_, **kwargs_): f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) - donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) - if xla_extension_version >= 283: - cache_key = pxla.JitGlobalCppCacheKeys( - donate_argnums=donated_argnums, donate_argnames=None, - device=None, backend=None, - in_shardings_treedef=None, in_shardings_leaves=in_shardings, - out_shardings_treedef=None, out_shardings_leaves=out_shardings, - in_layouts_treedef=None, in_layouts_leaves=in_layouts, - out_layouts_treedef=None, out_layouts_leaves=out_layouts, - use_resource_env=resource_env is not None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], cache_key, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) - else: - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings, out_shardings, in_layouts, out_layouts, None, None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(has_explicit_sharding))(*args) + donated_argnums = [i for i, d in enumerate(donated_invars) if d] + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings, out_shardings, in_layouts, out_layouts, None, None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], donated_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 56003ea7af5d..554bf2641769 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -90,17 +90,19 @@ def sync_global_devices(name: str): assert_equal(h, f"sync_global_devices name mismatch ('{name}')") -# Identity function is at the top level so that `process_allgather` doesn't -# recompile on every invocation. def _identity_fn(x): return x +@lru_cache(maxsize=128) +def _jitted_identity_fn(sharding): + return jax.jit(_identity_fn, out_shardings=sharding) + def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: reps = sharding_impls.GSPMDSharding.get_replicated( inp.sharding._device_assignment) - out = jax.jit(_identity_fn, out_shardings=reps)(inp) + out = _jitted_identity_fn(reps)(inp) else: # All inputs here will be fully addressable. if jax.process_count() == 1: @@ -123,8 +125,7 @@ def _handle_array_process_allgather(inp, tiled): bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( global_aval.shape, s, bufs) - out = jax.jit(_identity_fn, - out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr) + out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr) return np.asarray(out.addressable_data(0)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 392a25f32612..b4878b02199e 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -653,16 +653,18 @@ def testAutodiff(self, mesh, resources): @jtu.with_mesh([('x', 2), ('y', 1)]) def testAutodiffCache(self): - f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None) + f = pjit( + lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None + ) x = jnp.arange(16, dtype=jnp.float32) - jax.grad(f)(x) # Warm up the cache. - with jtu.count_pjit_cpp_cache_miss() as count: - jax.grad(f)(x) - if xla_extension_version >= 283: - self.assertEqual(count[0], 0) # no cache miss i.e. cache hit - else: - self.assertEqual(count[0], 2) + before = pjit_lib._pjit_lower_cached.cache_info() + jax.grad(f)(x) + after = pjit_lib._pjit_lower_cached.cache_info() + + # One hit for the forward pass, one hit for backward. + self.assertEqual(after.hits, before.hits + 2) + self.assertEqual(after.misses, before.misses) @jtu.with_mesh([('x', 2), ('y', 1)]) def testEvalJaxpr(self): @@ -4536,20 +4538,6 @@ def test_wsc_abstract_mesh_errors(self): ' match the mesh shape of the target sharding.*'): with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y'))) - @unittest.skipIf(xla_extension_version < 283, - "Requires xla_extension_version >= 283") - def test_global_jit_cpp_cache_hit_out_shardings(self): - mesh = jtu.create_global_mesh((2,), 'x') - s = NamedSharding(mesh, P('x')) - - def f(x): - return x * 2 - - with jtu.count_pjit_cpp_cache_miss() as count: - jax.jit(f, out_shardings=s)(np.arange(8)) - jax.jit(f, out_shardings=s)(np.arange(8)) - self.assertEqual(count[0], 1) - def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") From bcfe95e98ef51bd28a41dc7fb05fc2a9b950bfad Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 29 Aug 2024 10:49:30 -0700 Subject: [PATCH 286/702] Initial integration of sharding in types in JAX. Currently we just support `nary` ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind `jax_sharding_in_types` config flag. There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX. Co-authored-by: Dougal Maclaurin PiperOrigin-RevId: 668991384 --- jax/_src/array.py | 10 ++++-- jax/_src/config.py | 13 +++++++ jax/_src/core.py | 26 ++++++++++---- jax/_src/lax/lax.py | 80 +++++++++++++++++++++++++++++++++++++++---- jax/_src/lax/utils.py | 12 ++++--- tests/pjit_test.py | 26 ++++++++++++++ 6 files changed, 147 insertions(+), 20 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 7659c180ddc9..909f5acf0d43 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -41,7 +41,7 @@ from jax._src.lib import xla_extension as xe from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, + PmapSharding, SingleDeviceSharding, NamedSharding, device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache @@ -1012,7 +1012,13 @@ def make_array_from_single_device_arrays( core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval') xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity -api_util._shaped_abstractify_handlers[ArrayImpl] = op.attrgetter('aval') +def _get_aval_array(self): + if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): + return self.aval.update(sharding=NamedSharding( + self.sharding.mesh.abstract_mesh, self.sharding.spec)) + else: + return self.aval +api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array # TODO(jakevdp) replace this with true inheritance at the C++ level. basearray.Array.register(ArrayImpl) diff --git a/jax/_src/config.py b/jax/_src/config.py index b6d2358f4c26..f17ca59385fe 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -213,6 +213,7 @@ def trace_context(): default_device.value, random_seed_offset.value, threefry_partitionable.value, threefry_gpu_kernel_lowering.value, + sharding_in_types.value, softmax_custom_jvp.value, enable_memories.value, disable_jit.value, @@ -826,6 +827,7 @@ class _GlobalExtraJitContext(NamedTuple): random_seed_offset: int = 0 threefry_partitionable: bool = False threefry_gpu_kernel_lowering: bool = False + sharding_in_types: bool = False softmax_custom_jvp: bool = False xla_profile_version: int = 0 pgle_profiling_runs: int = 0 @@ -864,6 +866,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): random_seed_offset: int | None = None threefry_partitionable: bool | None = None threefry_gpu_kernel_lowering: bool | None = None + sharding_in_types: bool | None = None softmax_custom_jvp: bool | None = None xla_profile_version: int | None = None pgle_profiling_runs: int | None = None @@ -1139,6 +1142,16 @@ def _update_jax_memories_thread_local(val): update_thread_local_hook=lambda val: update_thread_local_jit_state( threefry_gpu_kernel_lowering=val)) +sharding_in_types = bool_state( + name='jax_sharding_in_types', + default=False, + help=('When True, enables forward only sharding propagation in JAX and ' + 'avals have sharding on them.'), + update_global_hook=lambda val: _update_global_jit_state( + sharding_in_types=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + sharding_in_types=val)) + softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', diff --git a/jax/_src/core.py b/jax/_src/core.py index f80cd0418b81..dbd4dc58e2e2 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1733,17 +1733,23 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) class ShapedArray(UnshapedArray): - __slots__ = ['shape'] + __slots__ = ['shape', 'dtype', 'weak_type', 'sharding'] array_abstraction_level = 2 named_shape = {} # type: ignore - def __init__(self, shape, dtype, weak_type=False, named_shape=None): + def __init__(self, shape, dtype, weak_type=False, named_shape=None, + sharding=None): del named_shape # unused, vestigial self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type + if config.sharding_in_types.value: + self.sharding = sharding + else: + self.sharding = None - def update(self, shape=None, dtype=None, weak_type=None, named_shape=None): + def update(self, shape=None, dtype=None, weak_type=None, named_shape=None, + sharding=None): del named_shape # unused, vestigial if shape is None: shape = self.shape @@ -1751,7 +1757,9 @@ def update(self, shape=None, dtype=None, weak_type=None, named_shape=None): dtype = self.dtype if weak_type is None: weak_type = self.weak_type - return ShapedArray(shape, dtype, weak_type) + if sharding is None: + sharding = self.sharding + return ShapedArray(shape, dtype, weak_type, sharding=sharding) ndim = property(lambda self: len(self.shape)) size = property(lambda self: @@ -1766,13 +1774,14 @@ def update(self, shape=None, dtype=None, weak_type=None, named_shape=None): def __eq__(self, other): return (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape - and self.weak_type == other.weak_type) + and self.weak_type == other.weak_type + and self.sharding == other.sharding) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) - return hash((self.shape, self.dtype, self.weak_type)) + return hash((self.shape, self.dtype, self.weak_type, self.sharding)) def at_least_vspace(self): return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), @@ -1791,7 +1800,10 @@ def str_short(self, short_dtypes=False): dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name dt_str = dt_str.replace('void', 'float0') shapestr = ','.join(map(str, self.shape)) - return f'{dt_str}[{shapestr}]' + if self.sharding is None: + return f'{dt_str}[{shapestr}]' + else: + return f'{dt_str}[{shapestr}]({self.sharding})' def _len(self, ignored_tracer): try: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e81c6c157627..29fa41012f80 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -65,7 +65,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.sharding_impls import PmapSharding +from jax._src.sharding_impls import PmapSharding, NamedSharding, PartitionSpec from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis, split_list, NumpyComplexWarning) @@ -1709,13 +1709,54 @@ def broadcasting_shape_rule(name, *avals): return tuple(result_shape) +def broadcasting_sharding_rule(name, *avals): + shapes = [aval.shape for aval in avals if aval.shape] + if not shapes: + return () + if len({len(shape) for shape in shapes}) != 1: + msg = '{}: arrays must have same number of dimensions, got {}.' + raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) + + specs = [a.sharding.spec for a in avals if a.shape] + + mesh = None + for a in avals: + if a.shape: + mesh = a.sharding.mesh + if mesh is not None and mesh != a.sharding.mesh: + raise ValueError( + f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' + f' another mesh: {a.sharding.mesh}') + assert mesh is not None + + result_specs = [] + for ss, ds in zip(zip(*specs), zip(*shapes)): + if all(s == ss[0] for s in ss[1:]): + # if all dimension shardings are same, the resulting dimension sharding is + # the same. + result_specs.append(ss[0]) + else: + non_trivial_s = [s for s, d in zip(ss, ds) + if not (core.definitely_equal(d, 1) and s is None)] + if not non_trivial_s: + result_specs.append(None) + elif all(non_trivial_s[0] == s for s in non_trivial_s[1:]): + result_specs.append(non_trivial_s[0]) + else: + raise TypeError(f'{name} got incompatible shardings for broadcasting: ' + f'{", ".join(map(str, map(tuple, specs)))}.') + return NamedSharding(mesh, PartitionSpec(*result_specs)) + + def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same_dtypes=False): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, allow_extended_dtype=allow_extended_dtype, require_same=require_same_dtypes) shape_rule = partial(broadcasting_shape_rule, name) - prim = standard_primitive(shape_rule, dtype_rule, name) + sharding_rule = partial(broadcasting_sharding_rule, name) + prim = standard_primitive(shape_rule, dtype_rule, name, + sharding_rule=sharding_rule) batching.defbroadcasting(prim) pe.def_trivial_padding(prim) return prim @@ -1772,6 +1813,20 @@ def broadcast_hlo( out.append(arg) return out +def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): + out = [] + for op, in_aval in zip(ops, in_avals): + if in_aval.sharding == out_aval.sharding or in_aval.sharding is None: + out.append(op) + else: + # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains + # CompilerShardingAxis, then specify `unspecified_dims` via + # `wrap_with_sharding_op`. + sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() + out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp)) + return out + + def _nary_lower_hlo(op: Callable, ctx, *args: ir.Value, explicit_type=False, **params) -> Sequence[ir.Value]: @@ -1782,13 +1837,19 @@ def _nary_lower_hlo(op: Callable, ctx, """ del params avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out - broadcasted_args = mlir.multi_broadcast_in_dim( - ctx, args, avals_in, aval_out.shape) + args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape) # type: ignore + if config.sharding_in_types.value: + args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) if explicit_type: - return [op(mlir.aval_to_ir_type(aval_out), *broadcasted_args)] + out = op(mlir.aval_to_ir_type(aval_out), *args) else: - return [op(*broadcasted_args)] + out = op(*args) + if config.sharding_in_types.value: + out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)] + else: + return [out] _float = {np.floating} @@ -2445,6 +2506,10 @@ def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type, sharding): return operand.shape +def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type, + sharding): + return sharding + def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type, sharding): if (operand.dtype != new_dtype and @@ -2538,7 +2603,8 @@ def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding): convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, - _convert_element_type_weak_type_rule)) + _convert_element_type_weak_type_rule, + _convert_element_type_sharding_rule)) ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule batching.defvectorized(convert_element_type_p) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 01301db1a9a0..5e3e9bcd8df2 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -20,6 +20,7 @@ from jax._src import core from jax._src import dispatch +from jax._src import config from jax._src import dtypes from jax._src.util import safe_zip from jax._src.lib import xla_client @@ -37,19 +38,19 @@ def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) def standard_primitive(shape_rule, dtype_rule, name, - weak_type_rule=None): + weak_type_rule=None, sharding_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, - weak_type_rule)) + weak_type_rule, sharding_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, - *avals, **kwargs): + sharding_rule, *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) @@ -58,8 +59,11 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, out = prim.impl(*[x.val for x in avals], **kwargs) return core.ConcreteArray(out.dtype, out, weak_type=weak_type) elif least_specialized is core.ShapedArray: + out_sharding = (sharding_rule(*avals, **kwargs) + if config.sharding_in_types.value else None) return core.ShapedArray(shape_rule(*avals, **kwargs), - dtype_rule(*avals, **kwargs), weak_type=weak_type) + dtype_rule(*avals, **kwargs), weak_type=weak_type, + sharding=out_sharding) elif least_specialized is core.DShapedArray: shape = shape_rule(*avals, **kwargs) ty = (core.ShapedArray if all(type(d) is int for d in shape) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b4878b02199e..1e4b5685d45a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4543,6 +4543,32 @@ def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") +class ShardingInTypesTest(jtu.JaxTestCase): + + @config.sharding_in_types(True) + def test_basic_mul(self): + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + self.assertEqual(x.sharding.spec, s.spec) + x = x * 2 + self.assertEqual(x.sharding.spec, s.spec) + x = x * x + self.assertEqual(x.sharding.spec, s.spec) + return x + + out = f(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2)) + + lowered_text = f.lower(arr).as_text() + self.assertEqual(lowered_text.count('@Sharding'), 2) + + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From c140103a02b6b097f8547760fd80648ad14d8b79 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 29 Aug 2024 13:23:14 -0700 Subject: [PATCH 287/702] new README tagline --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8f50aa42a125..878543304b25 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ logo -# JAX: Autograd and XLA +# Pushing back the limits on numerical computing ![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg) ![PyPI version](https://img.shields.io/pypi/v/jax) From 93ba65e23970788e44e88e82ca4210f7b4f17138 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 29 Aug 2024 14:31:00 -0700 Subject: [PATCH 288/702] Get StableHLO version from compatibility requirements in JAX and PJRT. PiperOrigin-RevId: 669064292 --- jax/_src/export/_export.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 54defa0e9c54..65f3d2852348 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -681,8 +681,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: def _module_to_bytecode(module: ir.Module) -> bytes: mlir_str = mlir.module_to_bytecode(module) # `target_version` is used to manage situations when a StableHLO producer - # (in this case, jax2tf) and a StableHLO consumer were built using - # different versions of StableHLO. + # and a StableHLO consumer were built using different versions of StableHLO. # # Each StableHLO version `producer_version` has a compatibility window, # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], @@ -691,12 +690,19 @@ def _module_to_bytecode(module: ir.Module) -> bytes: # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md # for the exact extent of these compatibility guarantees. # - # `hlo.get_minimum_version()` returns `consumer_version_min` - # for the current version of StableHLO. We are using it here to maximize - # forward compatibility, i.e. to maximize how far into the past we can go - # and still have the payloads produced by `serialize_portable_artifact` - # compatible with potential consumers from the past. - target_version = hlo.get_minimum_version() + # `hlo.get_version_from_compatibility_requirement(WEEK_4)` returns a version + # of StableHLO >= 4w old. This allows new StableHLO features to be used after + # ~4w and be compatible with any consumer that is updated on at least a + # monthly cadence. + # + # Note that this does not verify any JAX custom calls, which are only + # guaranteed 3w of forward compatibility, and only prevents use of new + # StableHLO features from failing on older hardware. + if hlo.get_api_version() < 9: + target_version = hlo.get_minimum_version() + else: + target_version = hlo.get_version_from_compatibility_requirement( + hlo.StablehloCompatibilityRequirement.WEEK_4) module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore mlir_str, target_version) return module_serialized From 45099b8e09ee87e7ab6531a363ccb832ff4f2a72 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 29 Aug 2024 14:32:17 -0700 Subject: [PATCH 289/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ed742254c6f9ec81a5c760fe3e709e72610eeffe. PiperOrigin-RevId: 669064685 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d9cd8c56672f..bbfd289abb17 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "09982ee2b99dc895ea4d038610f3dcfc2b10a9df" -XLA_SHA256 = "42e3e82c95b095d0fb39c83c497886f3353292381e6aa6403b7e616e584ea32c" +XLA_COMMIT = "ed742254c6f9ec81a5c760fe3e709e72610eeffe" +XLA_SHA256 = "508b0fa82c42a9f18507985467703825a1fa62a2d4f816a6575e4236fe99258b" def repo(): tf_http_archive( From e691f1f36b7fbe9c947b10e053e5379bcc203008 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 29 Aug 2024 17:45:43 -0700 Subject: [PATCH 290/702] Fix inaccuracy in jnp.matmul docs --- jax/_src/numpy/lax_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ae7ff06cc5c2..3e4d52bf97e1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6472,7 +6472,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *, JAX implementation of :func:`numpy.matmul`. Args: - a: first input array, of shape ``(..., N)``. + a: first input array, of shape ``(N,)`` or ``(..., K, N)``. b: second input array. Must have shape ``(N,)`` or ``(..., N, M)``. In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions of ``a``. From 727eb6a8dc7d25b766c26b31ea2cb1692d421cdb Mon Sep 17 00:00:00 2001 From: David Mis Date: Thu, 29 Aug 2024 21:43:48 -0500 Subject: [PATCH 291/702] Typo in documentation Fixed typo in documentation of `slogdet` --- jax/_src/numpy/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 2af25bcf80c4..79b47d9090af 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -501,7 +501,7 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: """ Compute the sign and (natural) logarithm of the determinant of an array. - JAX implementation of :func:`numpy.linalg.slotdet`. + JAX implementation of :func:`numpy.linalg.slogdet`. Args: a: array of shape ``(..., M, M)`` for which to compute the sign and log determinant. From 283770a0cbe2518182eb1a77c278c5cf46f91d8f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 30 Aug 2024 01:19:34 -0700 Subject: [PATCH 292/702] Relax test tolerance to fix a test failure on ARM CPU. Relax a test tolerance to fix a test failure with AVX enabled on a TPU host. PiperOrigin-RevId: 669233976 --- tests/linalg_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 6cd0110538eb..944880066437 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -325,7 +325,7 @@ def testEigvals(self, shape, dtype): a, = args_maker() w1, _ = jnp.linalg.eig(a) w2 = jnp.linalg.eigvals(a) - self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 1e-14}) + self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 2e-14}) @jtu.run_on_devices("cpu") def testEigvalsInf(self): @@ -489,7 +489,7 @@ def testEighRankDeficient(self, rank): with jax.numpy_rank_promotion("allow"): self.assertLessEqual( np.linalg.norm(np.matmul(a, v) - w * v), - 81 * eps * np.linalg.norm(a), + 85 * eps * np.linalg.norm(a), ) @jtu.sample_product( From 6a1adc842b536715847508b5c7e57247bfa8980b Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 30 Aug 2024 01:27:48 -0700 Subject: [PATCH 293/702] Disable scoped tracing in pipelining PiperOrigin-RevId: 669235827 --- jax/_src/pallas/mosaic/pipeline.py | 82 +++++++++++++++++++----------- 1 file changed, 51 insertions(+), 31 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index af3d55f581be..fca9ee471e6a 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -16,12 +16,13 @@ from __future__ import annotations from collections.abc import Sequence +from contextlib import contextmanager import dataclasses import enum import functools import itertools import operator -from typing import Union, Any +from typing import Any, Union import jax from jax import lax @@ -516,30 +517,34 @@ def accumulate(self): class Scheduler: """Sequences input and output copies and waits for a pipeline.""" - def __init__(self, - step: jax.Array, - grid: tuple[int | jax.Array, ...], - grid_offsets: tuple[int | jax.Array, ...], - first_cycle=None, - last_cycle=None, - init_accumulators=None, - ): + def __init__( + self, + step: jax.Array, + grid: tuple[int | jax.Array, ...], + grid_offsets: tuple[int | jax.Array, ...], + first_cycle=None, + last_cycle=None, + init_accumulators=None, + trace_scopes=True, + ): """Initializes scheduler. - Args: - step: inner step number. - grid: pallas grid for BufferedRefs. - grid_offsets: offsets for grid indices (used for megacore). - first_cycle: whether this is the first invocation of the pipeline. - last_cycle: whether this is the last invocation of the pipeline. - init_accumulators: do we zero-initialize accumulator state for this - invocation of the pipeline. + Args: + step: inner step number. + grid: pallas grid for BufferedRefs. + grid_offsets: offsets for grid indices (used for megacore). + first_cycle: whether this is the first invocation of the pipeline. + last_cycle: whether this is the last invocation of the pipeline. + init_accumulators: do we zero-initialize accumulator state for this + invocation of the pipeline. + trace_scopes: whether to use named_scope to trace blocks in the pipeline. """ self.step = step self.grid = grid self.first_cycle = first_cycle self.last_cycle = last_cycle self.init_accumulators = init_accumulators + self.trace_scopes = trace_scopes # Total number of linear steps. self.num_steps = _grid_size(grid) @@ -565,6 +570,14 @@ def __init__(self, self.next_step, grid, grid_offsets ) + @contextmanager + def _named_scope(self, name): + if self.trace_scopes: + with jax.named_scope(name): + yield + else: + yield + def grid_env(self): return pallas_core.grid_env( list(map(pallas_core.GridAxis, self.indices, self.grid))) @@ -592,7 +605,7 @@ def initialize(self, buffered_ref, src_ref, schedule=None): schedule = _default_schedule pred = schedule["prologue_copy_in"](self, buffered_ref, src_ref) - with jax.named_scope("ep_initialize"): + with self._named_scope("ep_initialize"): @pl.when(self.first_step_ever) def _init_slots(): buffered_ref.init_slots() @@ -611,7 +624,7 @@ def wait_in(self, buffered_ref, src_ref, schedule=None): schedule = _default_schedule pred = schedule["wait_in"](self, buffered_ref, src_ref) - @jax.named_scope("ep_wait_in") + @self._named_scope("ep_wait_in") def _wait(): if buffered_ref.is_input: buffered_ref.wait_in(src_ref, self.indices) @@ -619,7 +632,8 @@ def _wait(): # In most cases we won't be waiting when init_accumulators is True, # so this is usually just setting what we just copied. buffered_ref.set_accumulator(self.init_accumulators) - @jax.named_scope("ep_set_accum") + + @self._named_scope("ep_set_accum") def _no_wait(): if buffered_ref.is_accumulator: @@ -636,7 +650,7 @@ def copy_in(self, buffered_ref, src_ref, schedule=None): pred = schedule['copy_in'](self, buffered_ref, src_ref) @pl.when(pred) - @jax.named_scope("ep_copy_in") + @self._named_scope("ep_copy_in") def _send(): if buffered_ref.is_input: # We skip the last step because that's what prefetch is for. @@ -653,7 +667,7 @@ def prefetch(self, buffered_ref, src_ref, schedule=None): pred = schedule['prefetch'](self, buffered_ref, src_ref) @pl.when(pred) - @jax.named_scope("ep_prefetch") + @self._named_scope("ep_prefetch") def _send(): if buffered_ref.is_input: # Prefetch should only run on the last step. @@ -667,7 +681,7 @@ def wait_out(self, buffered_ref, dst_ref, schedule=None): pred = schedule['wait_out'](self, buffered_ref, dst_ref) @pl.when(pred) - @jax.named_scope("ep_wait_out") + @self._named_scope("ep_wait_out") def _wait(): if buffered_ref.is_output: buffered_ref.wait_out(dst_ref, self.prev_indices) @@ -680,13 +694,14 @@ def copy_out(self, buffered_ref, dst_ref, schedule=None): schedule = _default_schedule pred = schedule['copy_out'](self, buffered_ref, dst_ref) - @jax.named_scope("ep_copy_out") + @self._named_scope("ep_copy_out") def _copy_out_and_accumulate(): if buffered_ref.is_accumulator: buffered_ref.accumulate() if buffered_ref.is_output: buffered_ref.copy_out(dst_ref, self.indices) - @jax.named_scope("ep_accum") + + @self._named_scope("ep_accum") def _just_accumulate(): if buffered_ref.is_accumulator: # We accumulate on the last step because we will set the accumulator @@ -705,7 +720,7 @@ def finalize(self, buffered_ref, dst_ref, schedule=None): pred = schedule['epilogue_wait_out'](self, buffered_ref, dst_ref) @pl.when(pred) - @jax.named_scope("ep_finalize") + @self._named_scope("ep_finalize") def _end(): if buffered_ref.is_output: buffered_ref.swap_slots() # formally correct, not actually necessary. @@ -948,7 +963,8 @@ def emit_pipeline( out_specs=None, should_accumulate_out=False, core_axis: int | None = None, - dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None + dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None, + trace_scopes: bool = True, ): """Creates a function to emit a manual pallas pipeline. @@ -971,6 +987,8 @@ def emit_pipeline( along the core axis. dimension_semantics: optional tuple of GridDimensionSemantics (e.g. PARALLEL or ARBITRARY). + trace_scopes: optional bool, indicates whether to annotate each region in + the pipeline using named_scope. """ if any(not isinstance(d, (int, jax.Array)) for d in grid): grid_types = tuple(type(d) for d in grid) @@ -1065,7 +1083,9 @@ def loop_body(step, _): grid_offsets=grid_offsets, first_cycle=first_cycle, last_cycle=last_cycle, - init_accumulators=init_accumulators) + init_accumulators=init_accumulators, + trace_scopes=trace_scopes, + ) # prepare any local VMEM aliases brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) @@ -1076,7 +1096,7 @@ def loop_body(step, _): map_brefs(scheduler.wait_in, brefs, refs, schedule) # prefetch inputs for the *next* invocation of this pipeline - with jax.named_scope("ep_prefetch"): + with scheduler._named_scope("ep_prefetch"): if prefetch is not None: lax.cond(step == num_steps - 1, lambda: prefetch(*brefs, scheduler), @@ -1084,7 +1104,7 @@ def loop_body(step, _): # run the kernel! current_refs = map_brefs(lambda x: x.current_ref, brefs) - with jax.named_scope("ep_run_kernel"): + with scheduler._named_scope("ep_run_kernel"): with scheduler.grid_env(): body(*current_refs, *scratches) @@ -1092,7 +1112,7 @@ def loop_body(step, _): map_brefs(scheduler.copy_out, brefs, refs, schedule) map_brefs(scheduler.wait_out, brefs, refs, schedule) # handle writes for the *last* invocation of this pipeline's outputs - with jax.named_scope("ep_postyeet"): + with scheduler._named_scope("ep_postyeet"): if postyeet is not None: lax.cond(step == 0, lambda: postyeet(*brefs, scheduler), From 02bb884357e9bf060c8b4b90bff3c6b233f0a097 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 30 Aug 2024 02:01:02 -0700 Subject: [PATCH 294/702] ``jax.tree_util.register_dataclass`` now validates ``data_fields`` and ``meta_fields`` A well-behaved registration call must list all ``init=True`` fields in either ``data_fields`` or ``meta_fields``. Otherwise, ``flatten . unflatten`` could potentially *not* be an identity PiperOrigin-RevId: 669244669 --- CHANGELOG.md | 3 +++ jax/_src/tree_util.py | 37 +++++++++++++++++++------- tests/tree_util_test.py | 58 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d0b55b36476..5a3dac8c3152 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `jax.host_ids()` function that was deprecated in JAX v0.2.13. * To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been modified to no longer support `complex dtypes`. + * ``jax.tree_util.register_dataclass`` now checks that ``data_fields`` + and ``meta_fields`` includes all dataclass fields with ``init=True`` + and only them, if ``nodetype`` is a dataclass. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 07beed7276a8..b1c18a48263f 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -15,6 +15,7 @@ import collections from collections.abc import Callable, Hashable, Iterable, Sequence +import dataclasses from dataclasses import dataclass import difflib import functools @@ -925,7 +926,10 @@ class that defines how it could be flattened with keys. @export def register_dataclass( - nodetype: Typ, data_fields: Sequence[str], meta_fields: Sequence[str] + nodetype: Typ, + data_fields: Sequence[str], + meta_fields: Sequence[str], + drop_fields: Sequence[str] = (), ) -> Typ: """Extends the set of types that are considered internal nodes in pytrees. @@ -941,16 +945,14 @@ def register_dataclass( attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed among ``meta_fields`` or ``data_fields``. + meta_fields: auxiliary data field names. These fields *must* contain static, + hashable, immutable objects, as these objects are used to generate JIT cache + keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or + :class:`numpy.ndarray` objects. data_fields: data field names. These fields *must* be JAX-compatible objects such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or - pytrees whose leaves are arrays or scalars. Note that ``None`` is valid, as - this is recognized by JAX as an empty pytree. - meta_fields: auxiliary data field names. These fields will be considered static - within JAX transformations such as :func:`jax.jit`. The listed fields *must* - contain static, hashable, immutable objects, as these objects are used to - generate JIT cache keys: for example strings, Python scalars, or array shapes - and dtypes. In particular, ``meta_fields`` cannot contain :class:`jax.Array` - or :class:`numpy.ndarray` objects, as they are not hashable. + pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be + ``None``, as this is recognized by JAX as an empty pytree. Returns: The input class ``nodetype`` is returned unchanged after being added to JAX's @@ -1003,6 +1005,23 @@ def register_dataclass( meta_fields = tuple(meta_fields) data_fields = tuple(data_fields) + if dataclasses.is_dataclass(nodetype): + init_fields = {f.name for f in dataclasses.fields(nodetype) if f.init} + init_fields.difference_update(*drop_fields) + if {*meta_fields, *data_fields} != init_fields: + msg = ( + "data_fields and meta_fields must include all dataclass fields with" + " ``init=True`` and only them." + ) + if missing := init_fields - {*meta_fields, *data_fields}: + msg += ( + f" Missing fields: {missing}. Add them to drop_fields to suppress" + " this error." + ) + if unexpected := {*meta_fields, *data_fields} - init_fields: + msg += f" Unexpected fields: {unexpected}." + raise ValueError(msg) + def flatten_with_keys(x): meta = tuple(getattr(x, name) for name in meta_fields) data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields) diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 23ddf73904b5..bc741702ce58 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -1257,5 +1257,63 @@ def test_tree_unflatten(self): ) +class RegistrationTest(jtu.JaxTestCase): + + def test_register_dataclass_missing_fields(self): + @dataclasses.dataclass + class Foo: + x: int + y: int + z: float = dataclasses.field(init=False) + + with self.assertRaisesRegex( + ValueError, + "data_fields and meta_fields must include all dataclass fields.*" + "Missing fields: {'y'}", + ): + tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=[]) + + # ``z`` is not required, because it's not included in ``__init__``. + tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=["y"]) + + def test_register_dataclass_unexpected_fields(self): + @dataclasses.dataclass + class Foo: + x: int + y: float + + with self.assertRaisesRegex( + ValueError, + "data_fields and meta_fields must include all dataclass fields.*" + "Unexpected fields: {'z'}", + ): + tree_util.register_dataclass( + Foo, data_fields=["x"], meta_fields=["y", "z"] + ) + + def test_register_dataclass_drop_fields(self): + @dataclasses.dataclass + class Foo: + x: int + y: int = dataclasses.field(default=42) + + # ``y`` is explicitly excluded. + tree_util.register_dataclass( + Foo, data_fields=["x"], meta_fields=[], drop_fields=["y"] + ) + + def test_register_dataclass_invalid_plain_class(self): + class Foo: + x: int + y: int + + def __init__(self, x, y): + self.x = x + self.y = y + + # ``y`` is missing, but no validation is done for plain classes. + tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=[]) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 4342c0c0f37e34772246f1d4e6dd6c0ffea882d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Fri, 30 Aug 2024 02:05:32 -0700 Subject: [PATCH 295/702] Determine LAPACK workspace during Householder Product Kernel runtime Workspace dependency was removed, and the info parameter is ignored now. PiperOrigin-RevId: 669246058 --- jaxlib/cpu/lapack_kernels.cc | 41 +++++++++++++++++------------------- jaxlib/cpu/lapack_kernels.h | 4 +--- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index fd0a12ef2ed1..a6c5993f43b6 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -408,34 +408,34 @@ template struct Orgqr>; template ffi::Error OrthogonalQr::Kernel(ffi::Buffer x, ffi::Buffer tau, - ffi::ResultBuffer x_out, - ffi::ResultBuffer info, - ffi::ResultBuffer work) { + ffi::ResultBuffer x_out) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); auto* tau_data = tau.typed_data(); auto* x_out_data = x_out->typed_data(); - auto* info_data = info->typed_data(); - auto* work_data = work->typed_data(); + lapack_int info; CopyIfDiffBuffer(x, x_out); - FFI_ASSIGN_OR_RETURN(auto tau_size_v, MaybeCastNoOverflow( - tau.dimensions().back())); + // Prepare LAPACK workspaces. + int64_t work_size = GetWorkspaceSize(x_rows, x_cols, tau.dimensions().back()); + FFI_ASSIGN_OR_RETURN(auto work_size_v, + MaybeCastNoOverflow(work_size)); + auto work_data = AllocateScratchMemory(work_size); + FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow(x_rows)); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto tau_size_v, MaybeCastNoOverflow( + tau.dimensions().back())); auto x_leading_dim_v = x_rows_v; const int64_t x_out_step{x_rows * x_cols}; const int64_t tau_step{tau_size_v}; for (int64_t i = 0; i < batch_count; ++i) { fn(&x_rows_v, &x_cols_v, &tau_size_v, x_out_data, &x_leading_dim_v, - tau_data, work_data, &workspace_dim_v, info_data); + tau_data, work_data.get(), &work_size_v, &info); x_out_data += x_out_step; tau_data += tau_step; - ++info_data; } return ffi::Error::Success(); } @@ -1380,8 +1380,7 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( if (is_finite(x_copy.get(), x_size)) { fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v, eigvals_data, eigvecs_left_data, &x_cols_v, eigvecs_right_data, - &x_cols_v, work_data.get(), &work_size_v, rwork_data.get(), - info_data); + &x_cols_v, work_data.get(), &work_size_v, rwork_data.get(), info_data); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_left_data, x_size_bytes); @@ -1716,15 +1715,13 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*x_out*/) \ .Ret<::xla::ffi::Buffer>(/*tau*/)) -#define JAX_CPU_DEFINE_ORGQR(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, OrthogonalQr::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Arg<::xla::ffi::Buffer>(/*tau*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/)) +#define JAX_CPU_DEFINE_ORGQR(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, OrthogonalQr::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Arg<::xla::ffi::Buffer>(/*tau*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/)) #define JAX_CPU_DEFINE_POTRF(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 4d021b688de9..a571de5dd6de 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -226,9 +226,7 @@ struct OrthogonalQr { static ::xla::ffi::Error Kernel(::xla::ffi::Buffer x, ::xla::ffi::Buffer tau, - ::xla::ffi::ResultBuffer x_out, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer work); + ::xla::ffi::ResultBuffer x_out); static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, lapack_int tau_size); From 4c218fbf3b8431a5f75cdf20942d5d62433a8657 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 30 Aug 2024 05:24:33 -0700 Subject: [PATCH 296/702] Add cost estimate for forward pass of flash attention. PiperOrigin-RevId: 669291182 --- .../pallas/ops/tpu/flash_attention.py | 58 +++++++++++++++++-- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 6ce3a1886b1c..82bcde8153ef 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -17,6 +17,7 @@ import dataclasses import functools +import math from typing import Any, NamedTuple import jax @@ -565,6 +566,40 @@ def _flash_attention_kernel_single_batch_single_step( ).astype(o_tile_ref.dtype) +def _bytes(x: jax.Array | jax.ShapeDtypeStruct) -> int: + return math.prod(x.shape) * x.dtype.itemsize + + +def _fwd_cost_estimate( + q: jax.Array, + k: jax.Array, + v: jax.Array, + ab: jax.Array | None, + segment_ids: SegmentIds | None, + *, + causal: bool, + sm_scale: jax.Array | None, + kernel_inputs_specs, + kernel_outputs_specs, +) -> pl.CostEstimate | None: + full_cost = ( + mha_reference.lower( + q, k, v, ab, segment_ids, causal=causal, sm_scale=sm_scale + ) + .compile() + .cost_analysis() + ) + if not full_cost: + return None + input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs)) + output_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs)) + return pl.CostEstimate( + flops=full_cost[0]["flops"], + transcendentals=full_cost[0]["transcendentals"], + bytes_accessed=input_bytes + output_bytes, + ) + + def _flash_attention_impl( q, k, @@ -746,12 +781,23 @@ def kv_segment_ids_index_map( out_shape=out_shape, debug=debug, compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=( - "parallel", - "parallel", - "parallel", - "arbitrary", - ) + dimension_semantics=( + "parallel", + "parallel", + "parallel", + "arbitrary", + ) + ), + cost_estimate=_fwd_cost_estimate( + q, + k, + v, + ab, + segment_ids, + causal=causal, + sm_scale=sm_scale, + kernel_inputs_specs=(q, k, v, ab, q_segment_ids, kv_segment_ids), + kernel_outputs_specs=out_shape, ), )(q, k, v, ab, q_segment_ids, kv_segment_ids) if save_residuals: From 5b1a3b5375f0fef49591efbe3a57daef1a397a69 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 30 Aug 2024 06:23:14 -0700 Subject: [PATCH 297/702] Use dtypes instead of dtype names in Pallas GPU extern tables This allows proper checking/casting of weak dtypes, which will be implemented in a follow up. PiperOrigin-RevId: 669304338 --- jax/_src/pallas/triton/lowering.py | 320 +++++++++++++++-------------- 1 file changed, 164 insertions(+), 156 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 4057b125fcdc..3259ee634e1f 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -571,7 +571,7 @@ def _not_lowering_rule(ctx: LoweringRuleContext, x): @dataclasses.dataclass(frozen=True) class _Extern: - arg_types: Sequence[str] + arg_types: Sequence[jax.typing.DTypeLike] symbol: str result_type: str @@ -579,7 +579,7 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool: if len(avals) != len(self.arg_types): return False return all( - aval.dtype.name == arg_type + aval.dtype == jnp.dtype(arg_type) for aval, arg_type in zip(avals, self.arg_types) ) @@ -600,7 +600,7 @@ def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]): @dataclasses.dataclass(frozen=True) class _Fallback: - arg_types: Sequence[str] + arg_types: Sequence[jax.typing.DTypeLike] lower: Callable[..., ir.Value] matches = _Extern.matches @@ -614,7 +614,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: table = tables[ctx.context.platform] h = next((e for e in table if e.matches(ctx.avals_in)), None) if h is None: - arg_aval_dtypes = tuple(aval.dtype.name for aval in ctx.avals_in) + arg_aval_dtypes = tuple(aval.dtype for aval in ctx.avals_in) raise NotImplementedError( f"unsupported types for {name}: {arg_aval_dtypes}" ) @@ -623,7 +623,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: bcast_args = [] for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types): bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape) - if aval.weak_type and aval.dtype.name != arg_type: + if aval.weak_type and aval.dtype != jnp.dtype(arg_type): bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type)) bcast_args.append(bcast_arg) return h.lower(ctx, *bcast_args) @@ -634,16 +634,16 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: _abs_dispatch_table = _make_dispatch_table( "abs", cuda=[ - _Extern(["int32"], "__nv_abs", "int32"), - _Extern(["int64"], "__nv_llabs", "int64"), - _Extern(["float32"], "__nv_fabsf", "float32"), - _Extern(["float64"], "__nv_fabs", "float64"), + _Extern([jnp.int32], "__nv_abs", jnp.int32), + _Extern([jnp.int64], "__nv_llabs", jnp.int64), + _Extern([jnp.float32], "__nv_fabsf", jnp.float32), + _Extern([jnp.float64], "__nv_fabs", jnp.float64), ], rocm=[ - _Fallback(["int32"], lambda ctx, x: math_dialect.absi(x)), - _Fallback(["int64"], lambda ctx, x: math_dialect.absi(x)), - _Fallback(["float32"], lambda ctx, x: math_dialect.absf(x)), - _Fallback(["float64"], lambda ctx, x: math_dialect.absf(x)), + _Fallback([jnp.int32], lambda ctx, x: math_dialect.absi(x)), + _Fallback([jnp.int64], lambda ctx, x: math_dialect.absi(x)), + _Fallback([jnp.float32], lambda ctx, x: math_dialect.absf(x)), + _Fallback([jnp.float64], lambda ctx, x: math_dialect.absf(x)), ], ) @@ -667,337 +667,345 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): lax.ceil_p: _make_dispatch_table( "ceil", cuda=[ - _Extern(["float32"], "__nv_ceilf", "float32"), - _Extern(["float64"], "__nv_ceil", "float64"), + _Extern([jnp.float32], "__nv_ceilf", jnp.float32), + _Extern([jnp.float64], "__nv_ceil", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_ceil_f32", "float32"), - _Extern(["float64"], "__ocml_ceil_f64", "float64"), + _Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64), ], ), lax.floor_p: _make_dispatch_table( "floor", cuda=[ - _Extern(["float32"], "__nv_floorf", "float32"), - _Extern(["float64"], "__nv_floor", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)), + _Extern([jnp.float32], "__nv_floorf", jnp.float32), + _Extern([jnp.float64], "__nv_floor", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)), ], rocm=[ - _Extern(["float32"], "__ocml_floor_f32", "float32"), - _Extern(["float64"], "__ocml_floor_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)), + _Extern([jnp.float32], "__ocml_floor_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_floor_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)), ], ), lax.exp_p: _make_dispatch_table( "exp", cuda=[ - _Extern(["float32"], "__nv_expf", "float32"), - _Extern(["float64"], "__nv_exp", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)), + _Extern([jnp.float32], "__nv_expf", jnp.float32), + _Extern([jnp.float64], "__nv_exp", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)), ], rocm=[ - _Fallback(["float32"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["float64"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.float32], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.float64], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)), ], ), lax.exp2_p: _make_dispatch_table( "exp2", cuda=[ - _Extern(["float32"], "__nv_exp2f", "float32"), - _Extern(["float64"], "__nv_exp2", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)), + _Extern([jnp.float32], "__nv_exp2f", jnp.float32), + _Extern([jnp.float64], "__nv_exp2", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)), ], rocm=[ - _Extern(["float32"], "__ocml_exp2_f32", "float32"), - _Extern(["float64"], "__ocml_exp2_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)), + _Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)), ], ), lax.expm1_p: _make_dispatch_table( "expm1", cuda=[ - _Extern(["float32"], "__nv_expm1f", "float32"), - _Extern(["float64"], "__nv_expm1", "float64"), + _Extern([jnp.float32], "__nv_expm1f", jnp.float32), + _Extern([jnp.float64], "__nv_expm1", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_expm1_f32", "float32"), - _Extern(["float64"], "__ocml_expm1_f64", "float64"), + _Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64), ], ), lax.log_p: _make_dispatch_table( "log", cuda=[ - _Extern(["float32"], "__nv_logf", "float32"), - _Extern(["float64"], "__nv_log", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.log(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)), + _Extern([jnp.float32], "__nv_logf", jnp.float32), + _Extern([jnp.float64], "__nv_log", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)), ], rocm=[ - _Extern(["float32"], "__ocml_log_f32", "float32"), - _Extern(["float64"], "__ocml_log_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.log(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)), + _Extern([jnp.float32], "__ocml_log_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_log_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)), ], ), lax.log1p_p: _make_dispatch_table( "log1p", cuda=[ - _Extern(["float32"], "__nv_log1pf", "float32"), - _Extern(["float64"], "__nv_log1p", "float64"), + _Extern([jnp.float32], "__nv_log1pf", jnp.float32), + _Extern([jnp.float64], "__nv_log1p", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_log1p_f32", "float32"), - _Extern(["float64"], "__ocml_log1p_f64", "float64"), + _Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64), ], ), lax.sqrt_p: _make_dispatch_table( "sqrt", cuda=[ - _Extern(["float32"], "__nv_sqrtf", "float32"), - _Extern(["float64"], "__nv_sqrt", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)), + _Extern([jnp.float32], "__nv_sqrtf", jnp.float32), + _Extern([jnp.float64], "__nv_sqrt", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)), ], rocm=[ - _Extern(["float32"], "__ocml_sqrt_f32", "float32"), - _Extern(["float64"], "__ocml_sqrt_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)), + _Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)), ], ), lax.pow_p: _make_dispatch_table( "pow", cuda=[ - _Extern(["float32", "int32"], "__nv_powif", "float32"), - _Extern(["float64", "int32"], "__nv_powi", "float64"), - _Extern(["float32", "float32"], "__nv_powf", "float32"), - _Extern(["float64", "float64"], "__nv_pow", "float64"), + _Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32), + _Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64), + _Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32), + _Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64), ], rocm=[ - _Extern(["float32", "int32"], "__ocml_pown_f32", "float32"), - _Extern(["float64", "int32"], "__ocml_pown_f64", "float64"), - _Extern(["float32", "float32"], "__ocml_pow_f32", "float32"), - _Extern(["float64", "float64"], "__ocml_pow_f64", "float64"), + _Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32), + _Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64), + _Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32), + _Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64), ], ), lax.cbrt_p: _make_dispatch_table( "cbrt", cuda=[ - _Extern(["float32"], "__nv_cbrtf", "float32"), - _Extern(["float64"], "__nv_cbrt", "float64"), + _Extern([jnp.float32], "__nv_cbrtf", jnp.float32), + _Extern([jnp.float64], "__nv_cbrt", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_cbrt_f32", "float32"), - _Extern(["float64"], "__ocml_cbrt_f64", "float64"), + _Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64), ], ), lax.rsqrt_p: _make_dispatch_table( "rsqrt", cuda=[ - _Extern(["float32"], "__nv_rsqrtf", "float32"), - _Extern(["float64"], "__nv_rsqrt", "float64"), + _Extern([jnp.float32], "__nv_rsqrtf", jnp.float32), + _Extern([jnp.float64], "__nv_rsqrt", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_rsqrt_f32", "float32"), - _Extern(["float64"], "__ocml_rsqrt_f64", "float64"), + _Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64), ], ), lax.sin_p: _make_dispatch_table( "sin", cuda=[ - _Extern(["float32"], "__nv_sinf", "float32"), - _Extern(["float64"], "__nv_sin", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)), + _Extern([jnp.float32], "__nv_sinf", jnp.float32), + _Extern([jnp.float64], "__nv_sin", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)), ], rocm=[ - _Extern(["float32"], "__ocml_sin_f32", "float32"), - _Extern(["float64"], "__ocml_sin_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)), + _Extern([jnp.float32], "__ocml_sin_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_sin_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)), ], ), lax.cos_p: _make_dispatch_table( "cos", cuda=[ - _Extern(["float32"], "__nv_cosf", "float32"), - _Extern(["float64"], "__nv_cos", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)), + _Extern([jnp.float32], "__nv_cosf", jnp.float32), + _Extern([jnp.float64], "__nv_cos", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)), ], rocm=[ - _Extern(["float32"], "__ocml_cos_f32", "float32"), - _Extern(["float64"], "__ocml_cos_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)), + _Extern([jnp.float32], "__ocml_cos_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_cos_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)), ], ), lax.tan_p: _make_dispatch_table( "tan", cuda=[ - _Extern(["float32"], "__nv_tanf", "float32"), - _Extern(["float64"], "__nv_tan", "float64"), + _Extern([jnp.float32], "__nv_tanf", jnp.float32), + _Extern([jnp.float64], "__nv_tan", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_tan_f32", "float32"), - _Extern(["float64"], "__ocml_tan_f64", "float64"), + _Extern([jnp.float32], "__ocml_tan_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_tan_f64", jnp.float64), ], ), lax.asin_p: _make_dispatch_table( "asin", cuda=[ - _Extern(["float32"], "__nv_asinf", "float32"), - _Extern(["float64"], "__nv_asin", "float64"), + _Extern([jnp.float32], "__nv_asinf", jnp.float32), + _Extern([jnp.float64], "__nv_asin", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_asin_f32", "float32"), - _Extern(["float64"], "__ocml_asin_f64", "float64"), + _Extern([jnp.float32], "__ocml_asin_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_asin_f64", jnp.float64), ], ), lax.acos_p: _make_dispatch_table( "acos", cuda=[ - _Extern(["float32"], "__nv_acosf", "float32"), - _Extern(["float64"], "__nv_acos", "float64"), + _Extern([jnp.float32], "__nv_acosf", jnp.float32), + _Extern([jnp.float64], "__nv_acos", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_acos_f32", "float32"), - _Extern(["float64"], "__ocml_acos_f64", "float64"), + _Extern([jnp.float32], "__ocml_acos_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_acos_f64", jnp.float64), ], ), lax.atan_p: _make_dispatch_table( "atan", cuda=[ - _Extern(["float32"], "__nv_atanf", "float32"), - _Extern(["float64"], "__nv_atan", "float64"), + _Extern([jnp.float32], "__nv_atanf", jnp.float32), + _Extern([jnp.float64], "__nv_atan", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_atan_f32", "float32"), - _Extern(["float64"], "__ocml_atan_f64", "float64"), + _Extern([jnp.float32], "__ocml_atan_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_atan_f64", jnp.float64), ], ), lax.atan2_p: _make_dispatch_table( "atan2", cuda=[ - _Extern(["float32", "float32"], "__nv_atan2f", "float32"), - _Extern(["float64", "float64"], "__nv_atan2", "float64"), + _Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32), + _Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64), ], rocm=[ - _Extern(["float32", "float32"], "__ocml_atan2_f32", "float32"), - _Extern(["float64", "float64"], "__ocml_atan2_f64", "float64"), + _Extern( + [jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32 + ), + _Extern( + [jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64 + ), ], ), lax.sinh_p: _make_dispatch_table( "sinh", cuda=[ - _Extern(["float32"], "__nv_sinhf", "float32"), - _Extern(["float64"], "__nv_sinh", "float64"), + _Extern([jnp.float32], "__nv_sinhf", jnp.float32), + _Extern([jnp.float64], "__nv_sinh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_sinh_f32", "float32"), - _Extern(["float64"], "__ocml_sinh_f64", "float64"), + _Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64), ], ), lax.cosh_p: _make_dispatch_table( "cosh", cuda=[ - _Extern(["float32"], "__nv_coshf", "float32"), - _Extern(["float64"], "__nv_cosh", "float64"), + _Extern([jnp.float32], "__nv_coshf", jnp.float32), + _Extern([jnp.float64], "__nv_cosh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_cosh_f32", "float32"), - _Extern(["float64"], "__ocml_cosh_f64", "float64"), + _Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64), ], ), lax.tanh_p: _make_dispatch_table( "tanh", cuda=[ - _Extern(["float32"], "__nv_tanhf", "float32"), - _Extern(["float64"], "__nv_tanh", "float64"), + _Extern([jnp.float32], "__nv_tanhf", jnp.float32), + _Extern([jnp.float64], "__nv_tanh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_tanh_f32", "float32"), - _Extern(["float64"], "__ocml_tanh_f64", "float64"), + _Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64), ], ), lax.asinh_p: _make_dispatch_table( "asinh", cuda=[ - _Extern(["float32"], "__nv_asinhf", "float32"), - _Extern(["float64"], "__nv_asinh", "float64"), + _Extern([jnp.float32], "__nv_asinhf", jnp.float32), + _Extern([jnp.float64], "__nv_asinh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_asinh_f32", "float32"), - _Extern(["float64"], "__ocml_asinh_f64", "float64"), + _Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64), ], ), lax.acosh_p: _make_dispatch_table( "acosh", cuda=[ - _Extern(["float32"], "__nv_acoshf", "float32"), - _Extern(["float64"], "__nv_acosh", "float64"), + _Extern([jnp.float32], "__nv_acoshf", jnp.float32), + _Extern([jnp.float64], "__nv_acosh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_acosh_f32", "float32"), - _Extern(["float64"], "__ocml_acosh_f64", "float64"), + _Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64), ], ), lax.atanh_p: _make_dispatch_table( "atanh", cuda=[ - _Extern(["float32"], "__nv_atanhf", "float32"), - _Extern(["float64"], "__nv_atanh", "float64"), + _Extern([jnp.float32], "__nv_atanhf", jnp.float32), + _Extern([jnp.float64], "__nv_atanh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_atanh_f32", "float32"), - _Extern(["float64"], "__ocml_atanh_f64", "float64"), + _Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64), ], ), lax.population_count_p: _make_dispatch_table( "population_count", cuda=[ - _Extern(["int32"], "__nv_popc", "int32"), - _Extern(["int64"], "__nv_popcll", "int32"), + _Extern([jnp.int32], "__nv_popc", jnp.int32), + _Extern([jnp.int64], "__nv_popcll", jnp.int32), ], rocm=[ - _Fallback(["int32"], lambda ctx, x: math_dialect.ctpop(x)), - _Fallback(["int64"], lambda ctx, x: math_dialect.ctpop(x)), + _Fallback([jnp.int32], lambda ctx, x: math_dialect.ctpop(x)), + _Fallback([jnp.int64], lambda ctx, x: math_dialect.ctpop(x)), ], ), lax.clz_p: _make_dispatch_table( "clz", cuda=[ - _Extern(["int32"], "__nv_clz", "int32"), - _Extern(["int64"], "__nv_clzll", "int32"), + _Extern([jnp.int32], "__nv_clz", jnp.int32), + _Extern([jnp.int64], "__nv_clzll", jnp.int32), ], rocm=[ - _Fallback(["int32"], lambda ctx, x: math_dialect.ctlz(x)), - _Fallback(["int64"], lambda ctx, x: math_dialect.ctlz(x)), + _Fallback([jnp.int32], lambda ctx, x: math_dialect.ctlz(x)), + _Fallback([jnp.int64], lambda ctx, x: math_dialect.ctlz(x)), ], ), lax.nextafter_p: _make_dispatch_table( "nextafter", cuda=[ - _Extern(["float32", "float32"], "__nv_nextafterf", "float32"), - _Extern(["float64", "float64"], "__nv_nextafter", "float64"), + _Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32 ), + _Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64), ], rocm=[ - _Extern(["float32", "float32"], "__ocml_nextafter_f32", "float32"), - _Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"), + _Extern( + [jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32 + ), + _Extern( + [jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64 + ), ], ), lax.erf_inv_p: _make_dispatch_table( "erf_inv", cuda=[ _Fallback( - ["float32"], + [jnp.float32], lower_fun( pallas_utils.erf_inv_32_lowering_helper, multiple_results=False, @@ -1006,7 +1014,7 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): ], rocm=[ _Fallback( - ["float32"], + [jnp.float32], lower_fun( pallas_utils.erf_inv_32_lowering_helper, multiple_results=False, @@ -2215,7 +2223,7 @@ def _argreduce_lowering( if i != axis: index = _expand_dims(index, i) index = _bcast_to(index, a_aval.shape) - ctx = ctx.replace(avals_in=[a_aval, a_aval.update(dtype=jnp.dtype("int32"))]) + ctx = ctx.replace(avals_in=[a_aval, a_aval.update(dtype=jnp.dtype(jnp.int32))]) _, indices = _reduction_lowering(body, ctx, (a, index), axes=axes) return indices From 24aab154d5c8eb7eec844ed1b6122ab2f6ffd11b Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 30 Aug 2024 21:31:26 +0530 Subject: [PATCH 298/702] Better dos for jax.numpy: exp, exp2, and expm1 --- jax/_src/numpy/ufuncs.py | 108 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index dfeff38df0fe..83b3e3a08b23 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -203,9 +203,45 @@ def ceil(x: ArrayLike, /) -> Array: return lax.asarray(x) return lax.ceil(*promote_args_inexact('ceil', x)) -@implements(np.exp, module='numpy') + @partial(jit, inline=True) def exp(x: ArrayLike, /) -> Array: + """Calculate element-wise exponential of the input. + + JAX implementation of :obj:`numpy.exp`. + + Args: + x: input array or scalar + + Returns: + An array containing the exponential of each element in ``x``, promotes to + inexact dtype. + + See also: + - :func:`jax.numpy.log`: Calculates element-wise logarithm of the input. + - :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the + input. + - :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of + the input. + + Examples: + ``jnp.exp`` follows the properties of exponential such as :math:`e^{(a+b)} + = e^a * e^b`. + + >>> x1 = jnp.array([2, 4, 3, 1]) + >>> x2 = jnp.array([1, 3, 2, 3]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.exp(x1+x2)) + [ 20.09 1096.63 148.41 54.6 ] + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.exp(x1)*jnp.exp(x2)) + [ 20.09 1096.63 148.41 54.6 ] + + This property holds for complex input also: + + >>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j)) + Array(True, dtype=bool) + """ return lax.exp(*promote_args_inexact('exp', x)) @implements(np.log, module='numpy') @@ -213,9 +249,48 @@ def exp(x: ArrayLike, /) -> Array: def log(x: ArrayLike, /) -> Array: return lax.log(*promote_args_inexact('log', x)) -@implements(np.expm1, module='numpy') + @partial(jit, inline=True) def expm1(x: ArrayLike, /) -> Array: + """Calculate ``exp(x)-1`` of each element of the input. + + JAX implementation of :obj:`numpy.expm1`. + + Args: + x: input array or scalar. + + Returns: + An array containing ``exp(x)-1`` of each element in ``x``, promotes to inexact + dtype. + + Note: + ``jnp.expm1`` has much higher precision than the naive computation of + ``exp(x)-1`` for small values of ``x``. + + See also: + - :func:`jax.numpy.log1p`: Calculates element-wise logarithm of one plus input. + - :func:`jax.numpy.exp`: Calculates element-wise exponential of the input. + - :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of + the input. + + Examples: + >>> x = jnp.array([2, -4, 3, -1]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.expm1(x)) + [ 6.39 -0.98 19.09 -0.63] + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.exp(x)-1) + [ 6.39 -0.98 19.09 -0.63] + + For values very close to 0, ``jnp.expm1(x)`` is much more accurate than + ``jnp.exp(x)-1``: + + >>> x1 = jnp.array([1e-4, 1e-6, 2e-10]) + >>> jnp.expm1(x1) + Array([1.0000500e-04, 1.0000005e-06, 2.0000000e-10], dtype=float32) + >>> jnp.exp(x1)-1 + Array([1.00016594e-04, 9.53674316e-07, 0.00000000e+00], dtype=float32) + """ return lax.expm1(*promote_args_inexact('expm1', x)) @implements(np.log1p, module='numpy') @@ -968,9 +1043,36 @@ def log10(x: ArrayLike, /) -> Array: return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) -@implements(np.exp2, module='numpy') @partial(jit, inline=True) def exp2(x: ArrayLike, /) -> Array: + """Calculate element-wise base-2 exponential of input. + + JAX implementation of :obj:`numpy.exp2`. + + Args: + x: input array or scalar + + Returns: + An array containing the base-2 exponential of each element in ``x``, promotes + to inexact dtype. + + See also: + - :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input. + - :func:`jax.numpy.exp`: Calculates exponential of each element of the input. + - :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the + input. + + Examples: + ``jnp.exp2`` follows the properties of the exponential such as :math:`2^{a+b} + = 2^a * 2^b`. + + >>> x1 = jnp.array([2, -4, 3, -1]) + >>> x2 = jnp.array([-1, 3, -2, 3]) + >>> jnp.exp2(x1+x2) + Array([2. , 0.5, 2. , 4. ], dtype=float32) + >>> jnp.exp2(x1)*jnp.exp2(x2) + Array([2. , 0.5, 2. , 4. ], dtype=float32) + """ x, = promote_args_inexact("exp2", x) return lax.exp2(x) From 9d757cdb85d5f4b854abd23fb9178280c9c1f481 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 30 Aug 2024 21:35:44 +0530 Subject: [PATCH 299/702] Use :obj: instead of :func: in ufuncs.py --- jax/_src/numpy/ufuncs.py | 48 +++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index dfeff38df0fe..1f66ee22265b 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -58,7 +58,7 @@ def _to_bool(x: Array) -> Array: def fabs(x: ArrayLike, /) -> Array: """Compute the element-wise absolute values of the real-valued input. - JAX implementation of :func:`numpy.fabs`. + JAX implementation of :obj:`numpy.fabs`. Args: x: input array or scalar. Must not have a complex dtype. @@ -132,7 +132,7 @@ def sign(x: ArrayLike, /) -> Array: def floor(x: ArrayLike, /) -> Array: """Round input to the nearest integer downwards. - JAX implementation of :func:`numpy.floor`. + JAX implementation of :obj:`numpy.floor`. Args: x: input array or scalar. Must not have complex dtype. @@ -170,7 +170,7 @@ def floor(x: ArrayLike, /) -> Array: def ceil(x: ArrayLike, /) -> Array: """Round input to the nearest integer upwards. - JAX implementation of :func:`numpy.ceil`. + JAX implementation of :obj:`numpy.ceil`. Args: x: input array or scalar. Must not have complex dtype. @@ -466,7 +466,7 @@ def bitwise_count(x: ArrayLike, /) -> Array: r"""Counts the number of 1 bits in the binary representation of the absolute value of each element of ``x``. - LAX-backend implementation of :func:`numpy.bitwise_count`. + JAX implementation of :obj:`numpy.bitwise_count`. Args: x: Input array, only accepts integer subtypes @@ -500,7 +500,7 @@ def bitwise_count(x: ArrayLike, /) -> Array: def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Right shift the bits of ``x1`` to the amount specified in ``x2``. - LAX-backend implementation of :func:`numpy.right_shift`. + JAX implementation of :obj:`numpy.right_shift`. Args: x1: Input array, only accepts unsigned integer subtypes @@ -559,7 +559,7 @@ def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: def absolute(x: ArrayLike, /) -> Array: r"""Calculate the absolute value element-wise. - LAX-backend implementation of :func:`numpy.absolute`. + JAX implementation of :obj:`numpy.absolute`. This is the same function as :func:`jax.numpy.abs`. @@ -600,7 +600,7 @@ def abs(x: ArrayLike, /) -> Array: def rint(x: ArrayLike, /) -> Array: """Rounds the elements of x to the nearest integer - LAX-backend implementation of :func:`numpy.rint`. + JAX implementation of :obj:`numpy.rint`. Args: x: Input array @@ -639,7 +639,7 @@ def rint(x: ArrayLike, /) -> Array: def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Copies the sign of each element in ``x2`` to the corresponding element in ``x1``. - LAX-backend implementation of :func:`numpy.copysign`. + JAX implementation of :obj:`numpy.copysign`. Args: x1: Input array @@ -687,7 +687,7 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the floor division of x1 by x2 element-wise - LAX-backend implementation of :func:`numpy.floor_divide`. + JAX implementation of :obj:`numpy.floor_divide`. Args: x1: Input array, the dividend @@ -698,6 +698,14 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: to the nearest integer towards negative infinity. This is equivalent to ``x1 // x2`` in Python. + Note: + ``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1`` + and ``x2`` + + See Also: + :func:`jax.numpy.divide` and :func:`jax.numpy.true_divide` for floating point + division. + Examples: >>> x1 = jnp.array([10, 20, 30]) >>> x2 = jnp.array([3, 4, 7]) @@ -713,12 +721,6 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: >>> x2 = jnp.array([2.0, 2.5, 3.0], dtype=jnp.float32) >>> jnp.floor_divide(x1, x2) Array([3., 2., 2.], dtype=float32) - - Note: - ``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1`` and ``x2`` - - See Also: - :func:`jnp.divide` and :func:`jnp.true_divide` for floating point division """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) dtype = dtypes.dtype(x1) @@ -739,7 +741,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: """Calculates the integer quotient and remainder of x1 by x2 element-wise - LAX-backend implementation of :func:`numpy.divmod`. + JAX implementation of :obj:`numpy.divmod`. Args: x1: Input array, the dividend @@ -748,6 +750,10 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: Returns: A tuple of arrays ``(x1 // x2, x1 % x2)``. + See Also: + - :func:`jax.numpy.floor_divide`: floor division function + - :func:`jax.numpy.remainder`: remainder function + Examples: >>> x1 = jnp.array([10, 20, 30]) >>> x2 = jnp.array([3, 4, 7]) @@ -765,10 +771,6 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: >>> jnp.divmod(x1, x2) (Array([3., 2., 1.], dtype=float32), Array([0.30000007, 1. , 2.9 ], dtype=float32)) - - See Also: - - :func:`jax.numpy.floor_divide`: floor division function - - :func:`jax.numpy.remainder`: remainder function """ x1, x2 = promote_args_numeric("divmod", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.integer): @@ -862,7 +864,7 @@ def _pow_int_int(x1, x2): def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. - JAX implementation of :func:`numpy.logaddexp` + JAX implementation of :obj:`numpy.logaddexp` Args: x1: input array @@ -927,7 +929,7 @@ def _logaddexp2_jvp(primals, tangents): def log2(x: ArrayLike, /) -> Array: """Calculates the base-2 logarithm of x element-wise - LAX-backend implementation of :func:`numpy.log2`. + JAX implementation of :obj:`numpy.log2`. Args: x: Input array @@ -949,7 +951,7 @@ def log2(x: ArrayLike, /) -> Array: def log10(x: ArrayLike, /) -> Array: """Calculates the base-10 logarithm of x element-wise - LAX-backend implementation of :func:`numpy.log10`. + JAX implementation of :obj:`numpy.log10`. Args: x: Input array From db4be03f0255ac26ed30f2ff149cc17995e2c8d3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 30 Aug 2024 09:08:56 -0700 Subject: [PATCH 300/702] Disable many eigh tests. These started failing due to a compiler change internally at Google, but the tests themselves are buggy. It is not correct to compare an eigendecomposition for equality up to a tolerance, because the eigenvalues are sorted, and all it takes is a tiny perturbation to reorder the eigenvalues and eigenvectors, which leads to a result that looks very different. PiperOrigin-RevId: 669346013 --- jax/experimental/jax2tf/tests/primitives_test.py | 5 +++++ jax/experimental/jax2tf/tests/shape_poly_test.py | 5 +++++ tests/export_harnesses_multi_platform_test.py | 4 ++++ tests/shape_poly_test.py | 5 +++++ 4 files changed, 19 insertions(+) diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 5169ba8ab252..485fa6e5831f 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -119,6 +119,11 @@ def test_prim(self, harness: test_harnesses.Harness): device == "tpu"): raise unittest.SkipTest("b/264716764: error on tf.cast from c64 to f32") + if ("eigh" == harness.group_name and + device == "cpu"): + raise unittest.SkipTest( + "Equality comparisons on eigendecompositions are not stable.") + if (config.jax2tf_default_native_serialization.value and device == "gpu" and "lu" in harness.fullname): diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index a34c431edab9..2475a062f5ec 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -2665,6 +2665,11 @@ def test_harness(self, harness: PolyHarness): if 0 < shape[-1] <= 32: harness.check_result = False + if harness.group_name == "vmap_eigh": + raise unittest.SkipTest( + "Should not compare eigendecompositions for equality directly" + "because eigenvalues are sorted.") + if harness.group_name == "vmap_tan": # Tan (b/274462307) require support for custom call stablehlo.tan. raise unittest.SkipTest( diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 21ad29c7a4c9..0f0c20fd78e3 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -100,6 +100,10 @@ def setUpClass(cls): ) @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_prim(self, harness: test_harnesses.Harness): + if "eigh_" in harness.fullname: + self.skipTest("Eigenvalues are sorted and it is not correct to compare " + "decompositions for equality.") + if (jtu.device_under_test() == "gpu" and _known_failures_gpu.search(harness.fullname)): self.skipTest("failure to be investigated") diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 0a6f955cd5c3..ee10a0ce2637 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -3452,6 +3452,11 @@ def test_harness(self, harness: PolyHarness): if 0 < shape[-1] <= 32: harness.check_result = False + if harness.group_name == "vmap_eigh": + raise unittest.SkipTest( + "Should not compare eigendecompositions for equality directly" + "because eigenvalues are sorted.") + if harness.group_name == "vmap_tan": # Tan (b/274462307) require support for custom call stablehlo.tan. raise unittest.SkipTest( From 164b884f33e3f886605d9ef16a7da9899e54ca00 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 30 Aug 2024 09:49:22 -0700 Subject: [PATCH 301/702] Fix failing tests in CI PiperOrigin-RevId: 669357019 --- jax/_src/lax/lax.py | 10 ++++++++-- tests/memories_test.py | 2 ++ tests/pjit_test.py | 5 ++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 29fa41012f80..618c715ba763 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1822,7 +1822,10 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains # CompilerShardingAxis, then specify `unspecified_dims` via # `wrap_with_sharding_op`. - sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() + if config.use_shardy_partitioner.value: + sp = in_aval.sharding._to_sdy_sharding(in_aval.ndim) + else: + sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp)) return out @@ -1846,7 +1849,10 @@ def _nary_lower_hlo(op: Callable, ctx, else: out = op(*args) if config.sharding_in_types.value: - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + if config.use_shardy_partitioner.value: + out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim) + else: + out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)] else: return [out] diff --git a/tests/memories_test.py b/tests/memories_test.py index 6140c6945df5..f761f490ad4e 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -559,6 +559,7 @@ def test_identity_jit_host_to_device_and_vice_versa(self): self.assertArraysEqual(out_host, np_inp) self.assertEqual(out_host.sharding, s_host) + @jtu.skip_on_devices("gpu") def test_parameter_streaming_inside_scan(self): mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z")) np_inp = np.arange(4096.0).reshape(16, 16, 16) @@ -1439,6 +1440,7 @@ def f(x): if jtu.pjrt_c_api_version_at_least(0, 43): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + @jtu.skip_on_devices("gpu") def test_remat_scan_jaxpr_offloadable(self): mesh = jtu.create_global_mesh((2,), ("x",)) shape = (256, 128) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 1e4b5685d45a..97dd65176a79 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4566,7 +4566,10 @@ def f(x): self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2)) lowered_text = f.lower(arr).as_text() - self.assertEqual(lowered_text.count('@Sharding'), 2) + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + else: + self.assertEqual(lowered_text.count('@Sharding'), 2) @jtu.pytest_mark_if_available('multiaccelerator') From 2f3990d13cbc9093bf11255a08a0c324314837fa Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 30 Aug 2024 09:57:53 -0700 Subject: [PATCH 302/702] Remove CPU test variant. PiperOrigin-RevId: 669359594 --- tests/BUILD | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index ef5f27f9bccb..45743d306fd6 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -35,7 +35,6 @@ jax_test( name = "api_test", srcs = ["api_test.py"], shard_count = 10, - tags = ["test_cpu_thunks"], ) jax_test( @@ -339,7 +338,6 @@ jax_test( jax_test( name = "infeed_test", srcs = ["infeed_test.py"], - tags = ["test_cpu_thunks"], deps = [ "//jax:experimental_host_callback", ], @@ -349,7 +347,6 @@ jax_test( name = "jax_jit_test", srcs = ["jax_jit_test.py"], main = "jax_jit_test.py", - tags = ["test_cpu_thunks"], ) py_test( @@ -440,7 +437,6 @@ jax_test( "gpu": 30, "tpu": 40, }, - tags = ["test_cpu_thunks"], ) jax_test( @@ -451,7 +447,6 @@ jax_test( "gpu": 20, "tpu": 20, }, - tags = ["test_cpu_thunks"], ) jax_test( @@ -472,7 +467,6 @@ jax_test( "gpu": 10, "tpu": 10, }, - tags = ["test_cpu_thunks"], ) jax_test( @@ -483,13 +477,11 @@ jax_test( "gpu": 10, "tpu": 10, }, - tags = ["test_cpu_thunks"], ) jax_test( name = "lax_numpy_vectorize_test", srcs = ["lax_numpy_vectorize_test.py"], - tags = ["test_cpu_thunks"], ) jax_test( @@ -554,7 +546,6 @@ jax_test( "gpu": 40, "tpu": 40, }, - tags = ["test_cpu_thunks"], deps = [ "//jax:internal_test_util", "//jax:lax_reference", @@ -584,7 +575,6 @@ jax_test( "gpu": 40, "tpu": 20, }, - tags = ["test_cpu_thunks"], ) jax_test( @@ -595,7 +585,6 @@ jax_test( "gpu": 40, "tpu": 40, }, - tags = ["test_cpu_thunks"], deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), ) @@ -607,7 +596,6 @@ jax_test( "gpu": 40, "tpu": 40, }, - tags = ["test_cpu_thunks"], deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), ) @@ -650,7 +638,6 @@ jax_test( "gpu": 40, "tpu": 40, }, - tags = ["test_cpu_thunks"], ) jax_test( @@ -1169,7 +1156,6 @@ py_test( jax_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.py"], - tags = ["test_cpu_thunks"], deps = [ "//jax:compilation_cache_internal", "//jax:compiler", From 8ccc439d4a0ec23bd613f9a200cd1c154295873d Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Fri, 30 Aug 2024 10:11:19 -0700 Subject: [PATCH 303/702] PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention` Imported from GitHub PR https://github.com/google/jax/pull/23223 While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping. For the new tests, we categorize them as follows: 1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc. 2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations. Additionally, we will no longer maintain separate tests for inference and training. Copybara import of the project: -- dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih : Reduce attn tests Merging this change closes #23223 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5 PiperOrigin-RevId: 669364738 --- tests/nn_test.py | 173 +++++++++++++++++++---------------------------- 1 file changed, 70 insertions(+), 103 deletions(-) diff --git a/tests/nn_test.py b/tests/nn_test.py index a79cf738714b..3722db42671c 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -45,149 +45,116 @@ def _is_required_cudnn_version_satisfied(): cuda_versions.cudnn_get_version() >= 8904 ) -def _get_causal_mask(T, S): - causal_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) - return causal_mask[jnp.newaxis, jnp.newaxis, :, :] +def _check_cudnn_backend(fn, *args, **kwargs): + lowered = jax.jit(fn).lower(*args, **kwargs) + hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) + return '__cudnn$fmha' in hlo @jtu.with_config(jax_legacy_prng_key="allow", jax_numpy_dtype_promotion="standard") class NNFunctionsTest(jtu.JaxTestCase): @parameterized.product( - dtype=[jnp.float32, jnp.bfloat16, jnp.float16], - use_bias=[False, True], - causal_mode=[None, 'attr', 'mask'], + dtype=[jnp.bfloat16, jnp.float16], group_num=[1, 2, 4], use_vmap=[False, True], - use_seqlen=[False, True], - impl=['xla', 'cudnn'], + impl=['cudnn', 'xla'], ) - def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, - group_num, use_vmap, use_seqlen, impl): + def testDotProductAttention(self, dtype, group_num, use_vmap, impl): if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") - if use_vmap and use_seqlen: - raise unittest.SkipTest("vmap cannot be used together with variable " - "seqence lengths") - sdpa = nn.dot_product_attention B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num - keys = random.split(random.PRNGKey(0), 4) + keys = random.split(random.PRNGKey(0), 5) Q = random.normal(keys[0], (B, T, N, H), dtype) K = random.normal(keys[1], (B, S, N // G, H), dtype) V = random.normal(keys[2], (B, S, N // G, H), dtype) - if use_bias: - bias = random.normal(keys[3], (1, N, T, S), dtype) - else: - bias = None - if use_seqlen: - q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32) - kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32) - else: - q_seqlen = None - kv_seqlen = None + grad = random.normal(keys[3], (B, T, N, H), dtype) + bias, mask = None, None - is_causal = causal_mode == 'attr' - causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None + sdpa = nn.dot_product_attention + sdpa_ref = partial(sdpa, implementation=None) + sdpa_ans = partial(sdpa, implementation=impl) + if use_vmap: + sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) - sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) + # For testing purposes, we call the non-GQA version without vmap in the + # reference code + K_ref = jnp.repeat(K, G, axis=2) + V_ref = jnp.repeat(V, G, axis=2) + out_ref, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, mask) + out_ans, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, mask) + + dQ_ref, dK_ref, dV_ref = sdpa_vjp_ref(grad)[:3] + dQ_ans, dK_ans, dV_ans = sdpa_vjp_ans(grad)[:3] + dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) + dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) if impl == 'cudnn': - lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias, causal_mask, - query_seq_lengths=q_seqlen, - key_value_seq_lengths=kv_seqlen) - hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) - self.assertIn('__cudnn$fmha', hlo) - - K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K - V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V - out_ref = sdpa_ref(Q, K_ref, V_ref, bias, causal_mask, - query_seq_lengths=q_seqlen, - key_value_seq_lengths=kv_seqlen) - if use_vmap: - sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) + self.assertTrue(_check_cudnn_backend(sdpa_ans, Q, K, V, bias, mask)) + self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) - out_ans = sdpa_ans(Q, K, V, bias, causal_mask, - query_seq_lengths=q_seqlen, - key_value_seq_lengths=kv_seqlen) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) + self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) + self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) + self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) @parameterized.product( - dtype=[jnp.float32, jnp.bfloat16, jnp.float16], - use_bias=[False, True], - causal_mode=[None, 'attr', 'mask'], - group_num=[1, 2, 4], - use_vmap=[False, True], - use_seqlen=[False, True], - impl=['xla', 'cudnn'], + mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), + ('custom', 'padding'), ('bias', 'causal')], ) - def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, - group_num, use_vmap, use_seqlen, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): + def testDotProductAttentionMask(self, mask_mode): + if not _is_required_cudnn_version_satisfied(): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") - if impl == 'cudnn' and dtype == jnp.float32: - raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") - if use_vmap and use_seqlen: - raise unittest.SkipTest("vmap cannot be used together with variable " - "seqence lengths") - if use_seqlen and use_bias and impl == 'cudnn': - raise unittest.SkipTest("cudnn has limited support for dbias when using " - "variable seqence lengths") + if isinstance(mask_mode, str): + mask_mode = (mask_mode,) - sdpa = nn.dot_product_attention - B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num - keys = random.split(random.PRNGKey(0), 5) + dtype = jnp.bfloat16 + B, S, T, N, H = 2, 128, 128, 4, 32 + keys = random.split(random.PRNGKey(0), 4) Q = random.normal(keys[0], (B, T, N, H), dtype) - K = random.normal(keys[1], (B, S, N // G, H), dtype) - V = random.normal(keys[2], (B, S, N // G, H), dtype) + K = random.normal(keys[1], (B, S, N, H), dtype) + V = random.normal(keys[2], (B, S, N, H), dtype) grad = random.normal(keys[3], (B, T, N, H), dtype) - if use_bias: - bias = random.normal(keys[4], (1, N, T, S), dtype) - else: - bias = None - if use_seqlen: + bias, mask = None, None + q_seqlen, kv_seqlen = None, None + + is_causal = 'causal' in mask_mode + if 'padding' in mask_mode: q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32) kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32) - else: - q_seqlen = None - kv_seqlen = None - - is_causal = causal_mode == 'attr' - causal_mask = _get_causal_mask(T, S) if causal_mode == 'mask' else None + if 'custom' in mask_mode: + # Use a generated causal mask as the custom mask. + custom_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) + mask = custom_mask[None, None, :, :] + if 'bias' in mask_mode: + bias = random.normal(keys[4], (1, N, T, S), dtype) - K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K - V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V + sdpa = nn.dot_product_attention sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - # Convert the keyword arguments to positional ones. + sdpa_ans = partial(sdpa, is_causal=is_causal, implementation='cudnn') + + args = (Q, K, V, bias, mask) + kwargs = {'query_seq_lengths': q_seqlen, 'key_value_seq_lengths': kv_seqlen} + + # Convert the kargs to positional args for the jax.vjp. fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref( - q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs + q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + ) + fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( + q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, ) - _, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask, - q_seqlen, kv_seqlen) + out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen) + out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen) dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4] - if G != 1: - dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) - dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) - - sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) - if use_vmap and not use_seqlen: - sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) - _, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, causal_mask) - else: - fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( - q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs - ) - _, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask, q_seqlen, - kv_seqlen) dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4] - if impl == 'cudnn': - lowered = jax.jit(sdpa_vjp_ans).lower(grad) - hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) - self.assertRegex(hlo, r'__cudnn\$fmha.*Backward\(') + # Check if cudnn backend is called. + self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs)) + self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) + self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) From fb7fa2a09eade7b3ac2683ae14056a73b03f4a15 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 30 Aug 2024 10:52:58 -0700 Subject: [PATCH 304/702] Improved extern selection in Pallas GPU Previously, * weakly typed avals matched the wrong externs; * this was addressed by #23193, which disallowed weakly typed avals entirely. Here we check if a weakly typed aval can be casted to the extern input dtype when selecting an extern. PiperOrigin-RevId: 669378582 --- jax/_src/pallas/triton/lowering.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 3259ee634e1f..f2a4229223bd 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -580,6 +580,7 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool: return False return all( aval.dtype == jnp.dtype(arg_type) + or (aval.weak_type and aval.dtype.kind == jnp.dtype(arg_type).kind) for aval, arg_type in zip(avals, self.arg_types) ) From 1ab3119d438551b1b75d8b6fb567df8364db5681 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 30 Aug 2024 11:12:15 -0700 Subject: [PATCH 305/702] Add some msan suppressions to the LAPACK symmetric eigendecomposition FFI call. This fixes some msan false positives in our CI, since we do not msan-instrument Fortran code. PiperOrigin-RevId: 669385248 --- jaxlib/cpu/lapack_kernels.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index a6c5993f43b6..2bc62542e4da 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -1007,6 +1007,11 @@ ffi::Error EigenvalueDecompositionSymmetric::Kernel( fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v, eigenvalues_data, work_data.get(), &work_size_v, iwork_data.get(), &iwork_size_v, info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_out_data, + sizeof(*x_out_data) * x_cols * x_cols); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigenvalues_data, + sizeof(*eigenvalues_data) * x_cols); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int)); x_out_data += x_out_step; eigenvalues_data += eigenvalues_step; ++info_data; From a4c060c790a37c131f2bf9914a99b2fd1a502e2f Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 30 Aug 2024 23:45:12 +0530 Subject: [PATCH 306/702] Better docs for jax.numpy: positive, negative and sign --- jax/_src/numpy/ufuncs.py | 134 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 131 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index b9ba1317b313..37691e029f94 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -112,19 +112,147 @@ def bitwise_not(x: ArrayLike, /) -> Array: def invert(x: ArrayLike, /) -> Array: return lax.bitwise_not(*promote_args('invert', x)) -@implements(np.negative, module='numpy') + @partial(jit, inline=True) def negative(x: ArrayLike, /) -> Array: + """Return element-wise negative values of the input. + + JAX implementation of :obj:`numpy.negative`. + + Args: + x: input array or scalar. + + Returns: + An array with same shape and dtype as ``x`` containing ``-x``. + + See also: + - :func:`jax.numpy.positive`: Returns element-wise positive values of the input. + - :func:`jax.numpy.sign`: Returns element-wise indication of sign of the input. + + Note: + ``jnp.negative``, when applied over ``unsigned integer``, produces the result + of their two's complement negation, which typically results in unexpected + large positive values due to integer underflow. + + Examples: + For real-valued inputs: + + >>> x = jnp.array([0., -3., 7]) + >>> jnp.negative(x) + Array([-0., 3., -7.], dtype=float32) + + For complex inputs: + + >>> x1 = jnp.array([1-2j, -3+4j, 5-6j]) + >>> jnp.negative(x1) + Array([-1.+2.j, 3.-4.j, -5.+6.j], dtype=complex64) + + For unit32: + + >>> x2 = jnp.array([5, 0, -7]).astype(jnp.uint32) + >>> x2 + Array([ 5, 0, 4294967289], dtype=uint32) + >>> jnp.negative(x2) + Array([4294967291, 0, 7], dtype=uint32) + """ return lax.neg(*promote_args('negative', x)) -@implements(np.positive, module='numpy') + @partial(jit, inline=True) def positive(x: ArrayLike, /) -> Array: + """Return element-wise positive values of the input. + + JAX implementation of :obj:`numpy.positive`. + + Args: + x: input array or scalar + + Returns: + An array of same shape and dtype as ``x`` containing ``+x``. + + Note: + ``jnp.positive`` is equivalent to ``x.copy()`` and is defined only for the + types that support arithmetic operations. + + See also: + - :func:`jax.numpy.negative`: Returns element-wise negative values of the input. + - :func:`jax.numpy.sign`: Returns element-wise indication of sign of the input. + + Examples: + For real-valued inputs: + + >>> x = jnp.array([-5, 4, 7., -9.5]) + >>> jnp.positive(x) + Array([-5. , 4. , 7. , -9.5], dtype=float32) + >>> x.copy() + Array([-5. , 4. , 7. , -9.5], dtype=float32) + + For complex inputs: + + >>> x1 = jnp.array([1-2j, -3+4j, 5-6j]) + >>> jnp.positive(x1) + Array([ 1.-2.j, -3.+4.j, 5.-6.j], dtype=complex64) + >>> x1.copy() + Array([ 1.-2.j, -3.+4.j, 5.-6.j], dtype=complex64) + + For uint32: + + >>> x2 = jnp.array([6, 0, -4]).astype(jnp.uint32) + >>> x2 + Array([ 6, 0, 4294967292], dtype=uint32) + >>> jnp.positive(x2) + Array([ 6, 0, 4294967292], dtype=uint32) + """ return lax.asarray(*promote_args('positive', x)) -@implements(np.sign, module='numpy') + @partial(jit, inline=True) def sign(x: ArrayLike, /) -> Array: + r"""Return an element-wise indication of sign of the input. + + JAX implementation of :obj:`numpy.sign`. + + The sign of ``x`` for real-valued input is: + + .. math:: + \mathrm{sign}(x) = \begin{cases} + 1, & x > 0\\ + 0, & x = 0\\ + -1, & x < 0 + \end{cases} + + For complex valued input, ``jnp.sign`` returns a unit vector repesenting the + phase. For generalized case, the sign of ``x`` is given by: + + .. math:: + \mathrm{sign}(x) = \begin{cases} + \frac{x}{abs(x)}, & x \ne 0\\ + 0, & x = 0 + \end{cases} + + Args: + x: input array or scalar. + + Returns: + An array with same shape and dtype as ``x`` containing the sign indication. + + See also: + - :func:`jax.numpy.positive`: Returns element-wise positive values of the input. + - :func:`jax.numpy.negative`: Returns element-wise negative values of the input. + + Examples: + For Real-valued inputs: + + >>> x = jnp.array([0., -3., 7.]) + >>> jnp.sign(x) + Array([ 0., -1., 1.], dtype=float32) + + For complex-inputs: + + >>> x1 = jnp.array([1, 3+4j, 5j]) + >>> jnp.sign(x1) + Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64) + """ return lax.sign(*promote_args('sign', x)) From cd2040415957a2074b0879dec9046dada442e723 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 30 Aug 2024 11:28:59 -0700 Subject: [PATCH 307/702] Disable mosaic gpu tests that are failing at head. PiperOrigin-RevId: 669390680 --- tests/mosaic/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 255b03d3a002..9149891a2dea 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -49,6 +49,7 @@ jax_test( disable_configs = DISABLED_CONFIGS, enable_configs = ["gpu_h100_2gpu"], shard_count = 4, + tags = ["notap"], # Broken at head. deps = [ "//jax:mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), @@ -60,6 +61,7 @@ jax_test( disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, shard_count = 5, + tags = ["notap"], # Broken at head. deps = [ "//jax:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", From a3d6cf007e998223bd1fb0cc01255e0b5ca3f656 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 30 Aug 2024 11:53:02 -0700 Subject: [PATCH 308/702] First pass at ufunc interfaces for several jax.numpy functions --- CHANGELOG.md | 6 + jax/_src/numpy/reductions.py | 10 +- jax/_src/numpy/ufunc_api.py | 408 +++++++++++++++---- jax/_src/numpy/ufuncs.py | 151 ++++++- jax/experimental/jax2tf/tests/jax2tf_test.py | 4 +- jax/numpy/__init__.pyi | 48 ++- tests/lax_numpy_ufuncs_test.py | 243 ++++++++--- 7 files changed, 687 insertions(+), 183 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a3dac8c3152..f59b07cd237e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * ``jax.tree_util.register_dataclass`` now checks that ``data_fields`` and ``meta_fields`` includes all dataclass fields with ``init=True`` and only them, if ``nodetype`` is a dataclass. + * Several {mod}`jax.numpy` functions now have full {class}`~jax.numpy.ufunc` + interfaces, including {obj}`~jax.numpy.add`, {obj}`~jax.numpy.multiply`, + {obj}`~jax.numpy.bitwise_and`, {obj}`~jax.numpy.bitwise_or`, + {obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`, + {obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`. + In future releases we plan to expand these to other ufuncs. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index e8815c943ce6..dddb44dc9207 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -30,7 +30,6 @@ from jax._src import core from jax._src import deprecations from jax._src import dtypes -from jax._src.numpy import ufuncs from jax._src.numpy.util import ( _broadcast_to, check_arraylike, _complex_elem_type, promote_dtypes_inexact, promote_dtypes_numeric, _where, implements) @@ -2039,9 +2038,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, a_shape = a.shape if squash_nans: - a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. + a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. a = lax.sort(a, dimension=axis) - counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) + counts = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) shape_after_reduction = counts.shape q = lax.expand_dims( q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) @@ -2067,7 +2066,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, index[axis] = high high_value = a[tuple(index)] else: - a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a) + a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) q = lax.mul(q, n - 1) @@ -2223,7 +2222,8 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, Array([1.5, 3. , 4.5], dtype=float32) """ check_arraylike("nanpercentile", a, q) - q = ufuncs.true_divide(q, 100.0) + q, = promote_dtypes_inexact(q) + q = q / 100 if not isinstance(interpolation, DeprecatedArg): deprecations.warn( "jax-numpy-quantile-interpolation", diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 2e114193af13..3473e8a7468a 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -25,13 +25,11 @@ import jax from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.lax import lax as lax_internal -from jax._src.numpy import reductions -from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take +import jax._src.numpy.lax_numpy as jnp from jax._src.numpy.reductions import _moveaxis -from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where +from jax._src.numpy.util import check_arraylike, _broadcast_to, _where from jax._src.numpy.vectorize import vectorize from jax._src.util import canonicalize_axis, set_module -from jax._src import pjit import numpy as np @@ -42,81 +40,126 @@ """ -def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Primitive | None: - """ - If fun(*args) lowers to a single primitive with inputs and outputs matching - function inputs and outputs, return that primitive. Otherwise return None. - """ - try: - jaxpr = jax.make_jaxpr(fun)(*args) - except: - return None - while len(jaxpr.eqns) == 1: - eqn = jaxpr.eqns[0] - if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars): - return None - elif (eqn.primitive == pjit.pjit_p and - all(pjit.is_unspecified(sharding) for sharding in - (*eqn.params['in_shardings'], *eqn.params['out_shardings']))): - jaxpr = jaxpr.eqns[0].params['jaxpr'] - else: - return jaxpr.eqns[0].primitive - return None +@set_module('jax.numpy') +class ufunc: + """Universal functions which operation element-by-element on arrays. + JAX implementation of :class:`numpy.ufunc`. -_primitive_reducers: dict[jax.core.Primitive, Callable[..., Any]] = { - lax_internal.add_p: reductions.sum, - lax_internal.mul_p: reductions.prod, -} + This is a class for JAX-backed implementations of NumPy's ufunc APIs. + Most users will never need to instantiate :class:`ufunc`, but rather + will use the pre-defined ufuncs in :mod:`jax.numpy`. + For constructing your own ufuncs, see :func:`jax.numpy.frompyfunc`. -_primitive_accumulators: dict[jax.core.Primitive, Callable[..., Any]] = { - lax_internal.add_p: reductions.cumsum, - lax_internal.mul_p: reductions.cumprod, -} + Examples: + Universal functions are functions that apply element-wise to broadcasted + arrays, but they also come with a number of extra attributes and methods. + As an example, consider the function :obj:`jax.numpy.add`. The object + acts as a function that applies addition to broadcasted arrays in an + element-wise manner: -@set_module('jax.numpy') -class ufunc: - """Functions that operate element-by-element on whole arrays. + >>> x = jnp.array([1, 2, 3, 4, 5]) + >>> jnp.add(x, 1) + Array([2, 3, 4, 5, 6], dtype=int32) + + Each :class:`ufunc` object includes a number of attributes that describe + its behavior: + + >>> jnp.add.nin # number of inputs + 2 + >>> jnp.add.nout # number of outputs + 1 + >>> jnp.add.identity # identity value, or None if no identity exists + 0 + + Binary ufuncs like :obj:`jax.numpy.add` include number of methods to + apply the function to arrays in different manners. + + The :meth:`~ufunc.outer` method applies the function to the + pair-wise outer-product of the input array values: + + >>> jnp.add.outer(x, x) + Array([[ 2, 3, 4, 5, 6], + [ 3, 4, 5, 6, 7], + [ 4, 5, 6, 7, 8], + [ 5, 6, 7, 8, 9], + [ 6, 7, 8, 9, 10]], dtype=int32) - This is a class for LAX-backed implementations of numpy ufuncs. + The :meth:`ufunc.reduce` method perfoms a reduction over the array. + For example, :meth:`jnp.add.reduce` is equivalent to ``jnp.sum``: + + >>> jnp.add.reduce(x) + Array(15, dtype=int32) + + The :meth:`ufunc.accumulate` method performs a cumulative reduction + over the array. For example, :meth:`jnp.add.accumulate` is equivalent + to :func:`jax.numpy.cumulative_sum`: + + >>> jnp.add.accumulate(x) + Array([ 1, 3, 6, 10, 15], dtype=int32) + + The :meth:`ufunc.at` method applies the function at particular indices in the + array; for ``jnp.add`` the computation is similar to :func:`jax.lax.scatter_add`: + + >>> jnp.add.at(x, 0, 100, inplace=False) + Array([101, 2, 3, 4, 5], dtype=int32) + + And the :meth:`ufunc.reduceat` method performs a number of ``reduce`` + operations bewteen specified indices of an array; for ``jnp.add`` the + operation is similar to :func:`jax.ops.segment_sum`: + + >>> jnp.add.reduceat(x, jnp.array([0, 2])) + Array([ 3, 12], dtype=int32) + + In this case, the first element is ``x[0:2].sum()``, and the second element + is ``x[2:].sum()``. """ def __init__(self, func: Callable[..., Any], /, nin: int, nout: int, *, name: str | None = None, nargs: int | None = None, - identity: Any = None, update_doc=False): + identity: Any = None, + call: Callable[..., Any] | None = None, + reduce: Callable[..., Any] | None = None, + accumulate: Callable[..., Any] | None = None, + at: Callable[..., Any] | None = None, + reduceat: Callable[..., Any] | None = None, + ): + self.__doc__ = func.__doc__ + self.__name__ = name or func.__name__ # We want ufunc instances to work properly when marked as static, # and for this reason it's important that their properties not be # mutated. We prevent this by storing them in a dunder attribute, # and accessing them via read-only properties. - if update_doc: - self.__doc__ = func.__doc__ - self.__name__ = name or func.__name__ self.__static_props = { 'func': func, - 'call': vectorize(func), 'nin': operator.index(nin), 'nout': operator.index(nout), 'nargs': operator.index(nargs or nin), - 'identity': identity + 'identity': identity, + 'call': call, + 'reduce': reduce, + 'accumulate': accumulate, + 'at': at, + 'reduceat': reduceat, } _func = property(lambda self: self.__static_props['func']) - _call = property(lambda self: self.__static_props['call']) nin = property(lambda self: self.__static_props['nin']) nout = property(lambda self: self.__static_props['nout']) nargs = property(lambda self: self.__static_props['nargs']) identity = property(lambda self: self.__static_props['identity']) def __hash__(self) -> int: - # Do not include _call, because it is computed from _func. + # In both __hash__ and __eq__, we do not consider call, reduce, etc. + # because they are considered implementation details rather than + # necessary parts of object identity. return hash((self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs)) def __eq__(self, other: Any) -> bool: - # Do not include _call, because it is computed from _func. return isinstance(other, ufunc) and ( (self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) == (other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs)) @@ -124,20 +167,71 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: return f"" - def __call__(self, *args: ArrayLike, - out: None = None, where: None = None, - **kwargs: Any) -> Any: + def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> Any: + check_arraylike(self.__name__, *args) if out is not None: raise NotImplementedError(f"out argument of {self}") if where is not None: raise NotImplementedError(f"where argument of {self}") - return self._call(*args, **kwargs) + call = self.__static_props['call'] or self._call_vectorized + return call(*args) + + @partial(jax.jit, static_argnames=['self']) + def _call_vectorized(self, *args): + return vectorize(self._func)(*args) - @implements(np.ufunc.reduce, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) - def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + def reduce(self, a: ArrayLike, axis: int = 0, + dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + """Reduction operation derived from a binary function. + + JAX implementation of :meth:`numpy.ufunc.reduce`. + + Args: + a: Input array. + axis: integer specifying the axis over which to reduce. default=0 + dtype: optionally specify the type of the output array. + out: Unused by JAX + keepdims: If True, reduced axes are left in the result with size 1. + If False (default) then reduced axes are squeezed out. + initial: int or array, Default=None. Initial value for the reduction. + where: boolean mask, default=None. The elements to be used in the sum. Array + should be broadcast compatible to the input. + + Returns: + array containing the result of the reduction operation. + + Examples: + Consider the following array: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + + :meth:`jax.numpy.add.reduce` is equivalent to :func:`jax.numpy.sum` + along ``axis=0``: + + >>> jnp.add.reduce(x) + Array([5, 7, 9], dtype=int32) + >>> x.sum(0) + Array([5, 7, 9], dtype=int32) + + Similarly, :meth:`jax.numpy.logical_and.reduce` is equivalent to + :func:`jax.numpy.all`: + + >>> jnp.logical_and.reduce(x > 2) + Array([False, False, True], dtype=bool) + >>> jnp.all(x > 2, axis=0) + Array([False, False, True], dtype=bool) + + Some reductions do not correspond to any built-in aggregation function; + for example here is the reduction of :func:`jax.numpy.bitwise_or` along + the first axis of ``x``: + + >>> jnp.bitwise_or.reduce(x, axis=1) + Array([3, 7], dtype=int32) + """ check_arraylike(f"{self.__name__}.reduce", a) if self.nin != 2: raise ValueError("reduce only supported for binary ufuncs") @@ -154,14 +248,10 @@ def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, "so to use a where mask one has to specify 'initial'.") if lax_internal._dtype(where) != bool: raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}") - primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) - if primitive is None: - reducer = self._reduce_via_scan - else: - reducer = _primitive_reducers.get(primitive, self._reduce_via_scan) - return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) + reduce = self.__static_props['reduce'] or self._reduce_via_scan + return reduce(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 @@ -202,9 +292,9 @@ def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | Non def body_fun(i, val): if where is None: - return self._call(val, arr[i].astype(dtype)) + return self(val, arr[i].astype(dtype)) else: - return _where(where[i], self._call(val, arr[i].astype(dtype)), val) + return _where(where[i], self(val, arr[i].astype(dtype)), val) start_value: ArrayLike if initial is None: @@ -221,22 +311,63 @@ def body_fun(i, val): result = result.reshape(final_shape) return result - @implements(np.ufunc.accumulate, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: + """Accumulate operation derived from binary ufunc. + + JAX implementation of :func:`numpy.ufunc.accumulate`. + + Args: + a: N-dimensional array over which to accumulate. + axis: integer axis over which accumulation will be performed (default = 0) + dtype: optionally specify the type of the output array. + out: Unused by JAX + + Returns: + An array containing the accumulated result. + + Examples: + Consider the following array: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + + :meth:`jax.numpy.add.accumulate` is equivalent to + :func:`jax.numpy.cumsum` along the specified axis: + >>> jnp.add.accumulate(x, axis=1) + Array([[ 1, 3, 6], + [ 4, 9, 15]], dtype=int32) + >>> jnp.cumsum(x, axis=1) + Array([[ 1, 3, 6], + [ 4, 9, 15]], dtype=int32) + + Similarly, :meth:`jax.numpy.multiply.accumulate` is equivalent to + :func:`jax.numpy.cumprod` along the specified axis: + + >>> jnp.multiply.accumulate(x, axis=1) + Array([[ 1, 2, 6], + [ 4, 20, 120]], dtype=int32) + >>> jnp.cumprod(x, axis=1) + Array([[ 1, 2, 6], + [ 4, 20, 120]], dtype=int32) + + For other binary ufuncs, the accumulation is an operation not available + via standard APIs. For example, :meth:`jax.numpy.bitwise_or.accumulate` + is essentially a bitwise cumulative ``any``: + + >>> jnp.bitwise_or.accumulate(x, axis=1) + Array([[1, 3, 3], + [4, 5, 7]], dtype=int32) + """ if self.nin != 2: raise ValueError("accumulate only supported for binary ufuncs") if self.nout != 1: raise ValueError("accumulate only supported for functions returning a single value") if out is not None: raise NotImplementedError(f"out argument of {self.__name__}.accumulate()") - primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) - if primitive is None: - accumulator = self._accumulate_via_scan - else: - accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan) - return accumulator(a, axis=axis, dtype=dtype) + accumulate = self.__static_props['accumulate'] or self._accumulate_via_scan + return accumulate(a, axis=axis, dtype=dtype) def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None) -> Array: @@ -254,21 +385,54 @@ def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, arr = _moveaxis(arr, axis, 0) def scan_fun(carry, _): i, x = carry - y = _where(i == 0, arr[0].astype(dtype), self._call(x.astype(dtype), arr[i].astype(dtype))) + y = _where(i == 0, arr[0].astype(dtype), self(x.astype(dtype), arr[i].astype(dtype))) return (i + 1, y), y _, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) return _moveaxis(result, 0, axis) - @implements(np.ufunc.at, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, inplace: bool = True) -> Array: + """Update elements of an array via the specified unary or binary ufunc. + + JAX implementation of :func:`numpy.ufunc.at`. + + Note: + :meth:`numpy.ufunc.at` mutates arrays in-place. JAX arrays are immutable, + so :meth:`jax.numpy.ufunc.at` cannot replicate these semantics. Instead, JAX + will return the updated value, but requires explicitly passing ``inplace=False`` + as a reminder of this difference. + + Args: + a: N-dimensional array to update + indices: index, slice, or tuple of indices and slices. + b: array of values for binary ufunc updates. + inplace: must be set to False to indicate that an updated copy will be returned. + + Returns: + an updated copy of the input array. + + Examples: + + Add numbers to specified indices: + + >>> x = jnp.ones(10, dtype=int) + >>> indices = jnp.array([2, 5, 7]) + >>> values = jnp.array([10, 20, 30]) + >>> jnp.add.at(x, indices, values, inplace=False) + Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32) + + This is roughly equivalent to JAX's :meth:`jax.numpy.ndarray.at` method + called this way: + + >>> x.at[indices].add(values) + Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32) + """ if inplace: raise NotImplementedError(_AT_INPLACE_WARNING) - if b is None: - return self._at_via_scan(a, indices) - else: - return self._at_via_scan(a, indices, b) + + at = self.__static_props['at'] or self._at_via_scan + return at(a, indices) if b is None else at(a, indices, b) def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: assert len(args) in {0, 1} @@ -276,14 +440,14 @@ def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype a = lax_internal.asarray(a).astype(dtype) args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args) - indices = _eliminate_deprecated_list_indexing(indices) + indices = jnp._eliminate_deprecated_list_indexing(indices) if not indices: return a shapes = [np.shape(i) for i in indices if not isinstance(i, slice)] shape = shapes and jax.lax.broadcast_shapes(*shapes) if not shape: - return a.at[indices].set(self._call(a.at[indices].get(), *args)) + return a.at[indices].set(self(a.at[indices].get(), *args)) if args: arg = _broadcast_to(args[0], (*shape, *args[0].shape[len(shape):])) @@ -293,28 +457,65 @@ def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: def scan_fun(carry, x): i, a = carry idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices) - a = a.at[idx].set(self._call(a.at[idx].get(), *(arg[i] for arg in args))) + a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args))) return (i + 1, a), x carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) return carry[1] - @implements(np.ufunc.reduceat, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: + """Reduce an array between specified indices via a binary ufunc. + + JAX implementation of :meth:`numpy.ufunc.reduceat` + + Args: + a: N-dimensional array to reduce + indices: a 1-dimensional array of increasing integer values which encodes + segments of the array to be reduced. + axis: integer specifying the axis along which to reduce: default=0. + dtype: optionally specify the dtype of the output array. + out: unused by JAX + Returns: + An array containing the reduced values. + + Examples: + The ``reduce`` method lets you efficiently compute reduction operations + over array segments. For example: + + >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) + >>> indices = jnp.array([0, 2, 5]) + >>> jnp.add.reduce(x, indices) + Array([ 3, 12, 21], dtype=int32) + + This is more-or-less equivalent to the following: + + >>> jnp.array([x[0:2].sum(), x[2:5].sum(), x[5:].sum()]) + Array([ 3, 12, 21], dtype=int32) + + For some binary ufuncs, JAX provides similar APIs within :mod:`jax.ops`. + For example, :meth:`jax.add.reduceat` is similar to :func:`jax.ops.segment_sum`, + although in this case the segments are defined via an array of segment ids: + + >>> segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 2]) + >>> jax.ops.segment_sum(x, segments) + Array([ 3, 12, 21], dtype=int32) + """ if self.nin != 2: raise ValueError("reduceat only supported for binary ufuncs") if self.nout != 1: raise ValueError("reduceat only supported for functions returning a single value") if out is not None: raise NotImplementedError(f"out argument of {self.__name__}.reduceat()") - return self._reduceat_via_scan(a, indices, axis=axis, dtype=dtype) + + reduceat = self.__static_props['reduceat'] or self._reduceat_via_scan + return reduceat(a, indices, axis=axis, dtype=dtype) def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None) -> Array: check_arraylike(f"{self.__name__}.reduceat", a, indices) a = lax_internal.asarray(a) - idx_tuple = _eliminate_deprecated_list_indexing(indices) + idx_tuple = jnp._eliminate_deprecated_list_indexing(indices) assert len(idx_tuple) == 1 indices = idx_tuple[0] if a.ndim == 0: @@ -326,27 +527,62 @@ def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, if axis is None or isinstance(axis, (tuple, list)): raise ValueError("reduceat requires a single integer axis.") axis = canonicalize_axis(axis, a.ndim) - out = take(a, indices, axis=axis) - ind = jax.lax.expand_dims(append(indices, a.shape[axis]), + out = jnp.take(a, indices, axis=axis) + ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]), list(np.delete(np.arange(out.ndim), axis))) ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) def loop_body(i, out): return _where((i > ind_start) & (i < ind_end), - self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), + self(out, jnp.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), out) return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) - @implements(np.ufunc.outer, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0]) - def outer(self, A: ArrayLike, B: ArrayLike, /, **kwargs) -> Array: + def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: + """Apply the function to all pairs of values in ``A`` and ``B``. + + JAX implementation of :meth:`numpy.ufunc.outer`. + + Args: + A: N-dimensional array + B: N-dimensional array + + Returns: + An array of shape `tuple(*A.shape, *B.shape)` + + Examples: + A times-table for integers 1...10 created via + :meth:`jax.numpy.multiply.outer`: + + >>> x = jnp.arange(1, 11) + >>> print(jnp.multiply.outer(x, x)) + [[ 1 2 3 4 5 6 7 8 9 10] + [ 2 4 6 8 10 12 14 16 18 20] + [ 3 6 9 12 15 18 21 24 27 30] + [ 4 8 12 16 20 24 28 32 36 40] + [ 5 10 15 20 25 30 35 40 45 50] + [ 6 12 18 24 30 36 42 48 54 60] + [ 7 14 21 28 35 42 49 56 63 70] + [ 8 16 24 32 40 48 56 64 72 80] + [ 9 18 27 36 45 54 63 72 81 90] + [ 10 20 30 40 50 60 70 80 90 100]] + + For input arrays with ``N`` and ``M`` dimensions respectively, the output + will have dimesion ``N + M``: + + >>> x = jnp.ones((1, 3, 5)) + >>> y = jnp.ones((2, 4)) + >>> jnp.add.outer(x, y).shape + (1, 3, 5, 2, 4) + """ if self.nin != 2: raise ValueError("outer only supported for binary ufuncs") if self.nout != 1: raise ValueError("outer only supported for functions returning a single value") check_arraylike(f"{self.__name__}.outer", A, B) _ravel = lambda A: jax.lax.reshape(A, (np.size(A),)) - result = jax.vmap(jax.vmap(partial(self._call, **kwargs), (None, 0)), (0, None))(_ravel(A), _ravel(B)) + result = jax.vmap(jax.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B)) return result.reshape(*np.shape(A), *np.shape(B)) @@ -363,4 +599,4 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, Returns: wrapped : jax.numpy.ufunc wrapper of func. """ - return ufunc(func, nin, nout, identity=identity, update_doc=True) + return ufunc(func, nin, nout, identity=identity) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index dfeff38df0fe..aa8ac4e95325 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -30,11 +30,13 @@ from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax from jax._src.lax import other as lax_other -from jax._src.typing import Array, ArrayLike +from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, implements, check_no_float0s) +from jax._src.numpy.ufunc_api import ufunc +from jax._src.numpy import reductions _lax_const = lax._const @@ -298,31 +300,81 @@ def sqrt(x: ArrayLike, /) -> Array: def cbrt(x: ArrayLike, /) -> Array: return lax.cbrt(*promote_args_inexact('cbrt', x)) -@implements(np.add, module='numpy') @partial(jit, inline=True) -def add(x: ArrayLike, y: ArrayLike, /) -> Array: +def _add(x: ArrayLike, y: ArrayLike, /) -> Array: + """Add two arrays element-wise. + + JAX implementation of :obj:`numpy.add`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: arrays to add. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise addition. + """ x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) -@implements(np.multiply, module='numpy') @partial(jit, inline=True) -def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: +def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: + """Multiply two arrays element-wise. + + JAX implementation of :obj:`numpy.multiply`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: arrays to multiply. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise multiplication. + """ x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) -@implements(np.bitwise_and, module='numpy') @partial(jit, inline=True) -def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the bitwise AND operation elementwise. + + JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: integer or boolean arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise bitwise AND. + """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) -@implements(np.bitwise_or, module='numpy') @partial(jit, inline=True) -def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the bitwise OR operation elementwise. + + JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: integer or boolean arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise bitwise OR. + """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) -@implements(np.bitwise_xor, module='numpy') @partial(jit, inline=True) -def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the bitwise XOR operation elementwise. + + JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: integer or boolean arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise bitwise XOR. + """ return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) @implements(np.left_shift, module='numpy') @@ -376,19 +428,49 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.nextafter(*promote_args_inexact("nextafter", x, y)) # Logical ops -@implements(np.logical_and, module='numpy') @partial(jit, inline=True) -def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: +def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the logical AND operation elementwise. + + JAX implementation of :obj:`numpy.logical_and`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: input arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise logical AND. + """ return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) -@implements(np.logical_or, module='numpy') @partial(jit, inline=True) -def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: +def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the logical OR operation elementwise. + + JAX implementation of :obj:`numpy.logical_or`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: input arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise logical OR. + """ return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) -@implements(np.logical_xor, module='numpy') @partial(jit, inline=True) -def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: +def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the logical XOR operation elementwise. + + JAX implementation of :obj:`numpy.logical_xor`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: input arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise logical XOR. + """ return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) @implements(np.logical_not, module='numpy') @@ -1281,3 +1363,38 @@ def _sinc_maclaurin(k, x): def _sinc_maclaurin_jvp(k, primals, tangents): (x,), (t,) = primals, tangents return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t + + +def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None): + if initial is not None: + raise ValueError("initial argument not supported in jnp.logical_and.reduce()") + result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where) + return result if dtype is None else result.astype(dtype) + + +def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None): + if initial is not None: + raise ValueError("initial argument not supported in jnp.logical_or.reduce()") + result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where) + return result if dtype is None else result.astype(dtype) + + +# Generate ufunc interfaces for several common binary functions. +# We start with binary ufuncs that have well-defined identities.' +# TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience? +# TODO(jakevdp): optimize some implementations. +# - define add.at/multiply.at in terms of scatter_add/scatter_mul +# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod +# - define all monoidal reductions in terms of lax.reduce +add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum) +multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod) +bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and) +bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or) +bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor) +logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce) +logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce) +logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index ffa3a103e7e4..64d461fe9996 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -965,8 +965,8 @@ def caller_jax(x): self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf)) else: graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()) - if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def: - self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def) + if "my_test_function_jax/pjit__multiply_/Mul" not in graph_def: + self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def) def test_bfloat16_constant(self): # Re: https://github.com/google/jax/issues/3942 diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 583f6886e915..5e2c1dce4c3d 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -3,7 +3,7 @@ from __future__ import annotations import builtins from collections.abc import Callable, Sequence -from typing import Any, Literal, NamedTuple, TypeVar, Union, overload +from typing import Any, Literal, NamedTuple, Protocol, TypeVar, Union, overload from jax._src import core as _core from jax._src import dtypes as _dtypes @@ -28,6 +28,34 @@ _Device = Device ComplexWarning: type +class BinaryUfunc(Protocol): + @property + def nin(self) -> int: ... + @property + def nout(self) -> int: ... + @property + def nargs(self) -> int: ... + @property + def identity(self) -> builtins.bool | int | float: ... + def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ... + def reduce(self, arr: ArrayLike, /, *, + axis: int | None = 0, + dtype: DTypeLike | None = None, + keepdims: builtins.bool = False, + initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... + def accumulate(self, a: ArrayLike, /, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, + inplace: builtins.bool = True) -> Array: ... + def reduceat(self, a: ArrayLike, indices: Any, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def outer(self, a: ArrayLike, b: ArrayLike, /) -> Array: ... + __array_api_version__: str def __array_namespace_info__() -> ArrayNamespaceInfo: ... @@ -36,7 +64,7 @@ def abs(x: ArrayLike, /) -> Array: ... def absolute(x: ArrayLike, /) -> Array: ... def acos(x: ArrayLike, /) -> Array: ... def acosh(x: ArrayLike, /) -> Array: ... -def add(x: ArrayLike, y: ArrayLike, /) -> Array: ... +add: BinaryUfunc def amax(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... @@ -162,14 +190,14 @@ def bartlett(M: int) -> Array: ... bfloat16: Any def bincount(x: ArrayLike, weights: ArrayLike | None = ..., minlength: int = ..., *, length: int | None = ...) -> Array: ... -def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_and: BinaryUfunc def bitwise_count(x: ArrayLike, /) -> Array: ... def bitwise_invert(x: ArrayLike, /) -> Array: ... def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_not(x: ArrayLike, /) -> Array: ... -def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_or: BinaryUfunc def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_xor: BinaryUfunc def blackman(M: int) -> Array: ... def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ... bool: Any @@ -251,7 +279,7 @@ def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ..., dtype: DTypeLike | None = ..., - include_initial: bool = ...) -> Array: ... + include_initial: builtins.bool = ...) -> Array: ... def deg2rad(x: ArrayLike, /) -> Array: ... degrees = rad2deg @@ -557,10 +585,10 @@ def log1p(x: ArrayLike, /) -> Array: ... def log2(x: ArrayLike, /) -> Array: ... def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logical_and: BinaryUfunc def logical_not(x: ArrayLike, /) -> Array: ... -def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logical_or: BinaryUfunc +logical_xor: BinaryUfunc def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., base: ArrayLike = ..., dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... @@ -588,7 +616,7 @@ def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: ... -def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ... +multiply: BinaryUfunc nan: float def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ..., posinf: ArrayLike | None = ..., diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 16eb9321c822..c65df8aa87a2 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -14,6 +14,7 @@ """Tests for jax.numpy.ufunc and its methods.""" +import itertools from functools import partial from absl.testing import absltest @@ -22,7 +23,6 @@ import jax import jax.numpy as jnp from jax._src import test_util as jtu -from jax._src.numpy.ufunc_api import get_if_single_primitive jax.config.parse_flags_with_absl() @@ -54,19 +54,22 @@ def scalar_sub(x, y): {'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None}, ] -FASTPATH_FUNCS = [ - {'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0, - 'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p}, - {'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1, - 'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p}, -] +def _jnp_ufunc_props(name): + jnp_func = getattr(jnp, name) + assert isinstance(jnp_func, jnp.ufunc) + np_func = getattr(np, name) + dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types] + return [dict(name=name, dtype=dtype) for dtype in dtypes] + -NON_FASTPATH_FUNCS = [ - {'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0}, - {'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1}, - {'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1}, +JAX_NUMPY_UFUNCS = [ + name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc) ] +JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable( + _jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS +)) + broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] nonscalar_shapes = [(3,), (4,), (4, 3)] @@ -80,23 +83,40 @@ def wrapped(*args, **kwargs): class LaxNumpyUfuncTests(jtu.JaxTestCase): @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_properties(self, func, nin, nout, identity): + def test_frompyfunc_properties(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) self.assertEqual(jnp_fun.identity, identity) self.assertEqual(jnp_fun.nin, nin) self.assertEqual(jnp_fun.nout, nout) self.assertEqual(jnp_fun.nargs, nin) + @jtu.sample_product(name=JAX_NUMPY_UFUNCS) + def test_ufunc_properties(self, name): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + self.assertEqual(jnp_fun.identity, np_fun.identity) + self.assertEqual(jnp_fun.nin, np_fun.nin) + self.assertEqual(jnp_fun.nout, np_fun.nout) + self.assertEqual(jnp_fun.nargs, np_fun.nargs - 1) # -1 because NumPy accepts `out` + @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_properties_readonly(self, func, nin, nout, identity): + def test_frompyfunc_properties_readonly(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) - for attr in ['nargs', 'nin', 'nout', 'identity', '_func', '_call']: + for attr in ['nargs', 'nin', 'nout', 'identity', '_func']: + getattr(jnp_fun, attr) # no error on attribute access. + with self.assertRaises(AttributeError): + setattr(jnp_fun, attr, None) # error when trying to mutate. + + @jtu.sample_product(name=JAX_NUMPY_UFUNCS) + def test_ufunc_properties_readonly(self, name): + jnp_fun = getattr(jnp, name) + for attr in ['nargs', 'nin', 'nout', 'identity', '_func']: getattr(jnp_fun, attr) # no error on attribute access. with self.assertRaises(AttributeError): setattr(jnp_fun, attr, None) # error when trying to mutate. @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_hash(self, func, nin, nout, identity): + def test_frompyfunc_hash(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) self.assertEqual(jnp_fun, jnp_fun_2) @@ -113,7 +133,7 @@ def test_ufunc_hash(self, func, nin, nout, identity): dtype=jtu.dtypes.floating, ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def test_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): + def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity)) @@ -123,13 +143,28 @@ def test_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + lhs_shape=broadcast_compatible_shapes, + rhs_shape=broadcast_compatible_shapes, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def test_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( SCALAR_FUNCS, lhs_shape=broadcast_compatible_shapes, rhs_shape=broadcast_compatible_shapes, dtype=jtu.dtypes.floating, ) - def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): + def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer @@ -141,6 +176,23 @@ def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + lhs_shape=broadcast_compatible_shapes, + rhs_shape=broadcast_compatible_shapes, + ) + def test_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker) + self._CompileAndCheck(jnp_fun.outer, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} @@ -148,7 +200,7 @@ def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) - def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) @@ -160,6 +212,26 @@ def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [None, *range(-len(shape), len(shape))]], + ) + def test_ufunc_reduce(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis) + np_fun_reduce = partial(np_fun.reduce, axis=axis) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + self._CompileAndCheck(jnp_fun_reduce, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} @@ -167,7 +239,7 @@ def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) - def test_reduce_where(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_reduce_where(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") @@ -194,42 +266,28 @@ def np_fun(arr, where): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - FASTPATH_FUNCS, - [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, - ) - def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): - del accumulator # unused - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") - rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) - self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer) - - @jtu.sample_product( - NON_FASTPATH_FUNCS, + JAX_NUMPY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, + for axis in [None, *range(-len(shape), len(shape))]], ) - def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype): - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") - rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - - _ = func(0, 0) # function should not error. + def test_ufunc_reduce_where(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + if jnp_fun.identity is None: + self.skipTest("reduce with where requires identity") - reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) - self.assertIsNone(get_if_single_primitive(reduce_fun, *args)) + jnp_fun_reduce = lambda a, where: jnp_fun.reduce(a, axis=axis, where=where) + np_fun_reduce = lambda a, where: np_fun.reduce(a, axis=axis, where=where) - accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) - self.assertIsNone(get_if_single_primitive(accum_fun, *args)) + rng = jtu.rand_default(self.rng()) + rng_where = jtu.rand_bool(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)] + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + self._CompileAndCheck(jnp_fun_reduce, args_maker) @jtu.sample_product( SCALAR_FUNCS, @@ -238,7 +296,7 @@ def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype): for axis in range(-len(shape), len(shape))], dtype=jtu.dtypes.floating, ) - def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) @@ -251,20 +309,28 @@ def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - FASTPATH_FUNCS, + JAX_NUMPY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, + for axis in range(-len(shape), len(shape))] ) - def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): - del reducer # unused - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") + def test_ufunc_accumulate(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) - self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator) + args_maker = lambda: [rng(shape, dtype)] + + jnp_fun_accumulate = partial(jnp_fun.accumulate, axis=axis) + def np_fun_accumulate(x): + # numpy accumulate has different dtype casting behavior. + result = np_fun.accumulate(x, axis=axis) + return result if x.dtype == bool else result.astype(x.dtype) + + self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker) + self._CompileAndCheck(jnp_fun_accumulate, args_maker) @jtu.sample_product( SCALAR_FUNCS, @@ -272,7 +338,7 @@ def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype idx_shape=[(), (2,)], dtype=jtu.dtypes.floating, ) - def test_at(self, func, nin, nout, identity, shape, idx_shape, dtype): + def test_frompyfunc_at(self, func, nin, nout, identity, shape, idx_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False) @@ -288,7 +354,31 @@ def np_fun(x, idx, y): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - def test_at_broadcasting(self): + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + shape=nonscalar_shapes, + idx_shape=[(), (2,)], + ) + def test_ufunc_at(self, name, shape, idx_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0]) + args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)] + + jnp_fun_at = partial(jnp_fun.at, inplace=False) + def np_fun_at(x, idx, y): + x_copy = x.copy() + np_fun.at(x_copy, idx, y) + return x_copy + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + self._CompileAndCheck(jnp_fun_at, args_maker) + + def test_frompyfunc_at_broadcasting(self): # Regression test for https://github.com/google/jax/issues/18004 args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]), np.arange(9.0).reshape(3, 3)] @@ -309,7 +399,7 @@ def np_fun(x, idx, y): idx_shape=[(0,), (3,), (5,)], dtype=jtu.dtypes.floating, ) - def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype): + def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis) @@ -322,6 +412,33 @@ def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + JAX_NUMPY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [*range(-len(shape), len(shape))]], + idx_shape=[(0,), (3,), (5,)], + ) + def test_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + if name in ['add', 'multiply'] and dtype == bool: + # TODO(jakevdp): figure out how to fix thest cases. + self.skipTest(f"known failure for {name}.reduceat with {dtype=}") + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis]) + args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')] + + def np_fun_reduceat(x, i): + # Numpy has different casting behavior. + return np_fun.reduceat(x, i).astype(x.dtype) + + self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker) + self._CompileAndCheck(jnp_fun.reduceat, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From e3110c18f8bce83901cff42458d4204df9e3abeb Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 30 Aug 2024 14:35:21 -0700 Subject: [PATCH 309/702] Remove `dtype` and `weak_type` from `__slots__` of `ShapedArray` since it comes from `UnShapedArray` PiperOrigin-RevId: 669447416 --- jax/_src/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index dbd4dc58e2e2..32677ade1b75 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1733,7 +1733,7 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) class ShapedArray(UnshapedArray): - __slots__ = ['shape', 'dtype', 'weak_type', 'sharding'] + __slots__ = ['shape', 'sharding'] # inherits slots from parent array_abstraction_level = 2 named_shape = {} # type: ignore From bd486fce88bc0065d38ae0dca20cc6f0ee39ec4a Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 30 Aug 2024 15:24:24 -0700 Subject: [PATCH 310/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/63252d04c328dceeca6a7460a5783a4d2abd2f17. PiperOrigin-RevId: 669462683 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index bbfd289abb17..15f224dec2c1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ed742254c6f9ec81a5c760fe3e709e72610eeffe" -XLA_SHA256 = "508b0fa82c42a9f18507985467703825a1fa62a2d4f816a6575e4236fe99258b" +XLA_COMMIT = "63252d04c328dceeca6a7460a5783a4d2abd2f17" +XLA_SHA256 = "69b3a09f45d0e39e92be855572be30f219859334184b2f1cf30897f9315c8ddf" def repo(): tf_http_archive( From a64b9a543ea34800fd3ba32d70ecc68683fe545b Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Fri, 30 Aug 2024 22:43:05 +0000 Subject: [PATCH 311/702] add sliding window attn --- jax/_src/cudnn/fused_attention_stablehlo.py | 96 +++++++++++++-------- tests/fused_attention_stablehlo_test.py | 63 ++++++++++++-- 2 files changed, 119 insertions(+), 40 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 7ceac8940147..e1ed7f094f35 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -109,6 +109,7 @@ def create_dot_product_attention_backend_config(batch, dropout_rate, mask_type, layout, + sliding_window_length, is_bwd): # Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H # P: BMM1 output in shape of BNTS @@ -119,7 +120,8 @@ def create_dot_product_attention_backend_config(batch, # BMM1Grad2: dP @ K -> dQ # BMM2Grad1: P @ dO -> dV # BMM2Grad2: dO @ V -> dP - + if sliding_window_length is None: + sliding_window_length = 0 cudnn_fmha_backend_config = { "algorithm": { "algo_id": "0", @@ -151,6 +153,7 @@ def create_dot_product_attention_backend_config(batch, "seed": seed, "is_flash_attention": True, "mask_type": convert_mask_type_to_string(mask_type), + "sliding_window_length": sliding_window_length, } # We define the contracting and batch dims in the format of @@ -319,34 +322,38 @@ def check_compute_capability(capability): def _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout, cudnn_version): + dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, cudnn_version): # check if flash attention is supported for this attention pattern check_is_flash_attention( query, key, layout, cudnn_version, bias is not None, False) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=False) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=False) output = outputs[0] return output def _dot_product_attention_fwd_rule( query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout, cudnn_version): + dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, cudnn_version): # check if flash attention is supported for this attention pattern check_is_flash_attention( query, key, layout, cudnn_version, bias is not None, True) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=True) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=True) res = (query, key, value, bias, q_seqlen, kv_seqlen, outputs[1], outputs[0]) return outputs[0], res def _dot_product_attention_bwd_rule( - scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, - res, grad_output): + scale, seed, dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, is_training, res, grad_output): (query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output) = res grads = _dot_product_attention_bwd_p_wrapper.bind( @@ -354,33 +361,39 @@ def _dot_product_attention_bwd_rule( fwd_output, grad_output, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length ) grads = (*grads,) + (None,) * (6 - len(grads)) return grads def _dot_product_attention_fwd_impl( query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout, is_training): + dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, is_training): # args: {Q, K, V, mask*, bias*} outputs = _dot_product_attention_fwd_p.bind( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=is_training) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=is_training) return outputs def _dot_product_attention_bwd_impl( query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output, - grad_output, scale, seed, dropout_rate, variadic_args, mask_type, layout): + grad_output, scale, seed, dropout_rate, variadic_args, mask_type, layout, + sliding_window_length): grads = _dot_product_attention_bwd_p.bind( query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output, grad_output, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length) return grads def _dot_product_attention_fwd_abstract( query, key, value, bias, q_seqlen, kv_seqlen, *, scale, seed, - dropout_rate, variadic_args, mask_type, layout, is_training): + dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, is_training): query_dtype = dtypes.canonicalize_dtype(query.dtype) if layout == AttentionLayout.BNTH.value: B, N, T, _ = query.shape @@ -404,7 +417,7 @@ def _dot_product_attention_fwd_abstract( def _dot_product_attention_bwd_abstract( query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output, grad_output, *, scale, seed, dropout_rate, variadic_args, mask_type, - layout): + layout, sliding_window_length): query_dtype = dtypes.canonicalize_dtype(query.dtype) key_dtype = dtypes.canonicalize_dtype(key.dtype) value_dtype = dtypes.canonicalize_dtype(value.dtype) @@ -442,7 +455,8 @@ def _dot_product_attention_bwd_abstract( def _dot_product_attention_fwd_cuda_lowering( ctx, query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout, is_training): + dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, is_training): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) @@ -465,7 +479,7 @@ def _dot_product_attention_fwd_cuda_lowering( workspace_type = ir.IntegerType.get_unsigned(8) backend_config = create_dot_product_attention_backend_config( B, N, T, S, query_type.element_type, scale, seed, dropout_rate, - mask_type, layout, is_bwd=False, + mask_type, layout, sliding_window_length, is_bwd=False, ) # {Q, K, V, bias*, q_seqlen*, kv_seqlen*} # {output, activation*, workspace} @@ -512,7 +526,7 @@ def _dot_product_attention_fwd_cuda_lowering( def _dot_product_attention_bwd_cuda_lowering( ctx, query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output, grad_output, scale, seed, dropout_rate, variadic_args, - mask_type, layout): + mask_type, layout, sliding_window_length): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) @@ -538,7 +552,7 @@ def _dot_product_attention_bwd_cuda_lowering( grad_value_shape = (B, k_N, S, H) backend_config = create_dot_product_attention_backend_config( B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate, - mask_type, layout, is_bwd=True, + mask_type, layout, sliding_window_length, is_bwd=True, ) # {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*} # {dQ, dK, dV, dbias*, workspace} @@ -601,7 +615,7 @@ def _check_valid_batch_dims(bdims): def _dot_product_attention_fwd_batcher( batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, - mask_type, layout, is_training): + mask_type, layout, sliding_window_length, is_training): _check_valid_batch_dims(batch_dims) query, key, value, bias, q_seqlen, kv_seqlen = batched_args query_bdim = batch_dims[0] @@ -646,7 +660,7 @@ def _dot_product_attention_fwd_batcher( def _dot_product_attention_bwd_batcher( batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, - mask_type, layout): + mask_type, layout, sliding_window_length): _check_valid_batch_dims(batch_dims) query, key, value, bias, q_seqlen, \ kv_seqlen, activation, fwd_output, grad_output = batched_args @@ -749,16 +763,16 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training): return [out_sharding] _dot_product_attention_fwd_lower = custom_partitioning( - _dot_product_attention_fwd_impl, static_argnums=(6, 7, 8, 9, 10, 11, 12)) + _dot_product_attention_fwd_impl, static_argnums=(6, 7, 8, 9, 10, 11, 12, 13)) def _dot_product_attention_fwd_infer_sharding_from_operands( - scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, - mesh, arg_shapes, result_shape): + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, + is_training, mesh, arg_shapes, result_shape): return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training) def _dot_product_attention_fwd_partition( - scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, - mesh, arg_shapes, result_shape): + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, + is_training, mesh, arg_shapes, result_shape): # args sharding arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) out_shardings = _infer_fwd_output_sharding( @@ -771,6 +785,7 @@ def _dot_product_attention_fwd_partition( variadic_args=variadic_args, mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=is_training, ) return mesh, impl, out_shardings, arg_shardings @@ -797,17 +812,17 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args): return out_shardings _dot_product_attention_bwd_lower = custom_partitioning( - _dot_product_attention_bwd_impl, static_argnums=(9, 10, 11, 12, 13, 14) + _dot_product_attention_bwd_impl, static_argnums=(9, 10, 11, 12, 13, 14, 15) ) def _dot_product_attention_bwd_infer_sharding_from_operands( - scale, seed, dropout_rate, variadic_args, mask_type, layout, mesh, - arg_shapes, result_shape): + scale, seed, dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, mesh, arg_shapes, result_shape): return _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args) def _dot_product_attention_bwd_partition( - scale, seed, dropout_rate, variadic_args, mask_type, layout, mesh, - arg_shapes, result_shape): + scale, seed, dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, mesh, arg_shapes, result_shape): out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args) # args sharding arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) @@ -820,6 +835,7 @@ def sharded_impl(*args): variadic_args=variadic_args, mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, ) grads = impl(*args) _, has_dbias = variadic_args @@ -917,7 +933,7 @@ def sharded_impl(*args): ) -@functools.partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12)) +@functools.partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12, 13)) def _dot_product_attention(query: Array, key: Array, value: Array, @@ -930,11 +946,13 @@ def _dot_product_attention(query: Array, variadic_args: tuple[bool, ...], mask_type: bool, layout: int, + sliding_window_length: int | None, cudnn_version: int): output = _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, cudnn_version=cudnn_version) + mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, + cudnn_version=cudnn_version) return output # _dot_product_attention_fwd must have the same func signature as _dot_product_attention @@ -953,7 +971,8 @@ def dot_product_attention(query: Array, mask_type: MaskType = MaskType.NO_MASK, seed: int = 42, dropout_rate: float = 0., - qkv_layout: str = "BTNH"): + qkv_layout: str = "BTNH", + sliding_window_length: int | None = None): """Computes dot-product attention given query (Q), key (K), and value (V). This function serves as the core operation for applying attention @@ -984,7 +1003,11 @@ def dot_product_attention(query: Array, scale: Scale for the query. dropout_rate: Dropout rate. qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH, - BNSH. + BNSH. + sliding_window_length: Window size to make attention only attend to each + token's left local window (pos - sliding_window_length, pos] where `pos` + is the index of each token. E.g., if sliding_window_length == 3 and the + sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c]. Returns: Output of the same shape as the query. @@ -997,6 +1020,9 @@ def dot_product_attention(query: Array, layout = _normalize_layout(qkv_layout) if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask") + if sliding_window_length is not None and sliding_window_length <= 0: + raise ValueError( + f"Require sliding_window_length > 0, got {sliding_window_length}") if bias is not None: # reshape bias to have 4D shape @@ -1032,6 +1058,6 @@ def dot_product_attention(query: Array, kv_seqlen = jnp.zeros(0, dtype=query.dtype) output = _dot_product_attention( query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout.value, cudnn_version - ) + dropout_rate, variadic_args, mask_type, layout.value, sliding_window_length, + cudnn_version) return output diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index bc05c4b2e85c..705cdfdbdc8f 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -47,7 +47,8 @@ def sdpa_train(query: Array, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, is_bnth: bool = False, - dropout_rate: float = 0.1) -> Array: + dropout_rate: float = 0.1, + sliding_window_length: int | None = None) -> Array: if mask_type == MaskType.PADDING: if is_bnth: B, _, S, _ = query.shape @@ -59,7 +60,8 @@ def sdpa_train(query: Array, out, sdpa_vjp = jax.vjp( partial(dot_product_attention, scale=scale, mask_type=mask_type, dropout_rate=dropout_rate, - qkv_layout="BNTH" if is_bnth else "BTNH"), + qkv_layout="BNTH" if is_bnth else "BTNH", + sliding_window_length=sliding_window_length), query, key, value, bias, mask, q_seqlen, kv_seqlen) query_grad, key_grad, value_grad, bias_grad, _, _, _ = sdpa_vjp(grad) if bias is not None and len(bias.shape) == 3: @@ -74,7 +76,8 @@ def sdpa_ref(query: Array, mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, - dropout_rate: float = 0.1) -> Array: + dropout_rate: float = 0.1, + sliding_window_length: int | None = None) -> Array: def get_causal_mask(logits): large_negative_number = get_large_negative_number(logits.dtype) @@ -99,6 +102,16 @@ def get_encoded_padding_mask(encoded): return jax.lax.broadcast_in_dim( encoded_padding, encoded.shape, broadcast_dimensions=[1]) + def get_sliding_window_mask(logits, window_length): + large_negative_number = get_large_negative_number(logits.dtype) + T = logits.shape[-2] + col_idx = jax.lax.broadcasted_iota(np.int32, (T, T), 1) + row_idx = jax.lax.broadcasted_iota(np.int32, (T, T), 0) + mask = jnp.logical_or( + row_idx < col_idx, + col_idx <= row_idx - window_length).astype(logits.dtype) * large_negative_number + return mask[(*([jnp.newaxis]*(len(logits.shape) - 2)), ...)] + B, T, qN, H = query.shape _, _, kN, _ = key.shape logits = jnp.einsum("bqhd,bkhd->bhqk", query, key) @@ -108,6 +121,11 @@ def get_encoded_padding_mask(encoded): bias = get_causal_mask(logits) elif mask_type == MaskType.PADDING: bias = get_padding_mask(logits) + elif sliding_window_length is not None: + if sliding_window_length <= 0: + raise ValueError( + f"Expect sliding_window_length > 0, got {sliding_window_length}.") + bias = get_sliding_window_mask(logits, sliding_window_length) if mask is not None: large_negative_number = get_large_negative_number(logits.dtype) mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number) @@ -141,10 +159,12 @@ def sdpa_train_ref(query: Array, mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, - dropout_rate: float = 0.1) -> Array: + dropout_rate: float = 0.1, + sliding_window_length: int | None = None) -> Array: out_ref, sdpa_vjp_ref = jax.vjp( partial( - sdpa_ref, scale=scale, mask_type=mask_type, dropout_rate=dropout_rate), + sdpa_ref, scale=scale, mask_type=mask_type, dropout_rate=dropout_rate, + sliding_window_length=sliding_window_length), query, key, value, bias, mask) query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref, _ = sdpa_vjp_ref(grad) if bias is not None and len(bias.shape) == 3: @@ -399,6 +419,39 @@ def test_sdpa_broadcast_bias_and_dbias(self): self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) self.assertArraysAllClose(bias_grad_ref, bias_grad, rtol=1e-5, atol=1e-5) + @jtu.run_on_devices("cuda") + def test_sdpa_sliding_window_length(self): + k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) + query = jax.random.normal( + k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 64), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + grad = jax.random.normal( + k4, (4, 1024, 4, 64), dtype=jnp.bfloat16) + jitted_sdpa_train = jax.jit( + partial( + sdpa_train, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0, + sliding_window_length=64), + ) + # for reference implementation + # sliding_window_length option itself will setup correct mask + jitted_sdpa_train_ref = jax.jit( + partial( + sdpa_train_ref, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0, + sliding_window_length=64), + ) + + out, (query_grad, key_grad, value_grad) = \ + jitted_sdpa_train(query, key, value, grad, None, None) + out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ + jitted_sdpa_train_ref(query, key, value, grad, None, None) + self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + @jtu.run_on_devices("cuda") def test_layouts(self): dtype = "bfloat16" From 5ba8d6b811489e7def513f04f9f23ccbbd9d6b9c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 30 Aug 2024 16:08:33 -0700 Subject: [PATCH 312/702] Make `input_layouts` and `output_layouts` properties of `Compiled`. This is equivalent to `input_shardings` and `output_shardings` which are also properties on `Compiled`. PiperOrigin-RevId: 669476027 --- jax/_src/stages.py | 2 + .../array_serialization/serialization_test.py | 2 +- tests/layout_test.py | 50 +++++++++---------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 549da2d39e77..b924072fc044 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -513,6 +513,7 @@ def output_shardings(self): # PyTree[sharding.Sharding] shardings_flat = self._executable.output_shardings() return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error + @property def input_layouts(self): layouts_flat = self._executable.input_layouts() assert all(isinstance(l, Layout) for l in layouts_flat) @@ -523,6 +524,7 @@ def input_layouts(self): else Layout() for i in range(self.in_tree.num_leaves)] return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error + @property def output_layouts(self): layouts_flat = self._executable.output_layouts() assert all(isinstance(l, Layout) for l in layouts_flat) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index e60bfaa1dc89..004f03a85b04 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -556,7 +556,7 @@ def test_load_with_layout(self): arr = jax.device_put(np_inp, s) out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower( - arr).compile().output_layouts() + arr).compile().output_layouts self.assertEqual(arr.layout.device_local_layout.major_to_minor, out_layout.device_local_layout.major_to_minor[::-1]) diff --git a/tests/layout_test.py b/tests/layout_test.py index 0a9a72e8f48b..1af8e259ecce 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -70,18 +70,18 @@ def init(x, y): out_shardings=Layout(DLL.AUTO)).lower(sds1, sds2) compiled_apply = lowered_apply.compile() - arg_layouts, kw_layouts = compiled_apply.input_layouts() + arg_layouts, kw_layouts = compiled_apply.input_layouts self.assertEmpty(kw_layouts) - for i, o in zip(arg_layouts, compiled_apply.output_layouts()): + for i, o in zip(arg_layouts, compiled_apply.output_layouts): self.assertEqual(i.device_local_layout.major_to_minor, o.device_local_layout.major_to_minor[::-1]) init_compiled = jax.jit( init, out_shardings=arg_layouts).lower(sds1, sds2).compile() - for i, o in zip(init_compiled.input_layouts()[0], - init_compiled.output_layouts()): + for i, o in zip(init_compiled.input_layouts[0], + init_compiled.output_layouts): self.assertEqual(i, o) arr1 = jax.device_put(np_inp1, s1) @@ -92,16 +92,16 @@ def init(x, y): init_compiled(arr1, arr2) self.assertEqual(init_count[0], 1) - self.assertEqual(init_out[0].layout, init_compiled.output_layouts()[0]) - self.assertEqual(init_out[1].layout, init_compiled.output_layouts()[1]) + self.assertEqual(init_out[0].layout, init_compiled.output_layouts[0]) + self.assertEqual(init_out[1].layout, init_compiled.output_layouts[1]) with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) compiled_apply(*init_out) self.assertEqual(apply_count[0], 1) - self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts()[0]) - self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts()[1]) + self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts[0]) + self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts[1]) self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor, init_out[0].layout.device_local_layout.major_to_minor[::-1]) @@ -132,10 +132,10 @@ def f(x): out = compiled(arr) self.assertTupleEqual( - compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertTupleEqual( - compiled.output_layouts().device_local_layout.major_to_minor[::-1], + compiled.output_layouts.device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) @@ -143,10 +143,10 @@ def f(x): compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO), out_shardings=Layout(DLL.AUTO)).lower(sds).compile() self.assertTupleEqual( - compiled_auto.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + compiled_auto.input_layouts[0][0].device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertTupleEqual( - compiled_auto.output_layouts().device_local_layout.major_to_minor[::-1], + compiled_auto.output_layouts.device_local_layout.major_to_minor[::-1], (0, 1, 2)) with self.assertRaisesRegex( @@ -169,15 +169,15 @@ def f(x): compiled = jax.jit(f, in_shardings=Layout(), out_shardings=Layout(DLL.AUTO)).lower(arr).compile() self.assertTupleEqual( - compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], (1, 0)) self.assertTupleEqual( - compiled.output_layouts().device_local_layout.major_to_minor[::-1], + compiled.output_layouts.device_local_layout.major_to_minor[::-1], (0, 1)) out = compiled(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout, compiled.output_layouts()) + self.assertEqual(out.layout, compiled.output_layouts) self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): @@ -192,10 +192,10 @@ def test_sharding_and_layouts(self): out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile() out = compiled(np_inp) self.assertTupleEqual( - compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], (1, 0)) self.assertTupleEqual( - compiled.output_layouts().device_local_layout.major_to_minor[::-1], + compiled.output_layouts.device_local_layout.major_to_minor[::-1], (0, 1)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, s) @@ -208,13 +208,13 @@ def f(x, y, z, a, b, c): inps = [np.arange(math.prod(shape)).reshape(shape)] * 6 compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), out_shardings=Layout(DLL.AUTO)).lower(*inps).compile() - arg_layouts, _ = compiled.input_layouts() + arg_layouts, _ = compiled.input_layouts out1, out2 = compiled(*inps) compiled2 = jax.jit(f, in_shardings=arg_layouts).lower(*inps).compile() out3, out4 = compiled2(*inps) - for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts()[0]): + for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts[0]): self.assertEqual(l1, l2) self.assertArraysEqual(out1, out3) @@ -240,7 +240,7 @@ def f(x, y): jf = jax.jit(f, in_shardings=Layout(DLL.AUTO, s), out_shardings=Layout(DLL.AUTO, s)) compiled = jf.lower(np_inp, np_inp).compile() - arg_layouts, _ = compiled.input_layouts() + arg_layouts, _ = compiled.input_layouts arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_layouts)] compiled(*arrs) @@ -291,7 +291,7 @@ def test_device_put_concrete_layout(self): compiled = jax.jit( lambda x: x * 2, out_shardings=Layout(DLL.AUTO)).lower(arr).compile() - col = compiled.output_layouts() + col = compiled.output_layouts out = jax.device_put(np_inp, col) self.assertEqual(out.layout, col) @@ -323,7 +323,7 @@ def invalid_layout_spec(self): compiled = jax.jit(lambda x: x).lower(x).compile() with self.assertRaisesRegex( ValueError, 'Sharding has to be concrete when layout.*'): - Layout(compiled.output_layouts()[0], None) + Layout(compiled.output_layouts[0], None) def test_layout_on_sds(self): mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) @@ -332,10 +332,10 @@ def test_layout_on_sds(self): arr = jax.device_put(np_inp, s) out_layout = jax.jit(jnp.sin, out_shardings=Layout(DLL.AUTO)).lower( - arr).compile().output_layouts() + arr).compile().output_layouts sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_layout) - arg_layout, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_layouts() + arg_layout, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_layouts self.assertEqual(arg_layout[0], out_layout) with self.assertRaisesRegex( @@ -350,7 +350,7 @@ def test_make_array_from_callback(self): np_inp = np.arange(16).reshape(8, 2) sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) - layout = jax.jit(lambda x: x * 2).lower(sds).compile().output_layouts() + layout = jax.jit(lambda x: x * 2).lower(sds).compile().output_layouts out = jax.make_array_from_callback(np_inp.shape, layout, lambda idx: np_inp[idx]) From 969dd89040ec5bbdc0fb5ad936183cf68463e88f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sat, 31 Aug 2024 08:43:39 -0700 Subject: [PATCH 313/702] Reverts changelist 668370165 PiperOrigin-RevId: 669670355 --- tests/memories_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index f761f490ad4e..6140c6945df5 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -559,7 +559,6 @@ def test_identity_jit_host_to_device_and_vice_versa(self): self.assertArraysEqual(out_host, np_inp) self.assertEqual(out_host.sharding, s_host) - @jtu.skip_on_devices("gpu") def test_parameter_streaming_inside_scan(self): mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z")) np_inp = np.arange(4096.0).reshape(16, 16, 16) @@ -1440,7 +1439,6 @@ def f(x): if jtu.pjrt_c_api_version_at_least(0, 43): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) - @jtu.skip_on_devices("gpu") def test_remat_scan_jaxpr_offloadable(self): mesh = jtu.create_global_mesh((2,), ("x",)) shape = (256, 128) From b6aedbc41fff9b71c167e42e251766bbeb2ac41a Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 31 Aug 2024 15:53:23 -0700 Subject: [PATCH 314/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0e3e2263e30ef24560a9abe64a713e1692f07216. PiperOrigin-RevId: 669740628 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 15f224dec2c1..17cbc754b373 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "63252d04c328dceeca6a7460a5783a4d2abd2f17" -XLA_SHA256 = "69b3a09f45d0e39e92be855572be30f219859334184b2f1cf30897f9315c8ddf" +XLA_COMMIT = "0e3e2263e30ef24560a9abe64a713e1692f07216" +XLA_SHA256 = "939e8a71c115db8575d0c9252402026e732a6b6f8d7134588ec4190b6e872382" def repo(): tf_http_archive( From 7b415834145df96f0c80e5cf3cb34cfe796f85b6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Sun, 1 Sep 2024 07:49:49 -0700 Subject: [PATCH 315/702] refactor jax.lax to not depend on jax.numpy --- jax/_src/lax/control_flow/conditionals.py | 7 +- jax/_src/lax/control_flow/for_loop.py | 12 +- jax/_src/lax/control_flow/loops.py | 24 +- jax/_src/lax/control_flow/solves.py | 9 +- jax/_src/lax/eigh.py | 134 ++++++---- jax/_src/lax/fft.py | 6 +- jax/_src/lax/lax.py | 31 ++- jax/_src/lax/linalg.py | 293 +++++++++++++--------- jax/_src/lax/other.py | 21 +- jax/_src/lax/parallel.py | 12 +- jax/_src/lax/qdwh.py | 80 +++--- jax/_src/lax/slicing.py | 5 +- jax/_src/lax/stack.py | 26 +- jax/_src/lax/svd.py | 79 +++--- jax/_src/scipy/linalg.py | 2 +- 15 files changed, 439 insertions(+), 302 deletions(-) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 8161606801c2..b96f9e8c6e40 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -23,7 +23,6 @@ import operator from typing import Any, TypeVar -import jax from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util from jax._src import config @@ -275,12 +274,8 @@ def cond(pred, true_fun, false_fun, *operands): num_consts = len(consts) out_ = iter(out) - def _cast_to_array(x): - _copy = isinstance(x, np.bool_) - return jax.numpy.asarray(x, copy=_copy) - out = [ - next(out_) if fwd is None else _cast_to_array(ops[fwd - num_consts]) + next(out_) if fwd is None else lax.asarray(ops[fwd - num_consts]) for fwd in in_fwd ] assert next(out_, None) is None diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 15249e531144..61b9a24644ce 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -20,7 +20,6 @@ import operator from typing import Any, Generic, TypeVar -import jax.numpy as jnp from jax import lax from jax.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad @@ -46,6 +45,7 @@ split_list, split_dict, weakref_lru_cache) from jax._src.lax.control_flow import loops from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr +import numpy as np ## JAX utilities @@ -132,7 +132,7 @@ def wrapped_body(i, refs): nsteps, = nsteps flat_state, state_tree = tree_flatten(init_state) state_avals = map(state_utils.val_to_ref_aval, flat_state) - idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64)) + idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64)) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( body, state_tree, [idx_aval, *state_avals]) if out_tree != tree_structure(None): @@ -202,7 +202,7 @@ def _create_jaxpr(init): return jaxpr, out_tree jaxpr, out_tree = _create_jaxpr(init) _, ys_avals = tree_unflatten(out_tree, jaxpr.out_avals) - ys = tree_map(lambda aval: jnp.zeros([length, *aval.shape], aval.dtype), + ys = tree_map(lambda aval: lax.full([length, *aval.shape], 0, aval.dtype), ys_avals) def for_body(i, refs): carry_refs, xs_refs, ys_refs = refs @@ -251,7 +251,7 @@ def body(i, state): def _for_impl_unrolled(body, nsteps, unroll, *args): remainder = nsteps % unroll - i = jnp.astype(0, dtypes.canonicalize_dtype(jnp.int64)) + i = lax.full((), 0, dtypes.canonicalize_dtype(np.int64)) state = list(args) for _ in range(remainder): @@ -748,7 +748,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False): """ flat_state, state_tree = tree_flatten(init_state) state_avals = map(state_utils.val_to_ref_aval, flat_state) - idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64)) + idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64)) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( body, state_tree, [idx_aval, *state_avals]) if out_tree != tree_structure(None): @@ -756,7 +756,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False): discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) def fori_body(i, carry): - i = jnp.astype(i, dtypes.canonicalize_dtype(jnp.int64)) + i = lax.convert_element_type(i, dtypes.canonicalize_dtype(np.int64)) if reverse: i = nsteps - i - 1 out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 5084ec43c2fa..828728ebdbd2 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -22,7 +22,6 @@ from typing import Any, TypeVar import weakref -import jax from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import api @@ -42,6 +41,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla +from jax._src import sharding_impls as sharding from jax._src.interpreters import xla from jax._src.lax import lax from jax._src.lax import slicing @@ -67,6 +67,7 @@ unzip2, weakref_lru_cache, ) +from jax._src import xla_bridge as xb from jax.tree_util import ( keystr, tree_flatten, @@ -85,6 +86,9 @@ ### Helper functions +def _stack(arrs: Sequence[Array], axis: int=0) -> Array: + return lax.concatenate([lax.expand_dims(arr, (axis,)) for arr in arrs], dimension=axis) + def _promote_weak_typed_inputs(in_vals, in_avals, out_avals): """Promote weakly-typed in_vals to be compatible with out_avals. @@ -254,7 +258,7 @@ def scan(f, init, xs, length=None): xs_slice = [slicing.index_in_dim(x, i, keepdims=False) for x in xs_flat] carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) ys.append(y) - stack = lambda *ys: jax.numpy.stack(ys) + stack = lambda *ys: _stack(ys) stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y @@ -449,11 +453,11 @@ def inner(n, carry, xs): carry, y = split_list(carry_y, [num_carry]) ys.append(y) ys = list(reversed(ys)) if reverse else ys - return carry, _map(jax.numpy.stack, zip(*ys)) + return carry, _map(_stack, zip(*ys)) if num_trips: i = lax._const(num_trips, 0) - _, carry, yss = jax.lax.while_loop(cond_fun, body_fun, (i, carry, yss)) + _, carry, yss = while_loop(cond_fun, body_fun, (i, carry, yss)) if unroll != 1: ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss] else: @@ -694,7 +698,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, def _maybe_put(x): if isinstance(x, np.ndarray): aval = shaped_abstractify(x) - s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]) + s = sharding.SingleDeviceSharding(xb.local_devices(backend='cpu')[0]) result_handler = pxla.global_aval_to_result_handler(aval, s, False) return result_handler(pxla.shard_args([s], [None], [x])) else: @@ -2144,12 +2148,12 @@ def map(f, xs): """ if batch_size is not None: scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size) - g = lambda _, x: ((), jax.vmap(f)(x)) + g = lambda _, x: ((), api.vmap(f)(x)) _, scan_ys = scan(g, (), scan_xs) - remainder_ys = jax.vmap(f)(remainder_xs) + remainder_ys = api.vmap(f)(remainder_xs) flatten = lambda x: x.reshape(-1, *x.shape[2:]) ys = tree_map( - lambda x, y: jax.numpy.concatenate([flatten(x), y], axis=0), scan_ys, remainder_ys, + lambda x, y: lax.concatenate([flatten(x), y], dimension=0), scan_ys, remainder_ys, ) else: g = lambda _, x: ((), f(x)) @@ -2167,10 +2171,10 @@ def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, key = keys[0] new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape), dtype=dtype, algorithm=algorithm) - new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) + new_keys = slicing.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) return (new_keys, bits), (0, 0) -batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore[has-type] +batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule ### associative_scan diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 4d55907f6b37..21105e20aaf8 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -16,11 +16,12 @@ from functools import partial import operator -import jax from jax.tree_util import (tree_flatten, treedef_children, tree_leaves, tree_unflatten, treedef_tuple) from jax._src import ad_util +from jax._src import api from jax._src import core +from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src.core import raise_to_shaped from jax._src.interpreters import ad @@ -99,7 +100,7 @@ def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False): _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) def linearize_and_solve(x, b): - unchecked_zeros, f_jvp = jax.linearize(f, x) + unchecked_zeros, f_jvp = api.linearize(f, x) return tangent_solve(f_jvp, b) l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( @@ -115,7 +116,7 @@ def linearize_and_solve(x, b): return tree_unflatten(solution_tree, solution_flat) -@partial(jax.custom_jvp, nondiff_argnums=(0, 1)) +@partial(custom_derivatives.custom_jvp, nondiff_argnums=(0, 1)) def _custom_root(const_lengths, jaxprs, *args): params, initial_guess = _split_root_args(args, const_lengths) solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess)) @@ -169,7 +170,7 @@ def _split_linear_solve_args(args, const_lengths): def _transpose_one_output(linear_fun, primals): - transpose_fun = jax.linear_transpose(linear_fun, primals) + transpose_fun = api.linear_transpose(linear_fun, primals) def transposed_fun(x): (y,) = transpose_fun(x) return y diff --git a/jax/_src/lax/eigh.py b/jax/_src/lax/eigh.py index fc66b0f2e7ee..d872ff0f68cc 100644 --- a/jax/_src/lax/eigh.py +++ b/jax/_src/lax/eigh.py @@ -30,15 +30,16 @@ from functools import partial from typing import NamedTuple -import jax -import jax._src.numpy.lax_numpy as jnp -import jax._src.numpy.linalg as jnp_linalg -from jax._src.numpy import reductions -from jax._src.numpy import ufuncs +from jax._src import api +from jax._src import config +from jax._src import dtypes from jax import lax from jax._src.lax import qdwh from jax._src.lax import linalg as lax_linalg from jax._src.lax.stack import Stack +from jax._src.lax import lax as lax_internal +from jax._src.typing import Array +import numpy as np # QDWH-eigh is a recursive algorithm where the structure of the recursion @@ -52,19 +53,45 @@ def _round_up(i, n): return ((i+n-1) // n) * n +def _norm(x, axis=None): + return lax.sqrt((abs(x) ** 2).sum(axis=axis)) + +def _broadcast_to(x: Array, shape: tuple[int, ...]) -> Array: + assert x.ndim <= len(shape) + return lax.broadcast_in_dim(x, shape, range(len(shape) - x.ndim, len(shape))) + +def _construct_diagonal(s: Array) -> Array: + """Construct a (batched) diagonal matrix""" + # signature: (...,n)->(...,n,n) + i = lax.iota('int32', s.shape[-1]) + return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s) + +def _extract_diagonal(s: Array) -> Array: + """Extract the diagonal from a batched matrix""" + # signature: (...,n,m)->(...k) where k=min(n,m) + i = lax.iota('int32', min(s.shape[-2], s.shape[-1])) + return s[..., i, i] + def _mask(x, dims, alternative=0): """Masks `x` up to the dynamic shape `dims`. Replaces values outside those dimensions with `alternative`. `alternative` is broadcast with `x`. """ - assert jnp.ndim(x) == len(dims) + assert np.ndim(x) == len(dims) mask = None for i, d in enumerate(dims): if d is not None: - mask_dim_i = lax.broadcasted_iota(jnp.int32, x.shape, i) < d + mask_dim_i = lax.broadcasted_iota(np.int32, x.shape, i) < d mask = mask_dim_i if mask is None else (mask & mask_dim_i) - return x if mask is None else jnp.where(mask, x, alternative) + + alternative = _broadcast_to(lax_internal.asarray(alternative), x.shape).astype(x.dtype) + return x if mask is None else lax.select(mask, x, alternative) + +def _nanmedian(vals): + # note: NaNs will be sorted to the end. + num_nans = lax_internal._isnan(vals).sum() + return lax.sort(vals.ravel())[(vals.size - num_nans) // 2] def _slice(operand, start_indices, dynamic_slice_sizes, static_slice_sizes, fill_value=0): @@ -87,9 +114,9 @@ def _slice(operand, start_indices, dynamic_slice_sizes, static_slice_sizes, # We must pad the input array so the dynamic_slice is guaranteed to fall # entirely in bounds. padded = lax.pad(operand, - jnp.array(0, operand.dtype), + np.array(0, operand.dtype), [(0, d, 0) for d in static_slice_sizes]) - out = lax.dynamic_slice(padded, tuple(jnp.int32(i) for i in start_indices), + out = lax.dynamic_slice(padded, tuple(lax.convert_element_type(i, 'int32') for i in start_indices), static_slice_sizes) return _mask(out, dynamic_slice_sizes, fill_value) @@ -106,9 +133,9 @@ def _update_slice(operand, update, start_indices, update_dims): inside the rectangle given by `update_dims` will be overwritten.""" operand_shape = operand.shape operand = lax.pad(operand, - jnp.array(0, operand.dtype), + np.array(0, operand.dtype), [(0, d, 0) for d in update.shape]) - start_indices = tuple(jnp.int32(i) for i in start_indices) + start_indices = tuple(lax.convert_element_type(i, 'int32') for i in start_indices) t = lax.dynamic_slice(operand, start_indices, update.shape) t = _mask(update, update_dims, t) operand = lax.dynamic_update_slice(operand, t, start_indices) @@ -140,45 +167,45 @@ def _projector_subspace(P, H, n, rank, maxiter=2, swap=False): """ # Choose an initial guess: the `rank` largest-norm columns of P. N, _ = P.shape - negative_column_norms = -jnp_linalg.norm(P, axis=1) - # `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN. - negative_column_norms = _mask(negative_column_norms, (n,), jnp.nan) - sort_idxs = jnp.argsort(negative_column_norms) + negative_column_norms = -_norm(P, axis=1) + # `sort_key_val` ensures NaNs sort last, so set masked-out column norms to NaN. + negative_column_norms = _mask(negative_column_norms, (n,), np.nan) + _, sort_idxs = lax.sort_key_val(negative_column_norms, lax.iota('int32', len(negative_column_norms))) X = P[:, sort_idxs] # X = X[:, :rank] X = _mask(X, (n, rank)) - H_norm = jnp_linalg.norm(H) - thresh = 10.0 * float(jnp.finfo(X.dtype).eps) * H_norm + H_norm = _norm(H) + thresh = 10.0 * float(dtypes.finfo(X.dtype).eps) * H_norm # First iteration skips the matmul. def body_f_after_matmul(X): - Q, _ = jnp_linalg.qr(X, mode="complete") + Q, _ = lax_linalg.qr(X) # V1 = Q[:, :rank] # V2 = Q[:, rank:] V1 = _mask(Q, (n, rank)) V2 = _slice(Q, (0, rank), (n, n - rank), (N, N)) # TODO: might be able to get away with lower precision here - error_matrix = jnp.dot(V2.conj().T, H) - error_matrix = jnp.dot(error_matrix, V1) - error = jnp_linalg.norm(error_matrix) + error_matrix = V2.conj().T @ H + error_matrix = error_matrix @ V1 + error = _norm(error_matrix) return V1, V2, error def cond_f(args): _, _, j, error = args still_counting = j < maxiter unconverged = error > thresh - return ufuncs.logical_and(still_counting, unconverged)[0] + return lax.bitwise_and(still_counting, unconverged)[0] def body_f(args): V1, _, j, _ = args - X = jnp.dot(P, V1) + X = P @ V1 V1, V2, error = body_f_after_matmul(X) return V1, V2, j + 1, error V1, V2, error = body_f_after_matmul(X) - one = jnp.ones(1, dtype=jnp.int32) + one = lax.full((1,), 1, np.dtype('int32')) V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error)) if swap: return V2, V1 @@ -210,11 +237,11 @@ def split_spectrum(H, n, split_point, V0=None): rank: The dynamic size of the m subblock. """ N, _ = H.shape - H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(H.dtype) + H_shift = H - (split_point * lax_internal._eye(split_point.dtype, (N, N))).astype(H.dtype) U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n)) - I = _mask(jnp.eye(N, dtype=H.dtype), (n, n)) + I = _mask(lax_internal._eye(H.dtype, (N, N)), (n, n)) P_minus = -0.5 * (U - I) - rank_minus = jnp.round(jnp.trace(ufuncs.real(P_minus))).astype(jnp.int32) + rank_minus = lax.round(_extract_diagonal(P_minus.real).sum(-1)).astype(np.int32) P_plus = 0.5 * (U + I) rank_plus = n - rank_minus @@ -232,8 +259,8 @@ def split_spectrum(H, n, split_point, V0=None): H_minus = (V_minus.conj().T @ H) @ V_minus H_plus = (V_plus.conj().T @ H) @ V_plus if V0 is not None: - V_minus = jnp.dot(V0, V_minus) - V_plus = jnp.dot(V0, V_plus) + V_minus = lax.dot(V0, V_minus) + V_plus = lax.dot(V0, V_plus) return H_minus, V_minus, H_plus, V_plus, rank_minus @@ -258,7 +285,7 @@ def split_spectrum(H, n, split_point, V0=None): # H, V: The result of the projection. # """ # if H.shape[0] <= termination_size: -# evals, evecs = jnp_linalg.eigh(H) +# evals, evecs = jnp.linalg.eigh(H) # if V is not None: # evecs = jnp.dot(V, evecs) # return evals, evecs @@ -279,13 +306,13 @@ class _Subproblem(NamedTuple): in the workspace. """ # The row offset of the block in the matrix of blocks. - offset: jax.Array + offset: Array # The size of the block. - size: jax.Array + size: Array -@partial(jax.jit, static_argnames=('termination_size', 'subset_by_index')) +@partial(api.jit, static_argnames=('termination_size', 'subset_by_index')) def _eigh_work(H, n, termination_size, subset_by_index): """ The main work loop performing the symmetric eigendecomposition of H. Each step recursively computes a projector into the space of eigenvalues @@ -308,20 +335,21 @@ def _eigh_work(H, n, termination_size, subset_by_index): # We turn what was originally a recursive algorithm into an iterative # algorithm with an explicit stack. N, _ = H.shape - n = jnp.asarray(n, jnp.int32) + n = n.astype('int32') + zero = lax.full((), 0, 'int32') agenda = Stack.create( - N + 1, _Subproblem(jnp.array(0, jnp.int32), jnp.array(0, jnp.int32))) - agenda = agenda.push(_Subproblem(offset=jnp.int32(0), size=n)) + N + 1, _Subproblem(zero, zero)) + agenda = agenda.push(_Subproblem(offset=zero, size=n)) # eigenvectors is the array in which we build the output eigenvectors. # We initialize it with the identity matrix so the initial matrix # multiplications in_split_spectrum_jittable are the identity. - eigenvectors = jnp.eye(N, dtype=H.dtype) + eigenvectors = lax_internal._eye(H.dtype, (N, N)) # Keep a copy of the initial matrix Frobenius norm, so we know when to stop # recursing. When the sub-matrix norm is less than eps*H0_norm, the contents are # pure numerical noise, and we should just stop. - H0_norm = jnp_linalg.norm(_mask(H, (n, n))) + H0_norm = _norm(_mask(H, (n, n))) # blocks is an array representing a stack of Hermitian matrix blocks that we # need to recursively decompose. Subproblems are different sizes, so the stack @@ -368,7 +396,7 @@ def base_case(B, offset, b, agenda, blocks, eigenvectors): eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False) eig_vecs = _mask(eig_vecs, (b, b)) eig_vals = _mask(eig_vals, (b,)) - eig_vecs = jnp.dot(V, eig_vecs) + eig_vecs = lax.dot(V, eig_vecs) eig_vals = eig_vals.astype(eig_vecs.dtype) blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, 1)) @@ -381,7 +409,7 @@ def recursive_case(B, offset, b, agenda, blocks, eigenvectors): H = _slice(blocks, (offset, 0), (b, b), (B, B)) def nearly_diagonal_case(agenda, blocks, eigenvectors): - blocks = _update_slice(blocks, jnp.diag(H)[:, None], (offset, 0), (b, 1)) + blocks = _update_slice(blocks, _extract_diagonal(H)[:, None], (offset, 0), (b, 1)) return agenda, blocks, eigenvectors def should_update_range(start, end, subset_by_index): @@ -394,7 +422,7 @@ def should_update_range(start, end, subset_by_index): def default_case(agenda, blocks, eigenvectors): V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) # TODO: Improve this? - split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan)) + split_point = _nanmedian(_mask(_extract_diagonal(H.real), (b,), np.nan)) H_minus, V_minus, H_plus, V_plus, rank = split_spectrum( H, b, split_point, V0=V) @@ -439,11 +467,11 @@ def default_case(agenda, blocks, eigenvectors): # the original input matrix,, terminate the execution. This is necessary to # handle matrices with clusters of eigenvalues, including rank deficient # matrices. See Nakatsukasa and Higham section 5.2. - norm = jnp_linalg.norm(H) - eps = jnp.asarray(jnp.finfo(H.dtype).eps, dtype=norm.dtype) - off_diag_norm = jnp_linalg.norm( - H - jnp.diag(jnp.diag(ufuncs.real(H)).astype(H.dtype))) - nearly_diagonal = off_diag_norm <= 5 * eps * norm + norm = _norm(H) + eps = np.asarray(dtypes.finfo(H.dtype).eps, dtype=norm.dtype) + off_diag_norm = _norm( + H - _construct_diagonal(_extract_diagonal(H.real).astype(H.dtype))) + nearly_diagonal = off_diag_norm <= 5 * (eps * norm) tiny = norm < eps * H0_norm return lax.cond( nearly_diagonal | tiny, @@ -482,13 +510,13 @@ def loop_cond(state): buckets.append(bucket_size) branches.append(partial(recursive_case, bucket_size)) i = i // 2 - buckets = jnp.array(buckets, dtype='int32') + buckets = np.array(buckets, dtype='int32') def loop_body(state): agenda, blocks, eigenvectors = state (offset, b), agenda = agenda.pop() - which = jnp.where(buckets < b, jnp.iinfo(jnp.int32).max, buckets) - choice = jnp.argmin(which) + which = lax.select(buckets < b, lax.full_like(buckets, np.iinfo(np.int32).max), buckets) + choice = lax.argmin(which, 0, 'int32') return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors) _, blocks, eigenvectors = lax.while_loop( @@ -557,13 +585,13 @@ def eigh( return eig_vals, eig_vecs n = N if n is None else n - with jax.default_matmul_precision(precision): + with config.default_matmul_precision(precision): eig_vals, eig_vecs = _eigh_work( H, n, termination_size=termination_size, subset_by_index=subset_by_index ) - eig_vals = _mask(ufuncs.real(eig_vals), (n,), jnp.nan) + eig_vals = _mask(eig_vals.real, (n,), np.nan) if sort_eigenvalues or compute_slice: - sort_idxs = jnp.argsort(eig_vals) + _, sort_idxs = lax.sort_key_val(eig_vals, lax.iota('int32', len(eig_vals))) if compute_slice: sort_idxs = sort_idxs[subset_by_index[0] : subset_by_index[1]] eig_vals = eig_vals[sort_idxs] diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index a1cce3500df1..0cbee6d2bfbc 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -23,6 +23,7 @@ from jax import lax from jax._src import dispatch +from jax._src import dtypes from jax._src.api import jit, linear_transpose, ShapeDtypeStruct from jax._src.core import Primitive, is_constant_shape from jax._src.interpreters import ad @@ -30,7 +31,6 @@ from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client -from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact __all__ = [ "fft", @@ -61,9 +61,9 @@ def fft(x, fft_type: xla_client.FftType | str, fft_lengths: Sequence[int]): if typ == xla_client.FftType.RFFT: if np.iscomplexobj(x): raise ValueError("only real valued inputs supported for rfft") - x, = promote_dtypes_inexact(x) + x = lax.convert_element_type(x, dtypes.to_inexact_dtype(dtypes.dtype(x))) else: - x, = promote_dtypes_complex(x) + x = lax.convert_element_type(x, dtypes.to_complex_dtype(dtypes.dtype(x))) if len(fft_lengths) == 0: # XLA FFT doesn't support 0-rank. return x diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 618c715ba763..2186e767ab37 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -27,7 +27,6 @@ import numpy as np -import jax from jax import tree_util from jax.sharding import Sharding from jax.tree_util import tree_map @@ -42,6 +41,7 @@ from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu +from jax._src import pjit from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src import state @@ -84,6 +84,10 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +def _matrix_transpose(x: Array) -> Array: + assert x.ndim >= 2 + return transpose(x, [*range(x.ndim - 2), x.ndim - 1, x.ndim - 2]) + def _clip_int_to_valid_range(val: DimSize, dtype, where: str) -> int: info = np.iinfo(dtype) val = core.concrete_dim_or_error(val, where) @@ -1327,7 +1331,7 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array: return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), dimension=dimension) -def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: +def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array: """Like numpy.eye, create a 2D array with ones on a diagonal.""" offset = _clip_int_to_valid_range(offset, np.int32, "argument `offset` of jax.numpy.eye") @@ -2180,7 +2184,11 @@ def _pow_dtype_rule(x, y): def _pow_jvp_lhs(g, ans, x, y): y_dtype = dtypes.dtype(y) - x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this + result_dtype = dtypes.result_type(x, y) + if result_dtype == bool: + result_dtype = 'int32' + x = convert_element_type(x, result_dtype) + y = convert_element_type(y, result_dtype) if dtypes.issubdtype(y_dtype, np.integer): if x.shape != y.shape: shape = broadcast_shapes(x.shape, y.shape) @@ -2602,7 +2610,7 @@ def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding): new_dtype=new_dtype, weak_type=weak_type, sharding=sharding) if sharding is not None: - operand = jax.lax.with_sharding_constraint(operand, sharding) + operand = pjit.with_sharding_constraint(operand, sharding) return operand convert_element_type_p.def_custom_bind(_convert_element_type_bind) convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p)) @@ -3093,7 +3101,7 @@ def _ragged_dot_jvp_rule( preferred_element_type=preferred_element_type, ) if type(dx) is not ad_util.Zero - else jax.numpy.zeros_like(primal_out) + else _zeros(primal_out) ) dy_out = ( ragged_dot( @@ -3104,7 +3112,7 @@ def _ragged_dot_jvp_rule( preferred_element_type=preferred_element_type, ) if type(dy) is not ad_util.Zero - else jax.numpy.zeros_like(primal_out) + else _zeros(primal_out) ) tangent_out = dx_out + dy_out @@ -3112,10 +3120,11 @@ def _ragged_dot_jvp_rule( def _ragged_to_dense(x, y, group_sizes): + from jax._src.lax import control_flow # avoid circular imports shape = (y.shape[0], x.shape[0], x.shape[1]) x = broadcast_in_dim(x, shape, [1, 2]) iota = broadcasted_iota(group_sizes.dtype, shape, 1) - group_ends = jax.lax.cumsum(group_sizes) + group_ends = control_flow.cumsum(group_sizes) group_starts = concatenate( [_zeros(group_sizes)[:1], group_ends[:-1]], dimension=0, @@ -3137,7 +3146,7 @@ def _ragged_dot_transpose_rule( if ad.is_undefined_primal(y): grad_x = None else: - y_t = jax.numpy.matrix_transpose(y) + y_t = _matrix_transpose(y) grad_x = ragged_dot( ct, y_t, @@ -3153,7 +3162,7 @@ def _ragged_dot_transpose_rule( x_dense = _ragged_to_dense(x, y, group_sizes=gs) ct_dense = _ragged_to_dense(ct, y, group_sizes=gs) dimension_numbers = (([1], [1]), ([0], [0])) - grad_y = jax.lax.dot_general( + grad_y = dot_general( x_dense, ct_dense, dimension_numbers, @@ -4382,7 +4391,7 @@ def _canonicalize_float_for_sort(x): # and NaNs in the output. result = select(eq(x, _zero(x)), _zeros(x), x) - with jax.debug_nans(False): + with config.debug_nans(False): result = select(_isnan(x), full_like(result, np.nan), result) return result @@ -4897,7 +4906,7 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs): # method on jax.Array so that we can bypass the XLA compilation here. def _copy_impl(prim, *args, **kwargs): a, = args - if isinstance(a, jax.Array) and isinstance(a.sharding, PmapSharding): + if isinstance(a, Array) and isinstance(a.sharding, PmapSharding): sharded_dim = _which_dim_sharded(a.sharding) if sharded_dim is None: return dispatch.apply_primitive(prim, *args, **kwargs) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 1a792e3adc0c..83bfff1d752e 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -22,14 +22,15 @@ import numpy as np -import jax from jax import lax from jax._src import ad_util from jax._src import api +from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import util from jax._src.core import ( Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape) from jax._src.extend import ffi @@ -48,20 +49,43 @@ from jax._src.lib import gpu_sparse from jax._src.lib import lapack from jax._src.lib import version as jaxlib_version -from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy import reductions -from jax._src.numpy import ufuncs -from jax._src.numpy.vectorize import vectorize from jax._src.typing import Array, ArrayLike -xops = xla_client.ops TFun = TypeVar('TFun', bound=Callable[..., Any]) +def _broadcasted_iotas(*sizes): + ones = (1,) * (len(sizes) - 1) + shapes = (util.tuple_insert(ones, i, s) for i, s in enumerate(sizes)) + return [lax.broadcasted_iota('int32', shape, i) for i, shape in enumerate(shapes)] + +def _tril(m: Array, k:int = 0) -> Array: + *_, N, M = m.shape + mask = lax_internal._tri(bool, (N, M), k) + return lax.select(lax.broadcast(mask, m.shape[:-2]), m, lax.zeros_like_array(m)) + +def _triu(m: Array, k:int = 0) -> Array: + *_, N, M = m.shape + mask = lax_internal._tri(bool, (N, M), k - 1) + return lax.select(lax.broadcast(mask, m.shape[:-2]), lax.zeros_like_array(m), m) + +def _construct_diagonal(s: Array) -> Array: + """Construct a (batched) diagonal matrix""" + i = lax.iota('int32', s.shape[-1]) + return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s) + +def _extract_diagonal(s: Array) -> Array: + """Extract the diagonal from a batched matrix""" + i = lax.iota('int32', min(s.shape[-2], s.shape[-1])) + return s[..., i, i] + +def _broadcast_to(x: Array, shape: tuple[int, ...]) -> Array: + assert x.ndim <= len(shape) + return lax.broadcast_in_dim(x, shape, range(len(shape) - x.ndim, len(shape))) + # traceables def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: @@ -91,7 +115,7 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: """ if symmetrize_input: x = symmetrize(x) - return jnp.tril(cholesky_p.bind(x)) + return _tril(cholesky_p.bind(x)) def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, @@ -368,10 +392,10 @@ def triangular_solve(a: ArrayLike, b: ArrayLike, *, Returns: A batch of matrices the same shape and dtype as ``b``. """ - conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a), jnp.complexfloating) - singleton = jnp.ndim(b) == jnp.ndim(a) - 1 + conjugate_a = conjugate_a and dtypes.issubdtype(lax.dtype(a), np.complexfloating) + singleton = np.ndim(b) == np.ndim(a) - 1 if singleton: - b = jnp.expand_dims(b, -1 if left_side else -2) + b = lax.expand_dims(b, (-1 if left_side else -2,)) out = triangular_solve_p.bind( a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) @@ -381,9 +405,17 @@ def triangular_solve(a: ArrayLike, b: ArrayLike, *, # utilities -@partial(vectorize, signature='(n,m),(m)->(n)') -def _matvec_multiply(a: Array, b: Array) -> Array: - return lax.dot(a, b, precision=lax.Precision.HIGHEST) +def _broadcasted_matvec(a: Array, b: Array) -> Array: + # This is a broadcasted dot_general with signature (...,n,m),(...,m)->(...,n) + assert a.ndim >= 2 + assert b.ndim >= 1 + batch_shape = lax.broadcast_shapes(a.shape[:-2], b.shape[:-1]) + n_batch = len(batch_shape) + a = _broadcast_to(a, (*batch_shape, *a.shape[-2:])) + b = _broadcast_to(b, (*batch_shape, b.shape[-1])) + + dimension_numbers = (([a.ndim - 1], [b.ndim - 1]), (list(range(n_batch)), list(range(n_batch)))) + return lax.dot_general(a, b, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) def _check_solve_shapes(a: Array, b: Array): if not (a.ndim >= 2 and b.ndim in [a.ndim, a.ndim - 1] and @@ -399,14 +431,14 @@ def _solve(a: Array, b: Array) -> Array: # custom_linear_solve. out_shape = tuple(d_a if d_b == 1 else d_b for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape)) - b = jnp.broadcast_to(b, out_shape) + b = lax.broadcast_in_dim(b, out_shape, range(b.ndim)) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. lu_, _, permutation = lu(lax.stop_gradient(a)) custom_solve = partial( lax.custom_linear_solve, - lambda x: _matvec_multiply(a, x), + lambda x: _broadcasted_matvec(a, x), solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0), transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1)) if a.ndim == b.ndim + 1: @@ -416,8 +448,10 @@ def _solve(a: Array, b: Array) -> Array: # b.shape == [..., m, k] return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b) -def _T(x: Array) -> Array: return jnp.swapaxes(x, -1, -2) -def _H(x: Array) -> Array: return ufuncs.conj(_T(x)) +def _T(x: Array) -> Array: + return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) +def _H(x: Array) -> Array: + return _T(x).conj() def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 # primitives @@ -430,13 +464,13 @@ def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 def _cholesky_jvp_rule(primals, tangents): x, = primals sigma_dot, = tangents - L = jnp.tril(cholesky_p.bind(x)) + L = _tril(cholesky_p.bind(x)) # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf def phi(X): - l = jnp.tril(X) + l = _tril(X) return l / lax.expand_dims( - lax_internal._const(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype), + lax_internal._const(X, 1) + lax_internal._eye(X.dtype, (X.shape[-1], X.shape[-1])), range(l.ndim - 2)) tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True, @@ -532,22 +566,22 @@ def _cholesky_update_jax_fn(R, z): def _drotg(x, y): """Get coefs for Givens rotation in a numerically stable way.""" def _drotg_nonzero(x, y): - abs_x = jax.numpy.abs(x) - abs_y = jax.numpy.abs(y) - denominator = jnp.where(abs_x > abs_y, abs_x, abs_y) + abs_x = abs(x) + abs_y = abs(y) + denominator = lax.select(abs_x > abs_y, abs_x, abs_y) x /= denominator y /= denominator - rh = 1 / jax.numpy.sqrt(x ** 2 + y ** 2) + rh = 1 / lax.sqrt(x ** 2 + y ** 2) return x * rh, -y * rh one_and_zero = ( - jnp.array(1., dtype=x.dtype), - jnp.array(0., dtype=x.dtype), + np.array(1., dtype=x.dtype), + np.array(0., dtype=x.dtype), ) - return jax.lax.cond(y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) + return lax.cond(y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) def _drot( - first_vector: jax.Array, second_vector: jax.Array, - c_coef: float, s_coef: float) -> tuple[jax.Array, jax.Array]: + first_vector: Array, second_vector: Array, + c_coef: float, s_coef: float) -> tuple[Array, Array]: return ( c_coef * first_vector - s_coef * second_vector, c_coef * second_vector + s_coef * first_vector) @@ -681,7 +715,7 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, a, = primals da, = tangents l, v = eig(a, compute_left_eigenvectors=False) - return [l], [reductions.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)] + return [l], [(_solve(v, da.astype(v.dtype)) * _T(v)).sum(-1)] eig_p = Primitive('eig') eig_p.multiple_results = True @@ -875,16 +909,17 @@ def eigh_qdwh(x): # We should only look at elements from the lower/upper triangle. Reflects # that triangle into the other triangle to form a Hermitian matrix. if lower: - mask = jnp.tri(n, k=0, dtype=bool) + mask = lax_internal._tri(bool, (n, n), 0) else: - mask = ufuncs.logical_not(jnp.tri(n, k=-1, dtype=bool)) - if dtypes.issubdtype(x.dtype, jnp.complexfloating): + mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), -1)) + if dtypes.issubdtype(x.dtype, np.complexfloating): re = lax.select(mask, lax.real(x), _T(lax.real(x))) if lower: - im_mask = jnp.tri(n, k=-1, dtype=bool) + im_mask = lax_internal._tri(bool, (n, n), -1) else: - im_mask = ufuncs.logical_not(jnp.tri(n, k=0, dtype=bool)) - im = lax.select(im_mask, lax.imag(x), jnp.zeros_like(lax.imag(x))) + im_mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), 0)) + im = lax.imag(x) + im = lax.select(im_mask, im, lax.full_like(im, 0)) im = lax.select(mask, im, -_T(im)) x = lax.complex(re, im) else: @@ -929,15 +964,15 @@ def _eigh_jvp_rule( # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) - eye_n = jnp.eye(n, dtype=a.dtype) + eye_n = lax_internal._eye(a.dtype, (n, n)) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. - Fmat = ufuncs.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n + Fmat = lax.integer_pow(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis], -1) - eye_n # eigh impl doesn't support batch dims, but future-proof the grad. dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) vdag_adot_v = dot(dot(_H(v), a_dot), v) - dv = dot(v, ufuncs.multiply(Fmat, vdag_adot_v)) - dw = ufuncs.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)) + dv = dot(v, Fmat * vdag_adot_v) + dw = _extract_diagonal(vdag_adot_v.real) return (v, w_real), (dv, dw) @@ -1011,10 +1046,10 @@ def _triangular_solve_jvp_rule_a( unit_diagonal): m, n = b.shape[-2:] k = 1 if unit_diagonal else 0 - g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k) + g_a = _tril(g_a, k=-k) if lower else _triu(g_a, k=k) g_a = lax.neg(g_a) - g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a - g_a = ufuncs.conj(g_a) if conjugate_a else g_a + g_a = _T(g_a) if transpose_a else g_a + g_a = g_a.conj() if conjugate_a else g_a dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) @@ -1149,11 +1184,11 @@ def _lu_pivots_body_fn(i, permutation_and_swaps): permutation, swaps = permutation_and_swaps batch_dims = swaps.shape[:-1] j = swaps[..., i] - iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims)) + iotas = _broadcasted_iotas(*batch_dims) x = permutation[..., i] - y = permutation[iotas + (j,)] + y = permutation[(*iotas, j)] permutation = permutation.at[..., i].set(y) - return permutation.at[iotas + (j,)].set(x), swaps + return permutation.at[(*iotas, j)].set(x), swaps def _generic_lu_pivots_to_permutation(swaps, permutation_size): @@ -1173,7 +1208,7 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): k = swaps.shape[-1] m = permutation_size - permutation = lax.broadcasted_iota(jnp.int32, batch_dims + (m,), + permutation = lax.broadcasted_iota(np.int32, batch_dims + (m,), len(batch_dims)) if m == 0 or k == 0: return permutation @@ -1250,30 +1285,32 @@ def _lu_unblocked(a): m, n = a.shape def body(k, state): pivot, perm, a = state - m_idx = jnp.arange(m) - n_idx = jnp.arange(n) + m_idx = lax.iota('int32', m) + n_idx = lax.iota('int32', n) - if jnp.issubdtype(a.dtype, jnp.complexfloating): + if dtypes.issubdtype(a.dtype, np.complexfloating): t = a[:, k] - magnitude = ufuncs.abs(ufuncs.real(t)) + ufuncs.abs(ufuncs.imag(t)) + magnitude = abs(t.real) + abs(t.imag) else: - magnitude = ufuncs.abs(a[:, k]) - i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) - pivot = pivot.at[k].set(i.astype(pivot.dtype)) + magnitude = abs(a[:, k]) + i = lax.argmax(lax.select(m_idx >= k, magnitude, lax.full_like(magnitude, -np.inf)), + axis=0, index_dtype=pivot.dtype) + pivot = pivot.at[k].set(i) a = a.at[[k, i],].set(a[[i, k],]) perm = perm.at[[i, k],].set(perm[[k, i],]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] - a = a.at[:, k].set(jnp.where((m_idx > k) & (x != 0), a[:, k] / x, a[:, k])) + a = a.at[:, k].set(lax.select((m_idx > k) & (x != 0), a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) - a = a - jnp.where((m_idx[:, None] > k) & (n_idx[None, :] > k), - jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype)) + a_outer = a[:, k, None] * a[k, None] + a = a - lax.select((m_idx[:, None] > k) & (n_idx[None, :] > k), + a_outer, lax_internal._zeros(a_outer)) return pivot, perm, a - pivot = jnp.zeros((min(m, n),), dtype=jnp.int32) - perm = jnp.arange(m, dtype=jnp.int32) + pivot = lax.full((min(m, n),), 0, dtype=np.int32) + perm = lax.iota('int32', m) if m == 0 and n == 0: # If the array is empty, the loop body never executes but tracing it to a # jaxpr fails because the indexing cannot succeed. @@ -1285,8 +1322,8 @@ def _lu_blocked(a, block_size=128): """Blocked LU decomposition, as an unrolled loop.""" m, n = a.shape r = min(m, n) - pivot = jnp.zeros((r,), dtype=jnp.int32) - perm = jnp.arange(m, dtype=jnp.int32) + pivot = lax.full((r,), 0, dtype=np.int32) + perm = lax.iota('int32', m) for k in range(0, r, block_size): b = min(r - k, block_size) block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k+b]) @@ -1327,8 +1364,8 @@ def _lu_abstract_eval(operand): m = operand.shape[-2] n = operand.shape[-1] pivot = operand.update(shape=batch_dims + (core.min_dim(m, n),), - dtype=jnp.int32) - perm = operand.update(shape=batch_dims + (m,), dtype=jnp.int32) + dtype=np.int32) + perm = operand.update(shape=batch_dims + (m,), dtype=np.int32) else: pivot = operand perm = operand @@ -1339,14 +1376,14 @@ def _lu_jvp_rule(primals, tangents): a_dot, = tangents lu, pivots, permutation = lu_p.bind(a) - a_shape = jnp.shape(a) + a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax.dtype(a) k = min(m, n) batch_dims = a_shape[:-2] - iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,))) - x = a_dot[iotas[:-1] + (permutation, slice(None))] + iotas = _broadcasted_iotas(*batch_dims, 1) + x = a_dot[(*iotas[:-1], permutation, slice(None))] # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas @@ -1361,14 +1398,13 @@ def _lu_jvp_rule(primals, tangents): l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = lax_internal._const(lu, 0) - l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding) - l = l + lax.expand_dims(jnp.eye(m, m, dtype=dtype), range(l.ndim - 2)) - - u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero, + l = lax.pad(_tril(lu[..., :, :k], -1), zero, l_padding) + l = l + lax.expand_dims(lax_internal._eye(dtype, (m, m)), range(l.ndim - 2)) + u_eye = lax.pad(lax_internal._eye(dtype, (n - k, n - k)), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) - u = (lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + + u = (lax.pad(_triu(lu[..., :k, :]), zero, u_padding) + lax.expand_dims(u_eye, range(lu.ndim - 2))) la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True, @@ -1376,8 +1412,9 @@ def _lu_jvp_rule(primals, tangents): lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) - l_dot = jnp.matmul(l, jnp.tril(lau, -1), precision=lax.Precision.HIGHEST) - u_dot = jnp.matmul(jnp.triu(lau), u, precision=lax.Precision.HIGHEST) + with config.default_matmul_precision("highest"): + l_dot = l @ _tril(lau, -1) + u_dot = _triu(lau) @ u lu_dot = l_dot + u_dot return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots), ad_util.Zero.from_value(permutation)) @@ -1493,10 +1530,9 @@ def _lu_tpu_lowering_rule(ctx, operand): mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu') -@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)') def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array: m = lu.shape[0] - x = jnp.reshape(b, (m, math.prod(b.shape[1:]))) + x = lax.reshape(b, (m, math.prod(b.shape[1:]))) if trans == 0: x = x[permutation, :] x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) @@ -1507,7 +1543,8 @@ def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array conjugate_a=conj) x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True, transpose_a=True, conjugate_a=conj) - x = x[jnp.argsort(permutation), :] + _, ind = lax.sort_key_val(permutation, lax.iota('int32', len(permutation))) + x = x[ind, :] else: raise ValueError(f"'trans' value must be 0, 1, or 2, got {trans}") return lax.reshape(x, b.shape) @@ -1531,7 +1568,7 @@ def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array: "number of dimensions, last axis of LU decomposition " "matrix (shape {}) and b array (shape {}) must match" .format(lu.shape, b.shape)) - b = b[..., jnp.newaxis] + b = b[..., np.newaxis] else: if b.shape[-2] != lu.shape[-1]: raise ValueError("When LU decomposition matrix and b different " @@ -1539,7 +1576,15 @@ def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array: "matrix (shape {}) and second to last axis of b array " "(shape {}) must match" .format(lu.shape, b.shape)) - x = _lu_solve_core(lu, permutation, b, trans) + + batch_shape = lax.broadcast_shapes(lu.shape[:-2], permutation.shape[:-1], b.shape[:-2]) + lu = _broadcast_to(lu, (*batch_shape, *lu.shape[-2:])) + permutation = _broadcast_to(permutation, (*batch_shape, permutation.shape[-1])) + b = _broadcast_to(b, (*batch_shape, *b.shape[-2:])) + fn = _lu_solve_core + for _ in batch_shape: + fn = api.vmap(fn, in_axes=(0, 0, 0, None)) + x = fn(lu, permutation, b, trans) return x[..., 0] if rhs_vector else x @@ -1823,14 +1868,14 @@ def qr_jvp_rule(primals, tangents, *, full_matrices): raise NotImplementedError( "Unimplemented case of QR decomposition derivative") dx_rinv = triangular_solve(r, dx) # Right side solve by default - qt_dx_rinv = jnp.matmul(_H(q), dx_rinv) - qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1) + qt_dx_rinv = _H(q) @ dx_rinv + qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs - I = lax.expand_dims(jnp.eye(n, dtype=do.dtype), range(qt_dx_rinv.ndim - 2)) + I = lax.expand_dims(lax_internal._eye(do.dtype, (n, n)), range(qt_dx_rinv.ndim - 2)) do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) - dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv - dr = jnp.matmul(qt_dx_rinv - do, r) + dq = q @ (do - qt_dx_rinv) + dx_rinv + dr = (qt_dx_rinv - do) @ r return (q, r), (dq, dr) def _qr_batching_rule(batched_args, batch_dims, *, full_matrices): @@ -1843,8 +1888,10 @@ def _qr_lowering(a, *, full_matrices): *batch_dims, m, n = a.shape if m == 0 or n == 0: k = m if full_matrices else min(m, n) - q = jnp.broadcast_to(jnp.eye(m, k, dtype=a.dtype), (*batch_dims, m, k)) - r = jnp.empty((*batch_dims, k, n), dtype=a.dtype) + q = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (m, k)), + (*batch_dims, m, k), + (len(batch_dims), len(batch_dims) + 1)) + r = lax.full((*batch_dims, k, n), 0, dtype=a.dtype) return q, r r, taus = geqrf(a) @@ -1857,7 +1904,7 @@ def _qr_lowering(a, *, full_matrices): else: q = householder_product(r, taus) r = r[..., :n, :n] - r = jnp.triu(r) + r = _triu(r) return q, r @@ -1908,7 +1955,7 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index): raise NotImplementedError -@jax.default_matmul_precision("float32") +@config.default_matmul_precision("float32") def _svd_jvp_rule( primals, tangents, *, full_matrices, compute_uv, subset_by_index ): @@ -1926,13 +1973,13 @@ def _svd_jvp_rule( Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = Ut @ dA @ V - ds = ufuncs.real(jnp.diagonal(dS, 0, -2, -1)) + ds = _extract_diagonal(dS.real) if not compute_uv: return (s,), (ds,) s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim)) - s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else + s_diffs_zeros = lax_internal._eye(s.dtype, (s.shape[-1], s.shape[-1])) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2)) F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s)) @@ -1940,7 +1987,7 @@ def _svd_jvp_rule( s_zeros = (s == 0).astype(s.dtype) s_inv = 1 / (s + s_zeros) - s_zeros - s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv) + s_inv_mat = _construct_diagonal(s_inv) dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat.astype(A.dtype) dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag) dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS))) @@ -1959,15 +2006,17 @@ def _svd_jvp_rule( def _empty_svd(a, *, full_matrices, compute_uv): batch_shape = a.shape[:-2] m, n = a.shape[-2:] - s = jnp.empty(batch_shape + (0,), dtype=lax_internal._complex_basetype(a.dtype)) + s = lax.full(batch_shape + (0,), 0, dtype=lax_internal._complex_basetype(a.dtype)) if not compute_uv: return (s,) if full_matrices: size = max(m, n) - u = jnp.broadcast_to(jnp.eye(size, dtype=a.dtype), batch_shape + (size, size)) + u = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (size, size)), + (*batch_shape, size, size), + (len(batch_shape), len(batch_shape) + 1)) else: - u = jnp.empty(batch_shape + (m, n), dtype=a.dtype) - v = jnp.empty(batch_shape + (0, 0), dtype=a.dtype) + u = lax.full(batch_shape + (m, n), 0, dtype=a.dtype) + v = lax.full(batch_shape + (0, 0), 0, dtype=a.dtype) if m < n: u, v = v, u return s, u, v @@ -2207,24 +2256,25 @@ def _tridiagonal_solve_batching_rule( def _tridiagonal_solve_jax(dl, d, du, b, **kw): """Pure JAX implementation of `tridiagonal_solve`.""" def prepend_zero(x): - return jnp.append( - jnp.zeros((1,) + x.shape[1:], dtype=x.dtype), - x[:-1], axis=0) + return lax.concatenate( + [lax.full((1,) + x.shape[1:], 0, dtype=x.dtype), x[:-1]], dimension=0) fwd1 = lambda tu_, x: x[1] / (x[0] - x[2] * tu_) def fwd2(b_, x): - return (x[0] - x[3][jnp.newaxis, ...] * b_) / ( - x[1] - x[3] * x[2])[jnp.newaxis, ...] + return (x[0] - x[3][np.newaxis, ...] * b_) / ( + x[1] - x[3] * x[2])[np.newaxis, ...] - bwd1 = lambda x_, x: x[0] - x[1][jnp.newaxis, ...] * x_ + bwd1 = lambda x_, x: x[0] - x[1][np.newaxis, ...] * x_ double = lambda f, args: (f(*args), f(*args)) # Move relevant dimensions to the front for the scan. - dl = jnp.moveaxis(dl, -1, 0) - d = jnp.moveaxis(d, -1, 0) - du = jnp.moveaxis(du, -1, 0) - b = jnp.moveaxis(b, -1, 0) - b = jnp.moveaxis(b, -1, 0) + moveaxis_fwd = lambda x: lax.transpose(x, (x.ndim - 1, *range(x.ndim - 1))) + moveaxis_bwd = lambda x: lax.transpose(x, (*range(1, x.ndim), 0)) + dl = moveaxis_fwd(dl) + d = moveaxis_fwd(d) + du = moveaxis_fwd(du) + b = moveaxis_fwd(b) + b = moveaxis_fwd(b) # Forward pass. _, tu_ = lax.scan(lambda tu_, x: double(fwd1, (tu_, x)), @@ -2244,8 +2294,8 @@ def fwd2(b_, x): unroll=32) result = x_[::-1] - result = jnp.moveaxis(result, 0, -1) - result = jnp.moveaxis(result, 0, -1) + result = moveaxis_bwd(result) + result = moveaxis_bwd(result) return result @@ -2433,7 +2483,7 @@ def hessenberg(a: ArrayLike) -> tuple[Array, Array]: return hessenberg_p.bind(a) def _hessenberg_abstract_eval(a): - if a.dtype not in (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128): + if a.dtype not in (np.float32, np.float64, np.complex64, np.complex128): raise TypeError("hessenberg requires a.dtype to be float32, float64, " f"complex64, or complex128, got {a.dtype}.") if a.ndim < 2: @@ -2509,19 +2559,20 @@ def tridiagonal(a: ArrayLike, *, lower=True first superdiagonal. ``taus`` contains the scalar factors of the elementary Householder reflectors. """ - arr, d, e, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower) - nan = arr.dtype.type(jnp.nan) - if jnp.issubdtype(arr.dtype, np.complexfloating): - nan = nan + arr.dtype.type(jnp.nan * 1j) - arr = jnp.where((info == 0)[..., None, None], arr, nan) - real_type = jnp.finfo(arr.dtype).dtype.type - d = jnp.where((info == 0)[..., None], d, real_type(jnp.nan)) - e = jnp.where((info == 0)[..., None], e, real_type(jnp.nan)) - taus = jnp.where((info == 0)[..., None], taus, nan) + arr, d, e, taus, info = tridiagonal_p.bind(lax_internal.asarray(a), lower=lower) + def nans_like(arr): + if dtypes.issubdtype(arr.dtype, np.complexfloating): + return lax.full_like(arr, np.nan + 1j * np.nan) + return lax.full_like(arr, np.nan) + mask = lambda x: lax.broadcast_in_dim(info == 0, x.shape, range(info.ndim)) + arr = lax.select(mask(arr), arr, nans_like(arr)) + d = lax.select(mask(d), d, nans_like(d)) + e = lax.select(mask(e), e, nans_like(e)) + taus = lax.select(mask(taus), taus, nans_like(taus)) return arr, d, e, taus def _tridiagonal_abstract_eval(a, *, lower): - if a.dtype not in (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128): + if a.dtype not in (np.float32, np.float64, np.complex64, np.complex128): raise TypeError("tridiagonal requires a.dtype to be float32, float64, " f"complex64, or complex128, got {a.dtype}.") if a.ndim < 2: @@ -2533,7 +2584,7 @@ def _tridiagonal_abstract_eval(a, *, lower): if a.shape[-1] == 0: raise TypeError("tridiagonal requires the last two dimensions of a to be " f"non-zero, got a.shape of {a.shape}.") - real_dtype = jnp.finfo(a.dtype).dtype + real_dtype = dtypes.finfo(a.dtype).dtype return [ a, ShapedArray(a.shape[:-2] + (a.shape[-1],), real_dtype), @@ -2573,7 +2624,7 @@ def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower): # Utilities def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value: - if jnp.issubdtype(aval.dtype, np.complexfloating): + if dtypes.issubdtype(aval.dtype, np.complexfloating): return mlir.full_like_aval(ctx, np.nan + np.nan * 1j, aval) else: return mlir.full_like_aval(ctx, np.nan, aval) diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 45f9167ab807..69c7fdc0228b 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -18,17 +18,18 @@ import math from typing import Any -import jax +from jax._src.custom_derivatives import custom_jvp from jax._src import dtypes from jax._src.lax import lax from jax._src.lax import convolution from jax._src import util +from jax._src.typing import Array, ArrayLike import numpy as np DType = Any def conv_general_dilated_patches( - lhs: jax.typing.ArrayLike, + lhs: ArrayLike, filter_shape: Sequence[int], window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], @@ -37,7 +38,7 @@ def conv_general_dilated_patches( dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers | None = None, precision: lax.Precision | None = None, preferred_element_type: DType | None = None, -) -> jax.Array: +) -> Array: """Extract patches subject to the receptive field of `conv_general_dilated`. Runs the input through a convolution with given parameters. The kernel of the @@ -101,7 +102,7 @@ def conv_general_dilated_patches( n_channels = lhs_array.shape[lhs_spec[1]] # Move separate `lhs` spatial locations into separate `rhs` channels. - rhs = lax._eye(lhs_array.dtype, shape=(spatial_size, spatial_size), offset=0) + rhs = lax._eye(lhs_array.dtype, shape=(spatial_size, spatial_size)) rhs = lax.broadcast_in_dim(rhs, (n_channels, spatial_size, spatial_size), (1, 2)) rhs = lax.reshape(rhs, (n_channels * spatial_size, 1, *filter_shape)) rhs = util.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) @@ -123,8 +124,8 @@ def conv_general_dilated_patches( def conv_general_dilated_local( - lhs: jax.typing.ArrayLike, - rhs: jax.typing.ArrayLike, + lhs: ArrayLike, + rhs: ArrayLike, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], filter_shape: Sequence[int], @@ -132,7 +133,7 @@ def conv_general_dilated_local( rhs_dilation: Sequence[int] | None = None, dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers | None = None, precision: lax.PrecisionLike = None -) -> jax.Array: +) -> Array: """General n-dimensional unshared convolution operator with optional dilation. Also known as locally connected layer, the operation is equivalent to @@ -249,14 +250,14 @@ def _wrap_between(x, _a): return lax.sub(rem, a) -def _replace_inf(x: jax.Array) -> jax.Array: +def _replace_inf(x: Array) -> Array: re_x = lax.real(x) if dtypes.issubdtype(x.dtype, np.complexfloating) else x inf = lax._const(re_x, float('inf')) return lax.select(lax.eq(re_x, inf), lax._zeros(x), x) -@jax.custom_jvp -def logaddexp(x1: jax.typing.ArrayLike, x2: jax.typing.ArrayLike, /) -> jax.Array: +@custom_jvp +def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute log(exp(x1) + exp(x2)) avoiding overflow.""" x1_arr = lax.asarray(x1) x2_arr = lax.asarray(x2) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 4faa0bdd390b..c9a07072ddc7 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -35,7 +35,6 @@ from jax._src.lax import slicing from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy import lax_numpy from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip, unzip2) import numpy as np @@ -231,7 +230,10 @@ def pargmax(x, axis_name): def _axis_index_of_val(x, val, axis_name): idx = axis_index(axis_name) - validx = lax_numpy.where(val == x, idx, dtypes.iinfo(dtypes.dtype(idx)).max) + mask = (val == x) + validx = lax.select(mask, + lax.full(mask.shape, idx), + lax.full(mask.shape, dtypes.iinfo(dtypes.dtype(idx)).max, dtype=idx.dtype)) return pmin(validx, axis_name) def _validate_reduce_axis_index_groups(axis_index_groups): @@ -779,7 +781,7 @@ def _ppermute_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, per perm_indices = np.zeros(axis_size, dtype=int) for src, dst in perm: perm_indices[dst] = src - return lax_numpy.take(v, perm_indices, d), d + return v.take(perm_indices, d), d def _collective_batcher(prim, args, dims, **params): return prim.bind(*args, **params), dims if prim.multiple_results else dims[0] @@ -795,7 +797,7 @@ def _collective_batcher(prim, args, dims, **params): def _pbroadcast_transpose_rule(t, x, source, axis_name): is_source = axis_index(axis_name) == source tsum = psum(t, axis_name) - return [lax_numpy.where(is_source, tsum, lax_numpy.zeros_like(t))] + return [lax.select(is_source, lax.full_like(t, tsum), lax.full_like(t, 0))] def _pbroadcast_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, source): (v,), (d,) = vals_in, dims_in @@ -810,7 +812,7 @@ def _pbroadcast_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, s return pbroadcast_p.bind(v, source=source, axis_name=remaining_axes), d if d is batching.not_mapped: return v, d - return lax_numpy.take(v, [source] * axis_size, d), d + return v.take([source] * axis_size, d), d def _pbroadcast_lowering(ctx, x, *, axis_name, source): replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, None) diff --git a/jax/_src/lax/qdwh.py b/jax/_src/lax/qdwh.py index bac3ea957955..0cfd67d8797a 100644 --- a/jax/_src/lax/qdwh.py +++ b/jax/_src/lax/qdwh.py @@ -28,11 +28,33 @@ import functools -import jax -import jax.numpy as jnp -from jax import lax +from jax._src import api +from jax._src import config from jax._src import core +from jax._src import dtypes +from jax._src.lax.control_flow import loops +from jax._src.lax import lax from jax._src.lax import linalg as lax_linalg +from jax._src.lax import slicing +from jax._src.typing import Array +import numpy as np + + +def _norm(x, axis=None): + return lax.sqrt((abs(x) ** 2).sum(axis=axis)) + +def _one_norm(x): + assert x.ndim == 2 + return abs(x).sum(0).max() + +def _inf_norm(x): + assert x.ndim == 2 + return abs(x).sum(1).max() + + +def _broadcast_to(x: Array, shape: tuple[int, ...]) -> Array: + assert x.ndim <= len(shape) + return lax.broadcast_in_dim(x, shape, range(len(shape) - x.ndim, len(shape))) # Helpers for working with padded shapes @@ -42,24 +64,26 @@ def _mask(x, dims, alternative=0): Replaces values outside those dimensions with `alternative`. `alternative` is broadcast with `x`. """ - assert jnp.ndim(x) == len(dims) + assert np.ndim(x) == len(dims) mask = None for i, d in enumerate(dims): if d is not None: - mask_dim_i = lax.broadcasted_iota(jnp.int32, x.shape, i) < d + mask_dim_i = lax.broadcasted_iota(np.int32, x.shape, i) < d mask = mask_dim_i if mask is None else (mask & mask_dim_i) - return x if mask is None else jnp.where(mask, x, alternative) + + alternative = _broadcast_to(lax.asarray(alternative), x.shape).astype(x.dtype) + return x if mask is None else lax.select(mask, x, alternative) def _pad_in_dim(x, low=0, high=0, interior=0, fill_value=0, axis=0): pads = [(0, 0, 0)] * x.ndim pads[axis] = (low, high, interior) - return lax.pad(x, jnp.array(fill_value, x.dtype), pads) + return lax.pad(x, lax.convert_element_type(fill_value, x.dtype), pads) def _dynamic_concat(a, b, m, axis=0): "Concatenates padded arrays `a` and `b` where the true size of `a` is `m`." if m is None: - return jnp.concatenate([a, b], axis=axis) - return lax.dynamic_update_slice_in_dim( + return lax.concatenate([a, b], dimension=axis) + return slicing.dynamic_update_slice_in_dim( _pad_in_dim(a, high=b.shape[axis], axis=axis), b, m, axis) @@ -74,12 +98,12 @@ def _use_qr(u, m, n, params): a_minus_e_by_sqrt_c, sqrt_c, e = params M, N = u.shape - y = _dynamic_concat(sqrt_c * u, jnp.eye(N, dtype=jnp.dtype(u)), m) + y = _dynamic_concat(sqrt_c * u, lax._eye(np.dtype(u), (N, N)), m) q, _ = lax_linalg.qr(y, full_matrices=False) # q1 = q[:m, :] - q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n)) + q1 = _mask(slicing.slice(q, (0, 0), (M, N)), (m, n)) # q2 = (q[m:, :]).T.conj() - q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0) + q2 = slicing.dynamic_slice_in_dim(q, m, N, axis=0) q2 = _mask(q2, (n, n)).T.conj() return e * u + a_minus_e_by_sqrt_c * (q1 @ q2) @@ -94,11 +118,11 @@ def _use_cholesky(u, m, n, params): """ a_minus_e, c, e = params _, N = u.shape - x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u)) + x = c * (u.T.conj() @ u) + lax._eye(np.dtype(u), (N, N)) # Pads the lower-right corner with the identity matrix to prevent the Cholesky # decomposition from failing due to the matrix not being PSD if padded with # zeros. - x = _mask(x, (n, n), jnp.eye(N, dtype=x.dtype)) + x = _mask(x, (n, n), lax._eye(x.dtype, (N, N))) # `y` is lower triangular. y = lax_linalg.cholesky(x, symmetrize_input=False) @@ -119,18 +143,18 @@ def _qdwh(x, m, n, max_iterations, eps): # norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for # the smallest singular value of x. if eps is None: - eps = float(jnp.finfo(x.dtype).eps) - one_norm = jnp.linalg.norm(x, ord=1) - inf_norm = jnp.linalg.norm(x, ord=jnp.inf) + eps = float(dtypes.finfo(x.dtype).eps) + one_norm = _one_norm(x) + inf_norm = _inf_norm(x) alpha_inverse = lax.rsqrt(one_norm) * lax.rsqrt(inf_norm) - alpha_inverse = jnp.where(one_norm == 0, 1, alpha_inverse) + alpha_inverse = lax.select(one_norm == 0, lax._ones(alpha_inverse), alpha_inverse) u = x * alpha_inverse.astype(x.dtype) l = eps # Iteration tolerances. tol_l = 10.0 * eps / 2.0 - tol_norm = jnp.cbrt(tol_l) + tol_norm = lax.cbrt(tol_l) def get_qr_params(a, b, c): e = b / c @@ -169,22 +193,22 @@ def iteration(k, state, update_fn, coefs, test_convergence): # As l → 1, the coefficients a, b, c → 3, 1, 3, which is Halley's method. params = get_chol_params(3, 1, 3) else: - params = lax.dynamic_index_in_dim(coefs, k, keepdims=False) + params = slicing.dynamic_index_in_dim(coefs, k, keepdims=False) u_prev = u u = update_fn(u, m, n, params) is_not_converged = True if test_convergence: - is_not_converged = jnp.linalg.norm(u - u_prev) > tol_norm + is_not_converged = _norm(u - u_prev) > tol_norm return u, is_not_converged def iterate(u, coefs, **kwargs): if not coefs: return u, True - coefs = jnp.array(coefs).astype(x.dtype) + coefs = np.array(coefs).astype(x.dtype) body = functools.partial(iteration, coefs=coefs, **kwargs) - return lax.fori_loop(0, len(coefs), body, (u, True)) + return loops.fori_loop(0, len(coefs), body, (u, True)) u, _ = iterate( u, coefs=qr_coefs, update_fn=_use_qr, test_convergence=False @@ -197,7 +221,7 @@ def iterate(u, coefs, **kwargs): # (coef = None) until convergence. def cond_fun(state): k, _, is_not_converged = state - return jnp.logical_and(is_not_converged, k < max_iterations) + return lax.bitwise_and(is_not_converged, k < max_iterations) def body_fun(state): k, u, is_not_converged = state @@ -211,7 +235,7 @@ def body_fun(state): return k + 1, u, is_not_converged k = len(qr_coefs) + len(chol_coefs) - num_iters, u, is_not_converged = lax.while_loop( + num_iters, u, is_not_converged = loops.while_loop( cond_fun, body_fun, (k, u, is_not_converged) ) @@ -222,14 +246,14 @@ def body_fun(state): h = (h + h.T.conj()) / 2 # Converged within the maximum number of iterations. - is_converged = jnp.logical_not(is_not_converged) + is_converged = lax.bitwise_not(is_not_converged) return u, h, num_iters, is_converged # TODO: Add pivoting. @functools.partial( - jax.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps') + api.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps') ) def qdwh( x, @@ -279,7 +303,7 @@ def qdwh( else: m, n = M, N - with jax.default_matmul_precision('float32'): + with config.default_matmul_precision('float32'): u, h, num_iters, is_converged = _qdwh(x, m, n, max_iterations, eps) return u, h, num_iters, is_converged diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 206d52ba5ebd..2a3a63e89a35 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -24,9 +24,8 @@ import numpy as np -import jax - from jax._src import ad_util +from jax._src import api from jax._src import config from jax._src import core from jax._src import dispatch @@ -1401,7 +1400,7 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): inserted_window_dims=(), scatter_dims_to_operand_dims=dims) index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd) - return jax.vmap( + return api.vmap( partial(scatter, dimension_numbers=dnums, indices_are_sorted=True, unique_indices=True, mode=GatherScatterMode.CLIP), diff --git a/jax/_src/lax/stack.py b/jax/_src/lax/stack.py index 882195f17d51..e5a6fdf2163d 100644 --- a/jax/_src/lax/stack.py +++ b/jax/_src/lax/stack.py @@ -22,9 +22,9 @@ from typing import Any -import jax -from jax import lax -import jax.numpy as jnp +from jax._src.lax import lax +from jax._src.lax import slicing +from jax._src import tree_util class Stack: """A bounded functional stack implementation. Elements may be pytrees.""" @@ -44,9 +44,9 @@ def create(capacity: int, prototype: Any) -> Stack: structure; the specific values are ignored. """ return Stack( - jnp.array(0, jnp.int32), - jax.tree_util.tree_map( - lambda x: jnp.zeros((capacity,) + tuple(x.shape), x.dtype), prototype)) + lax.full((), 0, 'int32'), + tree_util.tree_map( + lambda x: lax.full((capacity, *x.shape), 0, x.dtype), prototype)) def empty(self) -> Any: """Returns true if the stack is empty.""" @@ -56,23 +56,23 @@ def push(self, elem: Any) -> Stack: """Pushes `elem` onto the stack, returning the updated stack.""" return Stack( self._size + 1, - jax.tree_util.tree_map( - lambda x, y: lax.dynamic_update_index_in_dim(x, y, self._size, 0), + tree_util.tree_map( + lambda x, y: slicing.dynamic_update_index_in_dim(x, y, self._size, 0), self._data, elem)) def pop(self) -> tuple[Any, Stack]: """Pops from the stack, returning an (elem, updated stack) pair.""" - elem = jax.tree_util.tree_map( - lambda x: lax.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False), + elem = tree_util.tree_map( + lambda x: slicing.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False), self._data) return elem, Stack(self._size - 1, self._data) def flatten(self): - leaves, treedef = jax.tree_util.tree_flatten(self._data) + leaves, treedef = tree_util.tree_flatten(self._data) return ([self._size] + leaves), treedef @staticmethod def unflatten(treedef, leaves): - return Stack(leaves[0], jax.tree_util.tree_unflatten(treedef, leaves[1:])) + return Stack(leaves[0], tree_util.tree_unflatten(treedef, leaves[1:])) -jax.tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten) +tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten) diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index 77ff4297e137..6055bce44303 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -40,13 +40,30 @@ import operator from typing import Any -import jax -from jax import lax +from jax._src import api +from jax._src import config from jax._src import core -import jax.numpy as jnp - - -@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) +from jax._src import dtypes +from jax._src.lax.control_flow import loops +from jax._src.lax import lax +from jax._src.lax import linalg as lax_linalg +from jax._src.lax import qdwh +from jax._src.typing import Array +import numpy as np + +def _construct_diagonal(s: Array) -> Array: + """Construct a (batched) diagonal matrix""" + # signature: (...,n)->(...,n,n) + i = lax.iota('int32', s.shape[-1]) + return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s) + +def _extract_diagonal(s: Array) -> Array: + """Extract the diagonal from a batched matrix""" + # signature: (...,n,m)->(...k) where k=min(n,m) + i = lax.iota('int32', min(s.shape[-2], s.shape[-1])) + return s[..., i, i] + +@functools.partial(api.jit, static_argnums=(1, 2, 3, 4)) def _svd_tall_and_square_input( a: Any, hermitian: bool, @@ -69,22 +86,23 @@ def _svd_tall_and_square_input( `a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned. """ - u_p, h, _, _ = lax.linalg.qdwh( + u_p, h, _, _ = qdwh.qdwh( a, is_hermitian=hermitian, max_iterations=max_iterations ) # TODO: Uses `eigvals_only=True` if `compute_uv=False`. - v, s = lax.linalg.eigh( + v, s = lax_linalg.eigh( h, subset_by_index=subset_by_index, sort_eigenvalues=False ) # Singular values are non-negative by definition. But eigh could return small # negative values, so we clamp them to zero. - s = jnp.maximum(s, 0.0) + s = lax.max(s, lax._zeros(s)) # Sort or reorder singular values to be in descending order. - sort_idx = jnp.argsort(s, descending=True) - s_out = s[sort_idx] + s_out, sort_idx = lax.rev(s, (0,)), lax.rev(lax.iota('int32', len(s)), (0,)) + s_out, sort_idx = lax.sort_key_val(s_out, sort_idx) + s_out, sort_idx = lax.rev(s_out, (0,)), lax.rev(sort_idx, (0,)) if not compute_uv: return s_out @@ -99,18 +117,20 @@ def _svd_tall_and_square_input( # eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing # 35, no. 3 (2013): A1325-A1349. def correct_rank_deficiency(u_out): - u_out, r = lax.linalg.qr(u_out, full_matrices=False) - u_out = u_out @ jnp.diag(jnp.where(jnp.diag(r) >= 0, 1, -1)) + u_out, r = lax_linalg.qr(u_out, full_matrices=False) + r_diag = _extract_diagonal(r) + ones = lax.full(r_diag.shape, 1, u_out.dtype) + u_out = u_out @ _construct_diagonal(lax.select(r_diag >= 0, ones, -ones)) return u_out - eps = float(jnp.finfo(a.dtype).eps) + eps = float(dtypes.finfo(a.dtype).eps) do_correction = s_out[-1] <= a.shape[1] * eps * s_out[0] cond_f = lambda args: args[1] body_f = lambda args: (correct_rank_deficiency(args[0]), False) - u_out, _ = lax.while_loop(cond_f, body_f, (u_out, do_correction)) + u_out, _ = loops.while_loop(cond_f, body_f, (u_out, do_correction)) return (u_out, s_out, v_out) -@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) +@functools.partial(api.jit, static_argnums=(1, 2, 3, 4, 5)) def svd( a: Any, full_matrices: bool, @@ -197,7 +217,7 @@ def svd( reduce_to_square = False if full_matrices: - q_full, a_full = lax.linalg.qr(a, full_matrices=True) + q_full, a_full = lax_linalg.qr(a, full_matrices=True) q = q_full[:, :n] u_out_null = q_full[:, n:] a = a_full[:n, :] @@ -206,16 +226,16 @@ def svd( # The constant `1.15` comes from Yuji Nakatsukasa's implementation # https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav if m > 1.15 * n: - q, a = lax.linalg.qr(a, full_matrices=False) + q, a = lax_linalg.qr(a, full_matrices=False) reduce_to_square = True if not compute_uv: - with jax.default_matmul_precision('float32'): + with config.default_matmul_precision('float32'): return _svd_tall_and_square_input( a, hermitian, compute_uv, max_iterations, subset_by_index ) - with jax.default_matmul_precision('float32'): + with config.default_matmul_precision('float32'): u_out, s_out, v_out = _svd_tall_and_square_input( a, hermitian, compute_uv, max_iterations, subset_by_index ) @@ -223,17 +243,20 @@ def svd( u_out = q @ u_out if full_matrices: - u_out = jnp.hstack((u_out, u_out_null)) + u_out = lax.concatenate([u_out, u_out_null], dimension=1) - is_finite = jnp.all(jnp.isfinite(a)) - cond_f = lambda args: jnp.logical_not(args[0]) + if dtypes.issubdtype(a.dtype, np.complexfloating): + is_finite = (lax.is_finite(a.real) & lax.is_finite(a.imag)).all() + else: + is_finite = lax.is_finite(a).all() + cond_f = lambda args: lax.bitwise_not(args[0].astype(bool)) body_f = lambda args: ( - jnp.array(True), - jnp.full_like(u_out, jnp.nan), - jnp.full_like(s_out, jnp.nan), - jnp.full_like(v_out, jnp.nan), + lax.full((), True, dtype=bool), + lax.full_like(u_out, np.nan), + lax.full_like(s_out, np.nan), + lax.full_like(v_out, np.nan), ) - _, u_out, s_out, v_out = lax.while_loop( + _, u_out, s_out, v_out = loops.while_loop( cond_f, body_f, (is_finite, u_out, s_out, v_out) ) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 5458d71dedf4..d014e5ceb24e 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -951,7 +951,7 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array: factors = cho_factor(lax.stop_gradient(a), lower=lower) custom_solve = partial( lax.custom_linear_solve, - lambda x: lax_linalg._matvec_multiply(a, x), + lambda x: lax_linalg._broadcasted_matvec(a, x), solve=lambda _, x: cho_solve(factors, x), symmetric=True) if a.ndim == b.ndim + 1: From b1f55fd64ac9d5a30a023409cea8e5e29c7c14e3 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Sun, 1 Sep 2024 12:41:21 -0700 Subject: [PATCH 316/702] use `jax.tree` consistently in example optimizers --- jax/example_libraries/optimizers.py | 32 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/jax/example_libraries/optimizers.py b/jax/example_libraries/optimizers.py index 71680ca61b96..f9b66ea1c082 100644 --- a/jax/example_libraries/optimizers.py +++ b/jax/example_libraries/optimizers.py @@ -98,11 +98,9 @@ def step(step, opt_state): import functools from functools import partial +import jax import jax.numpy as jnp from jax._src.util import safe_zip, safe_map, unzip2 -from jax import tree_util -from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, - register_pytree_node) map = safe_map zip = safe_zip @@ -117,7 +115,7 @@ def step(step, opt_state): OptimizerState = namedtuple("OptimizerState", ["packed_state", "tree_def", "subtree_defs"]) -register_pytree_node( +jax.tree_util.register_pytree_node( OptimizerState, lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)), lambda data, xs: OptimizerState(xs[0], data[0], data[1])) @@ -182,23 +180,23 @@ def tree_opt_maker(*args, **kwargs): @functools.wraps(init) def tree_init(x0_tree): - x0_flat, tree = tree_flatten(x0_tree) + x0_flat, tree = jax.tree.flatten(x0_tree) initial_states = [init(x0) for x0 in x0_flat] - states_flat, subtrees = unzip2(map(tree_flatten, initial_states)) + states_flat, subtrees = unzip2(map(jax.tree.flatten, initial_states)) return OptimizerState(states_flat, tree, subtrees) @functools.wraps(update) def tree_update(i, grad_tree, opt_state): states_flat, tree, subtrees = opt_state - grad_flat, tree2 = tree_flatten(grad_tree) + grad_flat, tree2 = jax.tree.flatten(grad_tree) if tree2 != tree: msg = ("optimizer update function was passed a gradient tree that did " "not match the parameter tree structure with which it was " "initialized: parameter tree {} and grad tree {}.") raise TypeError(msg.format(tree, tree2)) - states = map(tree_unflatten, subtrees, states_flat) + states = map(jax.tree.unflatten, subtrees, states_flat) new_states = map(partial(update, i), grad_flat, states) - new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states)) + new_states_flat, subtrees2 = unzip2(map(jax.tree.flatten, new_states)) for subtree, subtree2 in zip(subtrees, subtrees2): if subtree2 != subtree: msg = ("optimizer update function produced an output structure that " @@ -209,9 +207,9 @@ def tree_update(i, grad_tree, opt_state): @functools.wraps(get_params) def tree_get_params(opt_state): states_flat, tree, subtrees = opt_state - states = map(tree_unflatten, subtrees, states_flat) + states = map(jax.tree.unflatten, subtrees, states_flat) params = map(get_params, states) - return tree_unflatten(tree, params) + return jax.tree.unflatten(tree, params) return Optimizer(tree_init, tree_update, tree_get_params) return tree_opt_maker @@ -566,14 +564,14 @@ def make_schedule(scalar_or_schedule: float | Schedule) -> Schedule: def l2_norm(tree): """Compute the l2 norm of a pytree of arrays. Useful for weight decay.""" - leaves, _ = tree_flatten(tree) + leaves, _ = jax.tree.flatten(tree) return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves)) def clip_grads(grad_tree, max_norm): """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`.""" norm = l2_norm(grad_tree) normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm)) - return tree_map(normalize, grad_tree) + return jax.tree.map(normalize, grad_tree) ### serialization utilities @@ -600,9 +598,9 @@ def unpack_optimizer_state(opt_state): A pytree with JoinPoint leaves that contain a second level of pytrees. """ states_flat, tree_def, subtree_defs = opt_state - subtrees = map(tree_unflatten, subtree_defs, states_flat) + subtrees = map(jax.tree.unflatten, subtree_defs, states_flat) sentinels = [JoinPoint(subtree) for subtree in subtrees] - return tree_util.tree_unflatten(tree_def, sentinels) + return jax.tree.unflatten(tree_def, sentinels) def pack_optimizer_state(marked_pytree): """Converts a marked pytree to an OptimizerState. @@ -617,8 +615,8 @@ def pack_optimizer_state(marked_pytree): Returns: An equivalent OptimizerState to the input argument. """ - sentinels, tree_def = tree_flatten(marked_pytree) + sentinels, tree_def = jax.tree.flatten(marked_pytree) assert all(isinstance(s, JoinPoint) for s in sentinels) subtrees = [s.subtree for s in sentinels] - states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees)) + states_flat, subtree_defs = unzip2(map(jax.tree.flatten, subtrees)) return OptimizerState(states_flat, tree_def, subtree_defs) From ebd59487c3295c1c2698dab1d24c785c5dc41ffd Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 1 Sep 2024 16:37:48 -0700 Subject: [PATCH 317/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/39a982c1757bbb7136431b2df48d067f122c5190. PiperOrigin-RevId: 670006314 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 17cbc754b373..9550acb5e320 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "0e3e2263e30ef24560a9abe64a713e1692f07216" -XLA_SHA256 = "939e8a71c115db8575d0c9252402026e732a6b6f8d7134588ec4190b6e872382" +XLA_COMMIT = "39a982c1757bbb7136431b2df48d067f122c5190" +XLA_SHA256 = "2d4b80bc97700fb70021926771bd497ba70a8f027e0380d3717c2d9b4ee9775f" def repo(): tf_http_archive( From e8730ddfe0022d150cb534d8c2fd2cce2331b5ab Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 2 Sep 2024 13:40:37 +0200 Subject: [PATCH 318/702] [NFC] Remove unused argument, fix help string. --- jax/tools/pgo_nsys_converter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 5460edd960f5..fc06cb360509 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -28,8 +28,7 @@ parser = argparse.ArgumentParser(description='Tool to convert NVIDIA Nsys Profiles to the .pbtxt format') parser.add_argument("--profile_path", type=str, help="path to nsys profile") - parser.add_argument("--post_process", help="post process pbtxt to get minimum cost value for each instruction", action="store_true") - parser.add_argument("--pgle_output_path", type=str, help="output directory", default="/opt/paxml/workspace/lhs_pbtxt/temp.pbtxt") + parser.add_argument("--pgle_output_path", type=str, help="output file", default="/opt/paxml/workspace/lhs_pbtxt/temp.pbtxt") args = parser.parse_args() From 8cb35961360fab88235c08036cf64c7d5010f771 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 2 Sep 2024 04:44:06 -0700 Subject: [PATCH 319/702] Partially rolling forward #22998 Reverts 322d0c2f31e92e68a531f95a53c3f040d6a76bdf PiperOrigin-RevId: 670173462 --- jax/_src/lib/mlir/dialects/__init__.py | 64 ++++++++++++++++---------- 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index 01dc7e2725b5..a9bae8821db5 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -13,35 +13,49 @@ # limitations under the License. # ruff: noqa: F401 -from typing import Any -import jaxlib.mlir.dialects.arith as arith -import jaxlib.mlir.dialects.builtin as builtin -import jaxlib.mlir.dialects.chlo as chlo -import jaxlib.mlir.dialects.func as func -import jaxlib.mlir.dialects.math as math -import jaxlib.mlir.dialects.memref as memref -import jaxlib.mlir.dialects.mhlo as mhlo -import jaxlib.mlir.dialects.scf as scf +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from jaxlib.mlir.dialects import arith as arith + from jaxlib.mlir.dialects import builtin as builtin + from jaxlib.mlir.dialects import chlo as chlo + from jaxlib.mlir.dialects import func as func + from jaxlib.mlir.dialects import gpu as gpu + from jaxlib.mlir.dialects import llvm as llvm + from jaxlib.mlir.dialects import math as math + from jaxlib.mlir.dialects import memref as memref + from jaxlib.mlir.dialects import mhlo as mhlo + from jaxlib.mlir.dialects import nvgpu as nvgpu + from jaxlib.mlir.dialects import nvvm as nvvm + from jaxlib.mlir.dialects import scf as scf + from jaxlib.mlir.dialects import sparse_tensor as sparse_tensor + from jaxlib.mlir.dialects import vector as vector +else: + from jax._src import lazy_loader as _lazy + __getattr__, __dir__, __all__ = _lazy.attach("jaxlib.mlir.dialects", [ + "arith", + "builtin", + "chlo", + "func", + "gpu", + "llvm", + "math", + "memref", + "mhlo", + "nvgpu", + "nvvm", + "scf", + "sparse_tensor", + "vector", + ]) + del _lazy + # TODO(bartchr): Once JAX is released with SDY, remove the try/except. try: - import jaxlib.mlir.dialects.sdy as sdy + from jaxlib.mlir.dialects import sdy as sdy except ImportError: sdy: Any = None # type: ignore[no-redef] -import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor -import jaxlib.mlir.dialects.vector as vector -try: - # pytype: disable=import-error - import jaxlib.mlir.dialects.gpu as gpu - import jaxlib.mlir.dialects.nvgpu as nvgpu - import jaxlib.mlir.dialects.nvvm as nvvm - import jaxlib.mlir.dialects.llvm as llvm - # pytype: enable=import-error -except ImportError: - pass - -from jax._src import lib - # Alias that is set up to abstract away the transition from MHLO to StableHLO. -import jaxlib.mlir.dialects.stablehlo as hlo +from jaxlib.mlir.dialects import stablehlo as hlo From 414eb90f5bc339e8ab2a437c216c478acb4a3c8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Mon, 2 Sep 2024 06:18:13 -0700 Subject: [PATCH 320/702] Activate Householder Product to XLA's FFI PiperOrigin-RevId: 670196460 --- jax/_src/export/_export.py | 1 + .../cpu_qr_lapack_geqrf.py | 8 +- jax/_src/lax/linalg.py | 16 +++- jaxlib/lapack.py | 87 ++++++++++++------- 4 files changed, 72 insertions(+), 40 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 65f3d2852348..ee1b0dabba8d 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -931,6 +931,7 @@ def _check_lowering(lowering) -> None: _CPU_FFI_KERNELS = [ "lapack_spotrf_ffi", "lapack_dpotrf_ffi", "lapack_cpotrf_ffi", "lapack_zpotrf_ffi", "lapack_sgeqrf_ffi", "lapack_dgeqrf_ffi", "lapack_cgeqrf_ffi", "lapack_zgeqrf_ffi", + "lapack_sorgqr_ffi", "lapack_dorgqr_ffi", "lapack_cungqr_ffi", "lapack_zungqr_ffi", "lapack_ssyevd_ffi", "lapack_dsyevd_ffi", "lapack_cheevd_ffi", "lapack_zheevd_ffi", "lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi", "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py index 045e8df55cd2..94314a7ae518 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py @@ -357,7 +357,7 @@ data_2024_08_22['c128'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr'], + custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'], serialized_date=datetime.date(2024, 8, 22), inputs=(), expected_outputs=( @@ -479,7 +479,7 @@ data_2024_08_22['c64'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr'], + custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr_ffi'], serialized_date=datetime.date(2024, 8, 22), inputs=(), expected_outputs=( @@ -595,7 +595,7 @@ data_2024_08_22['f32'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr'], + custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr_ffi'], serialized_date=datetime.date(2024, 8, 22), inputs=(), expected_outputs=( @@ -703,7 +703,7 @@ data_2024_08_22['f64'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr'], + custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr_ffi'], serialized_date=datetime.date(2024, 8, 22), inputs=(), expected_outputs=( diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 1a792e3adc0c..0ed8b3c62f08 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1759,11 +1759,21 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *, f"on GPU is not implemented; b/261671778; {a_aval.shape}") a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus) else: + # TODO(b/344892332): Remove the conditional after the compatibility period + ctx_args = ( + (ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else () + ) a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape) - a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus, - a_shape_vals=a_shape_vals, - tau_shape_vals=tau_shape_vals) + a, *maybe_info_orgqr = orgqr_impl(*ctx_args, a_aval.dtype, a, taus, + a_shape_vals=a_shape_vals, + tau_shape_vals=tau_shape_vals) + if not ctx.is_forward_compat(): + # Skip the info parameter verification for the FFI kernel. + return [a] + # TODO(b/344892332): This parameter will no longer be needed after + # the forward compatibility period + info_orgqr = maybe_info_orgqr[0] zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED") select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_)) diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index a389380a61ec..e23cc0075139 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -281,10 +281,11 @@ def geqrf_hlo( # # ?orgqr: product of elementary Householder reflectors: -def orgqr_hlo(dtype, a: ir.Value, tau, *, +def orgqr_hlo(ctx, dtype, a: ir.Value, tau, *, a_shape_vals: tuple[DimensionSize, ...], tau_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() + fn_base = "un" if dtype == np.complex64 or dtype == np.complex128 else "or" + fn_base = prepare_lapack_call(fn_base=fn_base + "gqr", dtype=dtype) a_type = ir.RankedTensorType(a.type) dims = a_type.shape dims_vals = a_shape_vals @@ -294,55 +295,75 @@ def orgqr_hlo(dtype, a: ir.Value, tau, *, assert n != ir.ShapedType.get_dynamic_size() batch_dims_vals = dims_vals[:-2] num_bd = len(batch_dims_vals) - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - k = tau_shape_vals[-1] assert type(k) is int - - if dtype == np.float32: - fn = "lapack_sorgqr" - lwork = _lapack.lapack_sorgqr_workspace(m, n, k) - elif dtype == np.float64: - fn = "lapack_dorgqr" - lwork = _lapack.lapack_dorgqr_workspace(m, n, k) - elif dtype == np.complex64: - fn = "lapack_cungqr" - lwork = _lapack.lapack_cungqr_workspace(m, n, k) - elif dtype == np.complex128: - fn = "lapack_zungqr" - lwork = _lapack.lapack_zungqr_workspace(m, n, k) - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) + + if ctx.is_forward_compat(): + fn = fn_base + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + + if dtype == np.float32: + lwork = _lapack.lapack_sorgqr_workspace(m, n, k) + elif dtype == np.float64: + lwork = _lapack.lapack_dorgqr_workspace(m, n, k) + elif dtype == np.complex64: + lwork = _lapack.lapack_cungqr_workspace(m, n, k) + elif dtype == np.complex128: + lwork = _lapack.lapack_zungqr_workspace(m, n, k) + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + scalar_layout = [] + shape_type_pairs: Sequence[ShapeTypePair] = [ + (a_shape_vals, a_type.element_type), + (batch_dims_vals, i32_type), + ([lwork], a_type.element_type), + ] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + return custom_call( + fn, + result_types=result_types, + operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(k), + hlo_s32(lwork), a, tau], + operand_layouts=[scalar_layout] * 5 + [ + layout, + tuple(range(num_bd, -1, -1)), + ], + result_layouts=[ + layout, + tuple(range(num_bd - 1, -1, -1)), + [0], + ], + operand_output_aliases={5: 0}, + result_shapes=result_shapes, + ).results[:2] + fn = fn_base + "_ffi" shape_type_pairs: Sequence[ShapeTypePair] = [ (a_shape_vals, a_type.element_type), - (batch_dims_vals, i32_type), - ([lwork], a_type.element_type), ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - out = custom_call( + return custom_call( fn, result_types=result_types, - operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(k), - hlo_s32(lwork), a, tau], - operand_layouts=[scalar_layout] * 5 + [ + operands=[ + a, tau + ], + operand_layouts=[ layout, tuple(range(num_bd, -1, -1)), ], result_layouts=[ layout, - tuple(range(num_bd - 1, -1, -1)), - [0], ], - operand_output_aliases={5: 0}, + operand_output_aliases={0: 0}, result_shapes=result_shapes, + backend_config={}, + api_version=4, ).results - return out[:2] # ?potrf: Cholesky decomposition From 214577ec3517027f39ef1b0fc132b21a7640c745 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 2 Sep 2024 20:35:33 +0530 Subject: [PATCH 321/702] Better docs for jax.numpy: sin, cos and tan --- jax/_src/numpy/ufuncs.py | 82 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 3605201ad51b..71adb5d1e1ec 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -428,19 +428,95 @@ def expm1(x: ArrayLike, /) -> Array: def log1p(x: ArrayLike, /) -> Array: return lax.log1p(*promote_args_inexact('log1p', x)) -@implements(np.sin, module='numpy') + @partial(jit, inline=True) def sin(x: ArrayLike, /) -> Array: + """Compute a trigonometric sine of each element of input. + + JAX implementation of :obj:`numpy.sin`. + + Args: + x: array or scalar. Angle in radians. + + Returns: + An array containing the sine of each element in ``x``, promotes to inexact + dtype. + + See also: + - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of + input. + - :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of + input. + - :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of + trigonometric sine of each element of input. + + Examples: + >>> pi = jnp.pi + >>> x = jnp.array([pi/4, pi/2, 3*pi/4, pi]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... print(jnp.sin(x)) + [ 0.707 1. 0.707 -0. ] + """ return lax.sin(*promote_args_inexact('sin', x)) -@implements(np.cos, module='numpy') + @partial(jit, inline=True) def cos(x: ArrayLike, /) -> Array: + """Compute a trigonometric cosine of each element of input. + + JAX implementation of :obj:`numpy.cos`. + + Args: + x: scalar or array. Angle in radians. + + Returns: + An array containing the cosine of each element in ``x``, promotes to inexact + dtype. + + See also: + - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input. + - :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of + input. + - :func:`jax.numpy.arccos` and :func:`jax.numpy.acos`: Computes the inverse of + trigonometric cosine of each element of input. + + Examples: + >>> pi = jnp.pi + >>> x = jnp.array([pi/4, pi/2, 3*pi/4, 5*pi/6]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... print(jnp.cos(x)) + [ 0.707 -0. -0.707 -0.866] + """ return lax.cos(*promote_args_inexact('cos', x)) -@implements(np.tan, module='numpy') + @partial(jit, inline=True) def tan(x: ArrayLike, /) -> Array: + """Compute a trigonometric tangent of each element of input. + + JAX implementation of :obj:`numpy.tan`. + + Args: + x: scalar or array. Angle in radians. + + Returns: + An array containing the tangent of each element in ``x``, promotes to inexact + dtype. + + See also: + - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input. + - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of + input. + - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of + trigonometric tangent of each element of input. + + Examples: + >>> pi = jnp.pi + >>> x = jnp.array([0, pi/6, pi/4, 3*pi/4, 5*pi/6]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... print(jnp.tan(x)) + [ 0. 0.577 1. -1. -0.577] + """ return lax.tan(*promote_args_inexact('tan', x)) @implements(np.arcsin, module='numpy') From ab0e84b4e61c207d022670fa1380afd08f505950 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 2 Sep 2024 20:39:05 +0530 Subject: [PATCH 322/702] Better docs for jax.numpy: round, around and round_ --- jax/_src/numpy/lax_numpy.py | 54 ++++++++++++++++++++++++++++++++++--- tests/lax_numpy_test.py | 2 +- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3e4d52bf97e1..1d7d2361450a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2644,9 +2644,47 @@ def clip( arr = ufuncs.minimum(max, arr) return asarray(arr) -@util.implements(np.around, skip_params=['out']) + @partial(jit, static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: + """Round input evenly to the given number of decimals. + + JAX implementation of :func:`numpy.round`. + + Args: + a: input array or scalar. + decimals: int, default=0. Number of decimal points to which the input needs + to be rounded. It must be specified statically. Not implemented for + ``decimals < 0``. + out: Unused by JAX. + + Returns: + An array containing the rounded values to the specified ``decimals`` with + same shape and dtype as ``a``. + + Note: + ``jnp.round`` rounds to the nearest even integer for the values exactly halfway + between rounded decimal values. + + See also: + - :func:`jax.numpy.floor`: Rounds the input to the nearest integer downwards. + - :func:`jax.numpy.ceil`: Rounds the input to the nearest integer upwards. + - :func:`jax.numpy.fix` and :func:numpy.trunc`: Rounds the input to the + nearest integer towards zero. + + Examples: + >>> x = jnp.array([1.532, 3.267, 6.149]) + >>> jnp.round(x) + Array([2., 3., 6.], dtype=float32) + >>> jnp.round(x, decimals=2) + Array([1.53, 3.27, 6.15], dtype=float32) + + For values exactly halfway between rounded values: + + >>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5]) + >>> jnp.round(x1) + Array([10., 22., 12., 32.], dtype=float32) + """ util.check_arraylike("round", a) decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round") if out is not None: @@ -2676,8 +2714,18 @@ def _round_float(x: ArrayLike) -> Array: return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a))) else: return _round_float(a) -around = round -round_ = round + + +@partial(jit, static_argnames=('decimals',)) +def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: + """Alias of :func:`jax.numpy.round`""" + return round(a, decimals, out) + + +@partial(jit, static_argnames=('decimals',)) +def round_(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: + """Alias of :func:`jax.numpy.round`""" + return round(a, decimals, out) @jit diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6ca36ffe9035..0e1e7937e8dd 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6288,7 +6288,7 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', - 'amax', 'amin'] + 'amax', 'amin', 'around', 'round_'] for name in dir(jnp): if name.startswith('_') or name in unimplemented: From f1e074189047bc8dab7e47e67fad51cf94e21870 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 2 Sep 2024 08:46:31 -0700 Subject: [PATCH 323/702] Add `use_shardy_partitioner` to thread local jit state PiperOrigin-RevId: 670230769 --- jax/_src/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/config.py b/jax/_src/config.py index f17ca59385fe..a0b91e2ad6ed 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -871,6 +871,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): xla_profile_version: int | None = None pgle_profiling_runs: int | None = None enable_pgle: bool | None = None + use_shardy_partitioner: bool | None = None class _ThreadLocalStateCache(threading.local): From 59f825a23e002c03a1888eb91e679b832cf368bc Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 2 Sep 2024 21:44:18 +0100 Subject: [PATCH 324/702] Fixed the return type of ``jax.random.key_impl`` Closes #23363. --- jax/_src/random.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 6105d56f9148..7e006fac8319 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Hashable, Sequence +from collections.abc import Sequence from functools import partial import math from operator import index @@ -292,7 +292,7 @@ def _key_impl(keys: KeyArray) -> PRNGImpl: keys_dtype = typing.cast(prng.KeyTy, keys.dtype) return keys_dtype._impl -def key_impl(keys: KeyArrayLike) -> Hashable: +def key_impl(keys: KeyArrayLike) -> PRNGSpec: typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True) return PRNGSpec(_key_impl(typed_keys)) From 359e1b9d13f711fc704fa0dafaadb9a34b722548 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 2 Sep 2024 15:53:55 -0700 Subject: [PATCH 325/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d18cd64b7cd61a2ade10089665ac104f639101b1. PiperOrigin-RevId: 670316396 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 9550acb5e320..d21e51836919 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "39a982c1757bbb7136431b2df48d067f122c5190" -XLA_SHA256 = "2d4b80bc97700fb70021926771bd497ba70a8f027e0380d3717c2d9b4ee9775f" +XLA_COMMIT = "d18cd64b7cd61a2ade10089665ac104f639101b1" +XLA_SHA256 = "12d2a18d4f7549305c7949a0d13504e9c3de464792cf72c8d92d62dc414c8ff1" def repo(): tf_http_archive( From f2bef6bb5c0e779b1a308285cae65cd5631474b3 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 2 Sep 2024 23:53:10 +0000 Subject: [PATCH 326/702] tweak shmap implementation to work better with leak checker --- jax/experimental/shard_map.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 94d2e2693e87..bf331bbb913f 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1928,18 +1928,25 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) return _match_rep(fun, mesh, out_reps_src, out_reps_dst) +def _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps): + return _efficient_transpose_outer(_efficient_transpose_inner(fun), mesh, in_reps) + @lu.transformation_with_aux -def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args): +def _efficient_transpose_outer(mesh, in_reps, *args): lvl = core.dynamic_level() with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main: - t = main.with_cur_sublevel() - in_tracers = map(partial(RewriteTracer, t), in_reps, args) - ans = yield in_tracers, {} - out_tracers = map(t.full_raise, ans) - out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers) - del main, t, in_tracers, out_tracers, ans + out_vals, out_reps = yield (main, mesh, in_reps, args), {} + del main yield out_vals, out_reps +@lu.transformation +def _efficient_transpose_inner(main, mesh, in_reps, args): + t = main.with_cur_sublevel() + in_tracers = map(partial(RewriteTracer, t), in_reps, args) + ans = yield in_tracers, {} + out_tracers = map(t.full_raise, ans) + yield unzip2((t.val, t.rep) for t in out_tracers) + @lu.transformation def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args): outs = yield args, {} From 443780e208d49e699a6a9a99403980bfe41cb374 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 2 Sep 2024 19:28:20 -0700 Subject: [PATCH 327/702] [Mosaic TPU] Add support for semaphore operands (inputs and outputs) This enables writing async kernels for collectives or prefetching. PiperOrigin-RevId: 670366575 --- jax/_src/tpu_custom_call.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 86b6e443e854..d3c61e5a1722 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -71,6 +71,7 @@ class MemorySpace(enum.Enum): HBM = enum.auto() VMEM = enum.auto() + SEMAPHORE_MEM = enum.auto() @property def color(self) -> int: @@ -78,6 +79,8 @@ def color(self) -> int: return 0 elif self == MemorySpace.VMEM: return 1 + elif self == MemorySpace.SEMAPHORE_MEM: + return 2 else: raise ValueError("invalid memory space: " + str(self)) From 530ed026b8926cba3cb3d06c855b516fd4c9fb38 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Mon, 2 Sep 2024 20:49:21 -0700 Subject: [PATCH 328/702] FIX typo on jax.numpy.where docstring this was preventing the link to be correctly rendered in the webpage https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html PiperOrigin-RevId: 670385290 --- jax/_src/numpy/lax_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 1d7d2361450a..e7bcc9ecc5cf 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2189,7 +2189,7 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): Returns: An array of dtype ``jnp.result_type(x, y)`` with values drawn from ``x`` where ``condition`` is True, and from ``y`` where condition is ``False``. If ``x`` and ``y`` are ``None``, the - function behaves differently; see `:func:`jax.numpy.nonzero` for a description of the return + function behaves differently; see :func:`jax.numpy.nonzero` for a description of the return type. See Also: From f9cb95ca08fb53624ceaccf04372d5e48f860a46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Tue, 3 Sep 2024 09:50:19 +0200 Subject: [PATCH 329/702] feat(lib): add real-valued implementation of jax.scipy.special.fresnel Add implementation, documentation, and tests, for both single-precision and double-precision floating-point arithmetic. --- docs/jax.scipy.rst | 1 + jax/_src/third_party/scipy/special.py | 322 ++++++++++++++++++++++ jax/scipy/special.py | 4 + tests/lax_scipy_special_functions_test.py | 4 + 4 files changed, 331 insertions(+) create mode 100644 jax/_src/third_party/scipy/special.py diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index f6d8a151440b..abdf5069ee08 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -164,6 +164,7 @@ jax.scipy.special expit expn factorial + fresnel gamma gammainc gammaincc diff --git a/jax/_src/third_party/scipy/special.py b/jax/_src/third_party/scipy/special.py new file mode 100644 index 000000000000..67ef09f6de37 --- /dev/null +++ b/jax/_src/third_party/scipy/special.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import jax.numpy as jnp +from jax import jit + +from jax._src import custom_derivatives, dtypes +from jax._src.numpy.lax_numpy import complexfloating +from jax._src.numpy.util import promote_args_inexact +from jax._src.typing import Array, ArrayLike + + +@jit +def sincospisquaredhalf( + x: Array, +) -> tuple[Array, Array]: + """ + Accurate evaluation of sin(pi * x**2 / 2) and cos(pi * x**2 / 2). + + As based on the sinpi and cospi functions from SciPy, see: + - https://github.com/scipy/scipy/blob/v1.14.0/scipy/special/special/cephes/trig.h + """ + x = jnp.abs(x) + # define s = x % 2, y = x - s, then + # r = (x * x / 2) % 2 + # = [(y + s)*(y + s)/2] % 2 + # = [y*y/2 + s*y + s*s/2] % 2 + # = [(y*y/2)%2 + (s*y + s*s/2)%2]%2 + # = [0 + (s*(y+s/2))%2]%2 + # = [s*(x-s/2)]%2 + s = jnp.fmod(x, 2.0) + r = jnp.fmod(s * (x - s / 2), 2.0) + + sinpi = jnp.where( + r < 0.5, + jnp.sin(jnp.pi * r), + jnp.where( + r > 1.5, + jnp.sin(jnp.pi * (r - 2.0)), + -jnp.sin(jnp.pi * (r - 1.0)), + ), + ) + cospi = jnp.where( + r == 0.5, + 0.0, + jnp.where(r < 1.0, -jnp.sin(jnp.pi * (r - 0.5)), jnp.sin(jnp.pi * (r - 1.5))), + ) + + return sinpi, cospi + + +@custom_derivatives.custom_jvp +def fresnel(x: ArrayLike) -> tuple[Array, Array]: + r"""The Fresnel integrals + + JAX implementation of :obj:`scipy.special.fresnel`. + + The Fresnel integrals are defined as + .. math:: + S(x) &= \int_0^x \sin(\pi t^2 /2) dt \\ + C(x) &= \int_0^x \cos(\pi t^2 /2) dt. + + Args: + x: arraylike, real-valued. + + Returns: + Arrays containing the values of the Fresnel integrals. + + Notes: + The JAX version only supports real-valued inputs, and + is based on the SciPy C++ implementation, see + `here + `_. + For ``float32`` dtypes, the implementation is directly based + on the Cephes implementation ``fresnlf``. + + As for the original Cephes implementation, the accuracy + is only guaranteed in the domain [-10, 10]. Outside of + that domain, one could observe divergence between the + theoretical derivatives and the custom JVP implementation, + especially for large input values. + + Finally, for half-precision data types, ``float16`` + and ``bfloat16``, the array elements are upcasted to + ``float32`` as the Cephes coefficients used in + series expansions would otherwise lead to poor results. + Other data types, like ``float8``, are not supported. + """ + + xxa, = promote_args_inexact("fresnel", x) + original_dtype = xxa.dtype + + # This part is mostly a direct translation of SciPy's C++ code, + # and the original Cephes implementation for single precision. + + if dtypes.issubdtype(xxa.dtype, complexfloating): + raise NotImplementedError( + 'Support for complex-valued inputs is not implemented yet.') + elif xxa.dtype in (jnp.float32, jnp.float16, jnp.bfloat16): + # Single-precision Cephes coefficients + + # For half-precision, series expansions have either + # produce overflow or poor accuracy. + # Upcasting to single-precision is hence needed. + xxa = xxa.astype(jnp.float32) # No-op for float32 + + fresnl_sn = jnp.array([ + +1.647629463788700e-9, + -1.522754752581096e-7, + +8.424748808502400e-6, + -3.120693124703272e-4, + +7.244727626597022e-3, + -9.228055941124598e-2, + +5.235987735681432e-1, + ], dtype=jnp.float32) + + fresnl_cn = jnp.array([ + +1.416802502367354e-8, + -1.157231412229871e-6, + +5.387223446683264e-5, + -1.604381798862293e-3, + +2.818489036795073e-2, + -2.467398198317899e-1, + +9.999999760004487e-1, + ], dtype=jnp.float32) + + fresnl_fn = jnp.array([ + -1.903009855649792e12, + +1.355942388050252e11, + -4.158143148511033e9, + +7.343848463587323e7, + -8.732356681548485e5, + +8.560515466275470e3, + -1.032877601091159e2, + +2.999401847870011e0, + ], dtype=jnp.float32) + + fresnl_gn = jnp.array([ + -1.860843997624650e11, + +1.278350673393208e10, + -3.779387713202229e8, + +6.492611570598858e6, + -7.787789623358162e4, + +8.602931494734327e2, + -1.493439396592284e1, + +9.999841934744914e-1, + ], dtype=jnp.float32) + elif xxa.dtype == jnp.float64: + # Double-precision Cephes coefficients + + fresnl_sn = jnp.array([ + -2.99181919401019853726e3, + +7.08840045257738576863e5, + -6.29741486205862506537e7, + +2.54890880573376359104e9, + -4.42979518059697779103e10, + +3.18016297876567817986e11, + ], dtype=jnp.float64) + + fresnl_sd = jnp.array([ + +1.00000000000000000000e0, + +2.81376268889994315696e2, + +4.55847810806532581675e4, + +5.17343888770096400730e6, + +4.19320245898111231129e8, + +2.24411795645340920940e10, + +6.07366389490084639049e11, + ], dtype=jnp.float64) + + fresnl_cn = jnp.array([ + -4.98843114573573548651e-8, + +9.50428062829859605134e-6, + -6.45191435683965050962e-4, + +1.88843319396703850064e-2, + -2.05525900955013891793e-1, + +9.99999999999999998822e-1, + ], dtype=jnp.float64) + + fresnl_cd = jnp.array([ + +3.99982968972495980367e-12, + +9.15439215774657478799e-10, + +1.25001862479598821474e-7, + +1.22262789024179030997e-5, + +8.68029542941784300606e-4, + +4.12142090722199792936e-2, + +1.00000000000000000118e0, + ], dtype=jnp.float64) + + fresnl_fn = jnp.array([ + +4.21543555043677546506e-1, + +1.43407919780758885261e-1, + +1.15220955073585758835e-2, + +3.45017939782574027900e-4, + +4.63613749287867322088e-6, + +3.05568983790257605827e-8, + +1.02304514164907233465e-10, + +1.72010743268161828879e-13, + +1.34283276233062758925e-16, + +3.76329711269987889006e-20, + ], dtype=jnp.float64) + + fresnl_fd = jnp.array([ + +1.00000000000000000000e0, + +7.51586398353378947175e-1, + +1.16888925859191382142e-1, + +6.44051526508858611005e-3, + +1.55934409164153020873e-4, + +1.84627567348930545870e-6, + +1.12699224763999035261e-8, + +3.60140029589371370404e-11, + +5.88754533621578410010e-14, + +4.52001434074129701496e-17, + +1.25443237090011264384e-20, + ], dtype=jnp.float64) + + fresnl_gn = jnp.array([ + +5.04442073643383265887e-1, + +1.97102833525523411709e-1, + +1.87648584092575249293e-2, + +6.84079380915393090172e-4, + +1.15138826111884280931e-5, + +9.82852443688422223854e-8, + +4.45344415861750144738e-10, + +1.08268041139020870318e-12, + +1.37555460633261799868e-15, + +8.36354435630677421531e-19, + +1.86958710162783235106e-22, + ], dtype=jnp.float64) + + fresnl_gd = jnp.array([ + +1.00000000000000000000e0, + +1.47495759925128324529e0, + +3.37748989120019970451e-1, + +2.53603741420338795122e-2, + +8.14679107184306179049e-4, + +1.27545075667729118702e-5, + +1.04314589657571990585e-7, + +4.60680728146520428211e-10, + +1.10273215066240270757e-12, + +1.38796531259578871258e-15, + +8.39158816283118707363e-19, + +1.86958710162783236342e-22, + ], dtype=jnp.float64) + else: + raise NotImplementedError( + f'Support for {xxa.dtype} dtype is not implemented yet.') + + assert xxa.dtype in (jnp.float32, jnp.float64) + single_precision = (xxa.dtype == jnp.float32) + + x = jnp.abs(xxa) + + x2 = x * x + + # Infinite x values + s_inf = c_inf = 0.5 + + # Small x values + t = x2 * x2 + + if single_precision: + s_small = x * x2 * jnp.polyval(fresnl_sn, t) + c_small = x * jnp.polyval(fresnl_cn, t) + else: + s_small = x * x2 * jnp.polyval(fresnl_sn[:6], t) / jnp.polyval(fresnl_sd[:7], t) + c_small = x * jnp.polyval(fresnl_cn[:6], t) / jnp.polyval(fresnl_cd[:7], t) + + # Large x values + + sinpi, cospi = sincospisquaredhalf(x) + + if single_precision: + c_large = c_inf + s_large = s_inf + else: + c_large = 0.5 + 1 / (jnp.pi * x) * sinpi + s_large = 0.5 - 1 / (jnp.pi * x) * cospi + + # Other x values + t = jnp.pi * x2 + u = 1.0 / (t * t) + t = 1.0 / t + + if single_precision: + f = 1.0 - u * jnp.polyval(fresnl_fn, u) + g = t * jnp.polyval(fresnl_gn, u) + else: + f = 1.0 - u * jnp.polyval(fresnl_fn, u) / jnp.polyval(fresnl_fd, u) + g = t * jnp.polyval(fresnl_gn, u) / jnp.polyval(fresnl_gd, u) + + t = jnp.pi * x + c_other = 0.5 + (f * sinpi - g * cospi) / t + s_other = 0.5 - (f * cospi + g * sinpi) / t + + isinf = jnp.isinf(xxa) + small = x2 < 2.5625 + large = x > 36974.0 + s = jnp.where( + isinf, s_inf, jnp.where(small, s_small, jnp.where(large, s_large, s_other)) + ) + c = jnp.where( + isinf, c_inf, jnp.where(small, c_small, jnp.where(large, c_large, c_other)) + ) + + neg = xxa < 0.0 + s = jnp.where(neg, -s, s) + c = jnp.where(neg, -c, c) + + if original_dtype != xxa.dtype: + s = s.astype(original_dtype) + c = c.astype(original_dtype) + + return s, c + +def _fresnel_jvp(primals, tangents): + x, = primals + x_dot, = tangents + result = fresnel(x) + sinpi, cospi = sincospisquaredhalf(x) + dSdx = sinpi * x_dot + dCdx = cospi * x_dot + return result, (dSdx, dCdx) +fresnel.defjvp(_fresnel_jvp) diff --git a/jax/scipy/special.py b/jax/scipy/special.py index e244c3705af3..5d72339eaec8 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -61,3 +61,7 @@ xlogy as xlogy, zeta as zeta, ) + +from jax._src.third_party.scipy.special import ( + fresnel as fresnel, +) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 38607cae883b..bd3bca5385b7 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -95,6 +95,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t op_record( "factorial", 1, float_dtypes, jtu.rand_default, True ), + op_record( + "fresnel", 1, float_dtypes, + functools.partial(jtu.rand_default, scale=30), True + ), op_record( "i0", 1, float_dtypes, jtu.rand_default, True ), From a1fd582ad610bbebfc8f6c2fe25916bcfc356579 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 3 Sep 2024 04:24:34 -0700 Subject: [PATCH 330/702] [jax:pallas] Simplify pointer offset calculation in Triton lowering. PiperOrigin-RevId: 670499398 --- jax/_src/pallas/triton/lowering.py | 95 ++++++++++++------------------ 1 file changed, 37 insertions(+), 58 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index f2a4229223bd..55c88a2d5d7a 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1708,74 +1708,53 @@ def _compute_pointers_from_indices( index = _i32_constant(0) else: index = next(indexer_iter) + + if isinstance(index, slice): + index = primitives.Slice.from_slice(index, dim_block_size) + if isinstance(index, primitives.Slice): - if index.is_dynamic_start: - # Compute the offset as start + range(0, size). - ptr_dim_offset = _add( - _bcast_to(index.start, [index.size]), - _ir_cast(_make_range(0, index.size), index.start.type, signed=False), - ) - elif index.stride > 1: - # Compute the offset as start + range(0, size) * stride. - iota = _make_range(0, index.size) - ptr_dim_offset = _add( - _bcast_to(_i32_constant(index.start), [index.size]), - _mul(iota, _full(iota.type, index.stride)), - ) + if index.is_dynamic_start or (index.stride != 1): + start = index.start + if not index.is_dynamic_start: + start = _i32_constant(start) + + iota = _ir_cast(_make_range(0, index.size), start.type, signed=False) + if index.stride != 1: + iota = _mul(iota, _full(iota.type, index.stride)) + dim_offsets = _add(_bcast_to(start, [index.size]), iota) else: - ptr_dim_offset = _make_range(index.start, index.start + index.size) + dim_offsets = _make_range(index.start, index.start + index.size) - # We need to add broadcastable dimensions for the advanced int indexing - # and for previous slices - num_left_expand_dims = len(int_indexer_shape) + other_shape_idx - num_right_expand_dims = len(other_shape) - other_shape_idx - 1 - other_shape_idx += 1 - elif isinstance(index, slice): - if index != slice(None): - raise NotImplementedError("Only `slice(None)` allowed.") - ptr_dim_offset = _make_range(0, dim_block_size) - num_left_expand_dims = len(int_indexer_shape) + other_shape_idx - num_right_expand_dims = len(other_shape) - other_shape_idx - 1 other_shape_idx += 1 + for _ in other_shape[other_shape_idx:]: + rank = ir.RankedTensorType(dim_offsets.type).rank + dim_offsets = _expand_dims(dim_offsets, rank) else: # indexer is either a *scalar* or an array of size `int_indexer_shape` - ptr_dim_offset = _ensure_ir_value( - index, jax_core.ShapedArray((), jnp.int32) - ) - num_left_expand_dims = 0 - num_right_expand_dims = len(other_shape) - if not ir.RankedTensorType.isinstance(ptr_dim_offset.type): - num_left_expand_dims = max(len(indexer_shape) - 1, 0) - else: - num_right_expand_dims = len(other_shape) + dim_offsets = _ensure_ir_value(index, jax_core.ShapedArray((), jnp.int32)) + + if ir.RankedTensorType.isinstance(dim_offsets.type): + for _ in other_shape: + rank = ir.RankedTensorType(dim_offsets.type).rank + dim_offsets = _expand_dims(dim_offsets, rank) + + if ir.RankedTensorType.isinstance(dim_offsets.type): + rank = ir.RankedTensorType(dim_offsets.type).rank + for _ in range(len(indexer_shape) - rank): + dim_offsets = _expand_dims(dim_offsets, 0) + dim_offsets = _bcast_to(dim_offsets, indexer_shape) + elif indexer_shape: + dim_offsets = _splat(dim_offsets, indexer_shape) - if indexer_shape and not ir.RankedTensorType.isinstance(ptr_dim_offset.type): - ptr_dim_offset = _splat(ptr_dim_offset, [1] * len(indexer_shape)) - else: - for _ in range(num_left_expand_dims): - ptr_dim_offset = _expand_dims(ptr_dim_offset, 0) - for _ in range(num_right_expand_dims): - ndim = len(getattr(ptr_dim_offset.type, "shape", [])) - ptr_dim_offset = _expand_dims(ptr_dim_offset, ndim) - - ptr_dim_offset = _bcast_to(ptr_dim_offset, indexer_shape) - index_type = ir.IntegerType(_element_type(ptr_dim_offset.type)) if start_offset is not None: - start_offset = _ir_cast(start_offset, index_type, signed=False) - ptr_dim_offset = _add( - ptr_dim_offset, _bcast_to(start_offset, indexer_shape) - ) + offset_type = _element_type(dim_offsets.type) + start_offset = _ir_cast(start_offset, offset_type, signed=False) + dim_offsets = _add(dim_offsets, _bcast_to(start_offset, indexer_shape)) - if index_type.width == 32: - stride_size = _i32_constant(dim_stride) - else: - stride_size = _i64_constant(dim_stride) - stride_size = _splat(stride_size, indexer_shape) - bcast_indices.append(_mul(ptr_dim_offset, stride_size)) + bcast_indices.append(_mul(dim_offsets, _full(dim_offsets.type, dim_stride))) - return functools.reduce( - _add, bcast_indices, _bcast_to(root_ptr, indexer_shape) - ) + ptrs = _bcast_to(root_ptr, indexer_shape) + return functools.reduce(_add, bcast_indices, ptrs) @register_lowering(sp.get_p) From 4c3111bf26103c40e22c26ba7b1a92d6901ed755 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 3 Sep 2024 06:06:54 -0700 Subject: [PATCH 331/702] [Mosaic GPU] Unbreak tests I mistakenly checked for `amount + 1` instead of `amount * 2`. It initially seemed right because both expressions evalute to 2 for 1 :) PiperOrigin-RevId: 670527107 --- jax/experimental/mosaic/gpu/fragmented_array.py | 2 +- tests/mosaic/BUILD | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 892cd2d09332..7e0d43a05551 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -757,7 +757,7 @@ def transfer_tiled(shape, dtype, swizzle: int | None): case _: raise AssertionError(swizzle) stagger_amount = swizzle // 64 - if (cols_per_tile // 8) % (stagger_amount + 1): + if (cols_per_tile // 8) % (stagger_amount * 2): raise NotImplementedError else: # We rely on canonicalization to clean up the selects. diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 9149891a2dea..255b03d3a002 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -49,7 +49,6 @@ jax_test( disable_configs = DISABLED_CONFIGS, enable_configs = ["gpu_h100_2gpu"], shard_count = 4, - tags = ["notap"], # Broken at head. deps = [ "//jax:mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), @@ -61,7 +60,6 @@ jax_test( disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, shard_count = 5, - tags = ["notap"], # Broken at head. deps = [ "//jax:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", From 7b161fb76cc24ce44e499e020a05ab59ad262267 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 3 Sep 2024 06:19:56 -0700 Subject: [PATCH 332/702] [jax:pallas] Use 64-bit indexing when necessary when lowering to Triton. PiperOrigin-RevId: 670530776 --- jax/_src/pallas/triton/lowering.py | 56 +++++++++++++++++------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 55c88a2d5d7a..446d87a5f347 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -474,7 +474,7 @@ def _atomic_lowering_rule( raise NotImplementedError("Only single indexer is supported.") idx = indexers[0] ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape + ptr, ctx.block_infos[0], idx, ctx.avals_in[0] ) val = _ensure_ir_value(val, value_aval) if mask is not None: @@ -1674,12 +1674,12 @@ def _compute_pointers_from_indices( root_ptr: ir.Value, block_info: BlockInfo | None, nd_indexer: NDIndexer, - array_shape: tuple[int, ...], + array_shape_dtype: Any, ) -> ir.Value: if block_info is None: # TODO(necula): is this branch dead? - full_shape = array_shape + full_shape = array_shape_dtype.shape num_mapped_dims = 0 - block_shape = array_shape + block_shape = array_shape_dtype.shape else: full_shape = block_info.full_shape_dtype.shape num_mapped_dims = sum( @@ -1692,7 +1692,6 @@ def _compute_pointers_from_indices( _check_tensor_size(indexer_shape) indices = nd_indexer.indices other_shape = indexer_shape[len(int_indexer_shape) :] - bcast_indices = [] other_shape_idx = 0 if block_info is None: start_index_offsets = [None] * len(indices) @@ -1700,12 +1699,22 @@ def _compute_pointers_from_indices( start_index_offsets = block_info.start_indices assert len(indices) + num_mapped_dims == len(full_shape) assert len(start_index_offsets) == len(full_shape) + + array_dtype = jnp.dtype(array_shape_dtype.dtype) + full_size = math.prod(full_shape) * array_dtype.itemsize + # Use 64-bit indexing when offset might be >= 2**32 bytes. + offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32) + if indexer_shape: + offsets = _full(ir.RankedTensorType.get(indexer_shape, offset_eltype), 0) + else: + offsets = _ir_constant(0, offset_eltype) + indexer_iter = iter(indices) for dim_stride, dim_block_size, start_offset in zip( strides, block_shape, start_index_offsets ): if dim_block_size is pallas_core.mapped: - index = _i32_constant(0) + index = _ir_constant(0, offset_eltype) else: index = next(indexer_iter) @@ -1716,14 +1725,16 @@ def _compute_pointers_from_indices( if index.is_dynamic_start or (index.stride != 1): start = index.start if not index.is_dynamic_start: - start = _i32_constant(start) + start = _ir_constant(start, offset_eltype) + start = _ir_cast(start, offset_eltype, signed=False) - iota = _ir_cast(_make_range(0, index.size), start.type, signed=False) + iota = _ir_cast(_make_range(0, index.size), offset_eltype, signed=False) if index.stride != 1: iota = _mul(iota, _full(iota.type, index.stride)) dim_offsets = _add(_bcast_to(start, [index.size]), iota) else: - dim_offsets = _make_range(index.start, index.start + index.size) + iota = _make_range(index.start, index.start + index.size) + dim_offsets = _ir_cast(iota, offset_eltype, signed=False) other_shape_idx += 1 for _ in other_shape[other_shape_idx:]: @@ -1731,7 +1742,10 @@ def _compute_pointers_from_indices( dim_offsets = _expand_dims(dim_offsets, rank) else: # indexer is either a *scalar* or an array of size `int_indexer_shape` - dim_offsets = _ensure_ir_value(index, jax_core.ShapedArray((), jnp.int32)) + dim_offsets = index + if not isinstance(dim_offsets, ir.Value): + dim_offsets = _ir_constant(dim_offsets, offset_eltype) + dim_offsets = _ir_cast(dim_offsets, offset_eltype, signed=False) if ir.RankedTensorType.isinstance(dim_offsets.type): for _ in other_shape: @@ -1742,19 +1756,16 @@ def _compute_pointers_from_indices( rank = ir.RankedTensorType(dim_offsets.type).rank for _ in range(len(indexer_shape) - rank): dim_offsets = _expand_dims(dim_offsets, 0) - dim_offsets = _bcast_to(dim_offsets, indexer_shape) - elif indexer_shape: - dim_offsets = _splat(dim_offsets, indexer_shape) + dim_offsets = _bcast_to(dim_offsets, indexer_shape) if start_offset is not None: - offset_type = _element_type(dim_offsets.type) - start_offset = _ir_cast(start_offset, offset_type, signed=False) + start_offset = _ir_cast(start_offset, offset_eltype, signed=False) dim_offsets = _add(dim_offsets, _bcast_to(start_offset, indexer_shape)) - bcast_indices.append(_mul(dim_offsets, _full(dim_offsets.type, dim_stride))) + dim_offsets = _mul(dim_offsets, _full(dim_offsets.type, dim_stride)) + offsets = _add(offsets, dim_offsets) - ptrs = _bcast_to(root_ptr, indexer_shape) - return functools.reduce(_add, bcast_indices, ptrs) + return _add(_bcast_to(root_ptr, indexer_shape), offsets) @register_lowering(sp.get_p) @@ -1869,7 +1880,7 @@ def _masked_load_lowering_rule( assert len(ctx.avals_in) == 1 return ptr ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape + ptr, ctx.block_infos[0], idx, ctx.avals_in[0] ) if mask is not None: mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape()) @@ -1966,7 +1977,7 @@ def _masked_swap_lowering_rule( raise NotImplementedError("No support for multiple indexers yet.") idx = indexers[0] ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape + ptr, ctx.block_infos[0], idx, ctx.avals_in[0] ) other = None if value is not None: @@ -1991,10 +2002,7 @@ def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): raise NotImplementedError("No support for multiple indexers yet.") indexer = indexers[0] ptr = _compute_pointers_from_indices( - ptr, - ctx.block_infos[0], - indexer, - ctx.avals_in[0].shape, + ptr, ctx.block_infos[0], indexer, ctx.avals_in[0] ) op = tt_dialect.RMWOp.FADD if isinstance(_element_type(value.type), ir.IntegerType): From ccabd210840ca533f468729125984eeed8867c42 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Sep 2024 06:23:18 -0700 Subject: [PATCH 333/702] Fixed rules where ``sliding_window_length`` was not forwarded This is follow up to #23284. PiperOrigin-RevId: 670531634 --- jax/_src/cudnn/fused_attention_stablehlo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 931587c0cfcf..171954bc86c5 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -646,7 +646,8 @@ def _dot_product_attention_fwd_batcher( outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=is_training) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=is_training) # reshape to original shape output = outputs[0] @@ -698,6 +699,7 @@ def _dot_product_attention_bwd_batcher( fwd_output, grad_output, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, ) # reshape to original shape From bc415f915353ceb96c2bebfd2b287153340dc95b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 3 Sep 2024 09:45:28 -0400 Subject: [PATCH 334/702] Relax test tolerances to fix CI failures on Mac ARM. --- tests/for_loop_test.py | 6 ++++-- tests/lax_control_flow_test.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py index b79c233e6f2e..0fea62c12cf7 100644 --- a/tests/for_loop_test.py +++ b/tests/for_loop_test.py @@ -319,8 +319,10 @@ def f(a, b): _, f_lin = jax.linearize(f, a, b) expected_tangents = f_lin(a, b) _, actual_tangents = jax.jvp(f, (a, b), (a, b)) - np.testing.assert_allclose(actual_tangents[0], expected_tangents[0]) - np.testing.assert_allclose(actual_tangents[1], expected_tangents[1]) + np.testing.assert_allclose(actual_tangents[0], expected_tangents[0], + rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(actual_tangents[1], expected_tangents[1], + rtol=1e-6, atol=1e-6) def body2(_, refs): # Here we use `i_ref` as a loop counter diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index fd83d269b41c..37ad22063c94 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -1680,7 +1680,8 @@ def f(c, a): tol = {np.float64: 1e-12, np.float32: 1e-4} self.assertAllClose(ans, expected, check_dtypes=False, rtol=tol, atol=tol) - jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["fwd"]) + jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["fwd"], + rtol={jnp.float32: 2e-1}) @parameterized.named_parameters( {"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}", From f2ffe7f8f27c6179bd4a3d73c3d12e1ce5f99396 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Sep 2024 06:52:07 -0700 Subject: [PATCH 335/702] Deprecate jax.numpy.round_ NumPy removed np.round in version 2.0; jax.numpy.round is drop-in replacement. --- CHANGELOG.md | 2 ++ jax/_src/numpy/lax_numpy.py | 6 ------ jax/numpy/__init__.py | 17 +++++++++++++---- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f59b07cd237e..c633ebb40a50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and `jax.core.check_valid_jaxtype` are now deprecated, and will be removed in the future. + * `jax.numpy.round_` has been deprecated, following removal of the corresponding + API in NumPy 2.0. Use {func}`jax.numpy.round` instead. ## jaxlib 0.4.32 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index e7bcc9ecc5cf..7c0162d784c0 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2722,12 +2722,6 @@ def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: return round(a, decimals, out) -@partial(jit, static_argnames=('decimals',)) -def round_(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: - """Alias of :func:`jax.numpy.round`""" - return round(a, decimals, out) - - @jit def fix(x: ArrayLike, out: None = None) -> Array: """Round input to the nearest integer towards zero. diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 88e1840ef1c0..da79f7859bcd 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -212,7 +212,6 @@ rollaxis as rollaxis, rot90 as rot90, round as round, - round_ as round_, save as save, savez as savez, searchsorted as searchsorted, @@ -466,6 +465,11 @@ _deprecations = { + # Deprecated 03 Sept 2024 + "round_": ( + "jnp.round_ is deprecated; use jnp.round instead.", + round + ), # Deprecated 18 Sept 2023 and removed 06 Feb 2024 "trapz": ( "jnp.trapz is deprecated; use jnp.trapezoid instead.", @@ -473,6 +477,11 @@ ), } -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr +import typing +if typing.TYPE_CHECKING: + round_ = round +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing From fd897745d38e490250856748d62254a957b4b1d5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Sep 2024 09:48:32 -0700 Subject: [PATCH 336/702] Partial rollback of https://github.com/google/jax/pull/23353 as discussed in https://github.com/google/jax/pull/23353#issuecomment-2326604708 Reverts eed273c106af699efefc726eea1ff2b0f548f669 PiperOrigin-RevId: 670596159 --- jax/_src/lax/eigh.py | 134 +++++++++++++++++------------------------- jax/_src/lax/qdwh.py | 80 +++++++++---------------- jax/_src/lax/stack.py | 26 ++++---- jax/_src/lax/svd.py | 79 +++++++++---------------- 4 files changed, 122 insertions(+), 197 deletions(-) diff --git a/jax/_src/lax/eigh.py b/jax/_src/lax/eigh.py index d872ff0f68cc..fc66b0f2e7ee 100644 --- a/jax/_src/lax/eigh.py +++ b/jax/_src/lax/eigh.py @@ -30,16 +30,15 @@ from functools import partial from typing import NamedTuple -from jax._src import api -from jax._src import config -from jax._src import dtypes +import jax +import jax._src.numpy.lax_numpy as jnp +import jax._src.numpy.linalg as jnp_linalg +from jax._src.numpy import reductions +from jax._src.numpy import ufuncs from jax import lax from jax._src.lax import qdwh from jax._src.lax import linalg as lax_linalg from jax._src.lax.stack import Stack -from jax._src.lax import lax as lax_internal -from jax._src.typing import Array -import numpy as np # QDWH-eigh is a recursive algorithm where the structure of the recursion @@ -53,45 +52,19 @@ def _round_up(i, n): return ((i+n-1) // n) * n -def _norm(x, axis=None): - return lax.sqrt((abs(x) ** 2).sum(axis=axis)) - -def _broadcast_to(x: Array, shape: tuple[int, ...]) -> Array: - assert x.ndim <= len(shape) - return lax.broadcast_in_dim(x, shape, range(len(shape) - x.ndim, len(shape))) - -def _construct_diagonal(s: Array) -> Array: - """Construct a (batched) diagonal matrix""" - # signature: (...,n)->(...,n,n) - i = lax.iota('int32', s.shape[-1]) - return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s) - -def _extract_diagonal(s: Array) -> Array: - """Extract the diagonal from a batched matrix""" - # signature: (...,n,m)->(...k) where k=min(n,m) - i = lax.iota('int32', min(s.shape[-2], s.shape[-1])) - return s[..., i, i] - def _mask(x, dims, alternative=0): """Masks `x` up to the dynamic shape `dims`. Replaces values outside those dimensions with `alternative`. `alternative` is broadcast with `x`. """ - assert np.ndim(x) == len(dims) + assert jnp.ndim(x) == len(dims) mask = None for i, d in enumerate(dims): if d is not None: - mask_dim_i = lax.broadcasted_iota(np.int32, x.shape, i) < d + mask_dim_i = lax.broadcasted_iota(jnp.int32, x.shape, i) < d mask = mask_dim_i if mask is None else (mask & mask_dim_i) - - alternative = _broadcast_to(lax_internal.asarray(alternative), x.shape).astype(x.dtype) - return x if mask is None else lax.select(mask, x, alternative) - -def _nanmedian(vals): - # note: NaNs will be sorted to the end. - num_nans = lax_internal._isnan(vals).sum() - return lax.sort(vals.ravel())[(vals.size - num_nans) // 2] + return x if mask is None else jnp.where(mask, x, alternative) def _slice(operand, start_indices, dynamic_slice_sizes, static_slice_sizes, fill_value=0): @@ -114,9 +87,9 @@ def _slice(operand, start_indices, dynamic_slice_sizes, static_slice_sizes, # We must pad the input array so the dynamic_slice is guaranteed to fall # entirely in bounds. padded = lax.pad(operand, - np.array(0, operand.dtype), + jnp.array(0, operand.dtype), [(0, d, 0) for d in static_slice_sizes]) - out = lax.dynamic_slice(padded, tuple(lax.convert_element_type(i, 'int32') for i in start_indices), + out = lax.dynamic_slice(padded, tuple(jnp.int32(i) for i in start_indices), static_slice_sizes) return _mask(out, dynamic_slice_sizes, fill_value) @@ -133,9 +106,9 @@ def _update_slice(operand, update, start_indices, update_dims): inside the rectangle given by `update_dims` will be overwritten.""" operand_shape = operand.shape operand = lax.pad(operand, - np.array(0, operand.dtype), + jnp.array(0, operand.dtype), [(0, d, 0) for d in update.shape]) - start_indices = tuple(lax.convert_element_type(i, 'int32') for i in start_indices) + start_indices = tuple(jnp.int32(i) for i in start_indices) t = lax.dynamic_slice(operand, start_indices, update.shape) t = _mask(update, update_dims, t) operand = lax.dynamic_update_slice(operand, t, start_indices) @@ -167,45 +140,45 @@ def _projector_subspace(P, H, n, rank, maxiter=2, swap=False): """ # Choose an initial guess: the `rank` largest-norm columns of P. N, _ = P.shape - negative_column_norms = -_norm(P, axis=1) - # `sort_key_val` ensures NaNs sort last, so set masked-out column norms to NaN. - negative_column_norms = _mask(negative_column_norms, (n,), np.nan) - _, sort_idxs = lax.sort_key_val(negative_column_norms, lax.iota('int32', len(negative_column_norms))) + negative_column_norms = -jnp_linalg.norm(P, axis=1) + # `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN. + negative_column_norms = _mask(negative_column_norms, (n,), jnp.nan) + sort_idxs = jnp.argsort(negative_column_norms) X = P[:, sort_idxs] # X = X[:, :rank] X = _mask(X, (n, rank)) - H_norm = _norm(H) - thresh = 10.0 * float(dtypes.finfo(X.dtype).eps) * H_norm + H_norm = jnp_linalg.norm(H) + thresh = 10.0 * float(jnp.finfo(X.dtype).eps) * H_norm # First iteration skips the matmul. def body_f_after_matmul(X): - Q, _ = lax_linalg.qr(X) + Q, _ = jnp_linalg.qr(X, mode="complete") # V1 = Q[:, :rank] # V2 = Q[:, rank:] V1 = _mask(Q, (n, rank)) V2 = _slice(Q, (0, rank), (n, n - rank), (N, N)) # TODO: might be able to get away with lower precision here - error_matrix = V2.conj().T @ H - error_matrix = error_matrix @ V1 - error = _norm(error_matrix) + error_matrix = jnp.dot(V2.conj().T, H) + error_matrix = jnp.dot(error_matrix, V1) + error = jnp_linalg.norm(error_matrix) return V1, V2, error def cond_f(args): _, _, j, error = args still_counting = j < maxiter unconverged = error > thresh - return lax.bitwise_and(still_counting, unconverged)[0] + return ufuncs.logical_and(still_counting, unconverged)[0] def body_f(args): V1, _, j, _ = args - X = P @ V1 + X = jnp.dot(P, V1) V1, V2, error = body_f_after_matmul(X) return V1, V2, j + 1, error V1, V2, error = body_f_after_matmul(X) - one = lax.full((1,), 1, np.dtype('int32')) + one = jnp.ones(1, dtype=jnp.int32) V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error)) if swap: return V2, V1 @@ -237,11 +210,11 @@ def split_spectrum(H, n, split_point, V0=None): rank: The dynamic size of the m subblock. """ N, _ = H.shape - H_shift = H - (split_point * lax_internal._eye(split_point.dtype, (N, N))).astype(H.dtype) + H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(H.dtype) U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n)) - I = _mask(lax_internal._eye(H.dtype, (N, N)), (n, n)) + I = _mask(jnp.eye(N, dtype=H.dtype), (n, n)) P_minus = -0.5 * (U - I) - rank_minus = lax.round(_extract_diagonal(P_minus.real).sum(-1)).astype(np.int32) + rank_minus = jnp.round(jnp.trace(ufuncs.real(P_minus))).astype(jnp.int32) P_plus = 0.5 * (U + I) rank_plus = n - rank_minus @@ -259,8 +232,8 @@ def split_spectrum(H, n, split_point, V0=None): H_minus = (V_minus.conj().T @ H) @ V_minus H_plus = (V_plus.conj().T @ H) @ V_plus if V0 is not None: - V_minus = lax.dot(V0, V_minus) - V_plus = lax.dot(V0, V_plus) + V_minus = jnp.dot(V0, V_minus) + V_plus = jnp.dot(V0, V_plus) return H_minus, V_minus, H_plus, V_plus, rank_minus @@ -285,7 +258,7 @@ def split_spectrum(H, n, split_point, V0=None): # H, V: The result of the projection. # """ # if H.shape[0] <= termination_size: -# evals, evecs = jnp.linalg.eigh(H) +# evals, evecs = jnp_linalg.eigh(H) # if V is not None: # evecs = jnp.dot(V, evecs) # return evals, evecs @@ -306,13 +279,13 @@ class _Subproblem(NamedTuple): in the workspace. """ # The row offset of the block in the matrix of blocks. - offset: Array + offset: jax.Array # The size of the block. - size: Array + size: jax.Array -@partial(api.jit, static_argnames=('termination_size', 'subset_by_index')) +@partial(jax.jit, static_argnames=('termination_size', 'subset_by_index')) def _eigh_work(H, n, termination_size, subset_by_index): """ The main work loop performing the symmetric eigendecomposition of H. Each step recursively computes a projector into the space of eigenvalues @@ -335,21 +308,20 @@ def _eigh_work(H, n, termination_size, subset_by_index): # We turn what was originally a recursive algorithm into an iterative # algorithm with an explicit stack. N, _ = H.shape - n = n.astype('int32') - zero = lax.full((), 0, 'int32') + n = jnp.asarray(n, jnp.int32) agenda = Stack.create( - N + 1, _Subproblem(zero, zero)) - agenda = agenda.push(_Subproblem(offset=zero, size=n)) + N + 1, _Subproblem(jnp.array(0, jnp.int32), jnp.array(0, jnp.int32))) + agenda = agenda.push(_Subproblem(offset=jnp.int32(0), size=n)) # eigenvectors is the array in which we build the output eigenvectors. # We initialize it with the identity matrix so the initial matrix # multiplications in_split_spectrum_jittable are the identity. - eigenvectors = lax_internal._eye(H.dtype, (N, N)) + eigenvectors = jnp.eye(N, dtype=H.dtype) # Keep a copy of the initial matrix Frobenius norm, so we know when to stop # recursing. When the sub-matrix norm is less than eps*H0_norm, the contents are # pure numerical noise, and we should just stop. - H0_norm = _norm(_mask(H, (n, n))) + H0_norm = jnp_linalg.norm(_mask(H, (n, n))) # blocks is an array representing a stack of Hermitian matrix blocks that we # need to recursively decompose. Subproblems are different sizes, so the stack @@ -396,7 +368,7 @@ def base_case(B, offset, b, agenda, blocks, eigenvectors): eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False) eig_vecs = _mask(eig_vecs, (b, b)) eig_vals = _mask(eig_vals, (b,)) - eig_vecs = lax.dot(V, eig_vecs) + eig_vecs = jnp.dot(V, eig_vecs) eig_vals = eig_vals.astype(eig_vecs.dtype) blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, 1)) @@ -409,7 +381,7 @@ def recursive_case(B, offset, b, agenda, blocks, eigenvectors): H = _slice(blocks, (offset, 0), (b, b), (B, B)) def nearly_diagonal_case(agenda, blocks, eigenvectors): - blocks = _update_slice(blocks, _extract_diagonal(H)[:, None], (offset, 0), (b, 1)) + blocks = _update_slice(blocks, jnp.diag(H)[:, None], (offset, 0), (b, 1)) return agenda, blocks, eigenvectors def should_update_range(start, end, subset_by_index): @@ -422,7 +394,7 @@ def should_update_range(start, end, subset_by_index): def default_case(agenda, blocks, eigenvectors): V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) # TODO: Improve this? - split_point = _nanmedian(_mask(_extract_diagonal(H.real), (b,), np.nan)) + split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), jnp.nan)) H_minus, V_minus, H_plus, V_plus, rank = split_spectrum( H, b, split_point, V0=V) @@ -467,11 +439,11 @@ def default_case(agenda, blocks, eigenvectors): # the original input matrix,, terminate the execution. This is necessary to # handle matrices with clusters of eigenvalues, including rank deficient # matrices. See Nakatsukasa and Higham section 5.2. - norm = _norm(H) - eps = np.asarray(dtypes.finfo(H.dtype).eps, dtype=norm.dtype) - off_diag_norm = _norm( - H - _construct_diagonal(_extract_diagonal(H.real).astype(H.dtype))) - nearly_diagonal = off_diag_norm <= 5 * (eps * norm) + norm = jnp_linalg.norm(H) + eps = jnp.asarray(jnp.finfo(H.dtype).eps, dtype=norm.dtype) + off_diag_norm = jnp_linalg.norm( + H - jnp.diag(jnp.diag(ufuncs.real(H)).astype(H.dtype))) + nearly_diagonal = off_diag_norm <= 5 * eps * norm tiny = norm < eps * H0_norm return lax.cond( nearly_diagonal | tiny, @@ -510,13 +482,13 @@ def loop_cond(state): buckets.append(bucket_size) branches.append(partial(recursive_case, bucket_size)) i = i // 2 - buckets = np.array(buckets, dtype='int32') + buckets = jnp.array(buckets, dtype='int32') def loop_body(state): agenda, blocks, eigenvectors = state (offset, b), agenda = agenda.pop() - which = lax.select(buckets < b, lax.full_like(buckets, np.iinfo(np.int32).max), buckets) - choice = lax.argmin(which, 0, 'int32') + which = jnp.where(buckets < b, jnp.iinfo(jnp.int32).max, buckets) + choice = jnp.argmin(which) return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors) _, blocks, eigenvectors = lax.while_loop( @@ -585,13 +557,13 @@ def eigh( return eig_vals, eig_vecs n = N if n is None else n - with config.default_matmul_precision(precision): + with jax.default_matmul_precision(precision): eig_vals, eig_vecs = _eigh_work( H, n, termination_size=termination_size, subset_by_index=subset_by_index ) - eig_vals = _mask(eig_vals.real, (n,), np.nan) + eig_vals = _mask(ufuncs.real(eig_vals), (n,), jnp.nan) if sort_eigenvalues or compute_slice: - _, sort_idxs = lax.sort_key_val(eig_vals, lax.iota('int32', len(eig_vals))) + sort_idxs = jnp.argsort(eig_vals) if compute_slice: sort_idxs = sort_idxs[subset_by_index[0] : subset_by_index[1]] eig_vals = eig_vals[sort_idxs] diff --git a/jax/_src/lax/qdwh.py b/jax/_src/lax/qdwh.py index 0cfd67d8797a..bac3ea957955 100644 --- a/jax/_src/lax/qdwh.py +++ b/jax/_src/lax/qdwh.py @@ -28,33 +28,11 @@ import functools -from jax._src import api -from jax._src import config +import jax +import jax.numpy as jnp +from jax import lax from jax._src import core -from jax._src import dtypes -from jax._src.lax.control_flow import loops -from jax._src.lax import lax from jax._src.lax import linalg as lax_linalg -from jax._src.lax import slicing -from jax._src.typing import Array -import numpy as np - - -def _norm(x, axis=None): - return lax.sqrt((abs(x) ** 2).sum(axis=axis)) - -def _one_norm(x): - assert x.ndim == 2 - return abs(x).sum(0).max() - -def _inf_norm(x): - assert x.ndim == 2 - return abs(x).sum(1).max() - - -def _broadcast_to(x: Array, shape: tuple[int, ...]) -> Array: - assert x.ndim <= len(shape) - return lax.broadcast_in_dim(x, shape, range(len(shape) - x.ndim, len(shape))) # Helpers for working with padded shapes @@ -64,26 +42,24 @@ def _mask(x, dims, alternative=0): Replaces values outside those dimensions with `alternative`. `alternative` is broadcast with `x`. """ - assert np.ndim(x) == len(dims) + assert jnp.ndim(x) == len(dims) mask = None for i, d in enumerate(dims): if d is not None: - mask_dim_i = lax.broadcasted_iota(np.int32, x.shape, i) < d + mask_dim_i = lax.broadcasted_iota(jnp.int32, x.shape, i) < d mask = mask_dim_i if mask is None else (mask & mask_dim_i) - - alternative = _broadcast_to(lax.asarray(alternative), x.shape).astype(x.dtype) - return x if mask is None else lax.select(mask, x, alternative) + return x if mask is None else jnp.where(mask, x, alternative) def _pad_in_dim(x, low=0, high=0, interior=0, fill_value=0, axis=0): pads = [(0, 0, 0)] * x.ndim pads[axis] = (low, high, interior) - return lax.pad(x, lax.convert_element_type(fill_value, x.dtype), pads) + return lax.pad(x, jnp.array(fill_value, x.dtype), pads) def _dynamic_concat(a, b, m, axis=0): "Concatenates padded arrays `a` and `b` where the true size of `a` is `m`." if m is None: - return lax.concatenate([a, b], dimension=axis) - return slicing.dynamic_update_slice_in_dim( + return jnp.concatenate([a, b], axis=axis) + return lax.dynamic_update_slice_in_dim( _pad_in_dim(a, high=b.shape[axis], axis=axis), b, m, axis) @@ -98,12 +74,12 @@ def _use_qr(u, m, n, params): a_minus_e_by_sqrt_c, sqrt_c, e = params M, N = u.shape - y = _dynamic_concat(sqrt_c * u, lax._eye(np.dtype(u), (N, N)), m) + y = _dynamic_concat(sqrt_c * u, jnp.eye(N, dtype=jnp.dtype(u)), m) q, _ = lax_linalg.qr(y, full_matrices=False) # q1 = q[:m, :] - q1 = _mask(slicing.slice(q, (0, 0), (M, N)), (m, n)) + q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n)) # q2 = (q[m:, :]).T.conj() - q2 = slicing.dynamic_slice_in_dim(q, m, N, axis=0) + q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0) q2 = _mask(q2, (n, n)).T.conj() return e * u + a_minus_e_by_sqrt_c * (q1 @ q2) @@ -118,11 +94,11 @@ def _use_cholesky(u, m, n, params): """ a_minus_e, c, e = params _, N = u.shape - x = c * (u.T.conj() @ u) + lax._eye(np.dtype(u), (N, N)) + x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u)) # Pads the lower-right corner with the identity matrix to prevent the Cholesky # decomposition from failing due to the matrix not being PSD if padded with # zeros. - x = _mask(x, (n, n), lax._eye(x.dtype, (N, N))) + x = _mask(x, (n, n), jnp.eye(N, dtype=x.dtype)) # `y` is lower triangular. y = lax_linalg.cholesky(x, symmetrize_input=False) @@ -143,18 +119,18 @@ def _qdwh(x, m, n, max_iterations, eps): # norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for # the smallest singular value of x. if eps is None: - eps = float(dtypes.finfo(x.dtype).eps) - one_norm = _one_norm(x) - inf_norm = _inf_norm(x) + eps = float(jnp.finfo(x.dtype).eps) + one_norm = jnp.linalg.norm(x, ord=1) + inf_norm = jnp.linalg.norm(x, ord=jnp.inf) alpha_inverse = lax.rsqrt(one_norm) * lax.rsqrt(inf_norm) - alpha_inverse = lax.select(one_norm == 0, lax._ones(alpha_inverse), alpha_inverse) + alpha_inverse = jnp.where(one_norm == 0, 1, alpha_inverse) u = x * alpha_inverse.astype(x.dtype) l = eps # Iteration tolerances. tol_l = 10.0 * eps / 2.0 - tol_norm = lax.cbrt(tol_l) + tol_norm = jnp.cbrt(tol_l) def get_qr_params(a, b, c): e = b / c @@ -193,22 +169,22 @@ def iteration(k, state, update_fn, coefs, test_convergence): # As l → 1, the coefficients a, b, c → 3, 1, 3, which is Halley's method. params = get_chol_params(3, 1, 3) else: - params = slicing.dynamic_index_in_dim(coefs, k, keepdims=False) + params = lax.dynamic_index_in_dim(coefs, k, keepdims=False) u_prev = u u = update_fn(u, m, n, params) is_not_converged = True if test_convergence: - is_not_converged = _norm(u - u_prev) > tol_norm + is_not_converged = jnp.linalg.norm(u - u_prev) > tol_norm return u, is_not_converged def iterate(u, coefs, **kwargs): if not coefs: return u, True - coefs = np.array(coefs).astype(x.dtype) + coefs = jnp.array(coefs).astype(x.dtype) body = functools.partial(iteration, coefs=coefs, **kwargs) - return loops.fori_loop(0, len(coefs), body, (u, True)) + return lax.fori_loop(0, len(coefs), body, (u, True)) u, _ = iterate( u, coefs=qr_coefs, update_fn=_use_qr, test_convergence=False @@ -221,7 +197,7 @@ def iterate(u, coefs, **kwargs): # (coef = None) until convergence. def cond_fun(state): k, _, is_not_converged = state - return lax.bitwise_and(is_not_converged, k < max_iterations) + return jnp.logical_and(is_not_converged, k < max_iterations) def body_fun(state): k, u, is_not_converged = state @@ -235,7 +211,7 @@ def body_fun(state): return k + 1, u, is_not_converged k = len(qr_coefs) + len(chol_coefs) - num_iters, u, is_not_converged = loops.while_loop( + num_iters, u, is_not_converged = lax.while_loop( cond_fun, body_fun, (k, u, is_not_converged) ) @@ -246,14 +222,14 @@ def body_fun(state): h = (h + h.T.conj()) / 2 # Converged within the maximum number of iterations. - is_converged = lax.bitwise_not(is_not_converged) + is_converged = jnp.logical_not(is_not_converged) return u, h, num_iters, is_converged # TODO: Add pivoting. @functools.partial( - api.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps') + jax.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps') ) def qdwh( x, @@ -303,7 +279,7 @@ def qdwh( else: m, n = M, N - with config.default_matmul_precision('float32'): + with jax.default_matmul_precision('float32'): u, h, num_iters, is_converged = _qdwh(x, m, n, max_iterations, eps) return u, h, num_iters, is_converged diff --git a/jax/_src/lax/stack.py b/jax/_src/lax/stack.py index e5a6fdf2163d..882195f17d51 100644 --- a/jax/_src/lax/stack.py +++ b/jax/_src/lax/stack.py @@ -22,9 +22,9 @@ from typing import Any -from jax._src.lax import lax -from jax._src.lax import slicing -from jax._src import tree_util +import jax +from jax import lax +import jax.numpy as jnp class Stack: """A bounded functional stack implementation. Elements may be pytrees.""" @@ -44,9 +44,9 @@ def create(capacity: int, prototype: Any) -> Stack: structure; the specific values are ignored. """ return Stack( - lax.full((), 0, 'int32'), - tree_util.tree_map( - lambda x: lax.full((capacity, *x.shape), 0, x.dtype), prototype)) + jnp.array(0, jnp.int32), + jax.tree_util.tree_map( + lambda x: jnp.zeros((capacity,) + tuple(x.shape), x.dtype), prototype)) def empty(self) -> Any: """Returns true if the stack is empty.""" @@ -56,23 +56,23 @@ def push(self, elem: Any) -> Stack: """Pushes `elem` onto the stack, returning the updated stack.""" return Stack( self._size + 1, - tree_util.tree_map( - lambda x, y: slicing.dynamic_update_index_in_dim(x, y, self._size, 0), + jax.tree_util.tree_map( + lambda x, y: lax.dynamic_update_index_in_dim(x, y, self._size, 0), self._data, elem)) def pop(self) -> tuple[Any, Stack]: """Pops from the stack, returning an (elem, updated stack) pair.""" - elem = tree_util.tree_map( - lambda x: slicing.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False), + elem = jax.tree_util.tree_map( + lambda x: lax.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False), self._data) return elem, Stack(self._size - 1, self._data) def flatten(self): - leaves, treedef = tree_util.tree_flatten(self._data) + leaves, treedef = jax.tree_util.tree_flatten(self._data) return ([self._size] + leaves), treedef @staticmethod def unflatten(treedef, leaves): - return Stack(leaves[0], tree_util.tree_unflatten(treedef, leaves[1:])) + return Stack(leaves[0], jax.tree_util.tree_unflatten(treedef, leaves[1:])) -tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten) +jax.tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten) diff --git a/jax/_src/lax/svd.py b/jax/_src/lax/svd.py index 6055bce44303..77ff4297e137 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/lax/svd.py @@ -40,30 +40,13 @@ import operator from typing import Any -from jax._src import api -from jax._src import config +import jax +from jax import lax from jax._src import core -from jax._src import dtypes -from jax._src.lax.control_flow import loops -from jax._src.lax import lax -from jax._src.lax import linalg as lax_linalg -from jax._src.lax import qdwh -from jax._src.typing import Array -import numpy as np - -def _construct_diagonal(s: Array) -> Array: - """Construct a (batched) diagonal matrix""" - # signature: (...,n)->(...,n,n) - i = lax.iota('int32', s.shape[-1]) - return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s) - -def _extract_diagonal(s: Array) -> Array: - """Extract the diagonal from a batched matrix""" - # signature: (...,n,m)->(...k) where k=min(n,m) - i = lax.iota('int32', min(s.shape[-2], s.shape[-1])) - return s[..., i, i] - -@functools.partial(api.jit, static_argnums=(1, 2, 3, 4)) +import jax.numpy as jnp + + +@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) def _svd_tall_and_square_input( a: Any, hermitian: bool, @@ -86,23 +69,22 @@ def _svd_tall_and_square_input( `a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned. """ - u_p, h, _, _ = qdwh.qdwh( + u_p, h, _, _ = lax.linalg.qdwh( a, is_hermitian=hermitian, max_iterations=max_iterations ) # TODO: Uses `eigvals_only=True` if `compute_uv=False`. - v, s = lax_linalg.eigh( + v, s = lax.linalg.eigh( h, subset_by_index=subset_by_index, sort_eigenvalues=False ) # Singular values are non-negative by definition. But eigh could return small # negative values, so we clamp them to zero. - s = lax.max(s, lax._zeros(s)) + s = jnp.maximum(s, 0.0) # Sort or reorder singular values to be in descending order. - s_out, sort_idx = lax.rev(s, (0,)), lax.rev(lax.iota('int32', len(s)), (0,)) - s_out, sort_idx = lax.sort_key_val(s_out, sort_idx) - s_out, sort_idx = lax.rev(s_out, (0,)), lax.rev(sort_idx, (0,)) + sort_idx = jnp.argsort(s, descending=True) + s_out = s[sort_idx] if not compute_uv: return s_out @@ -117,20 +99,18 @@ def _svd_tall_and_square_input( # eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing # 35, no. 3 (2013): A1325-A1349. def correct_rank_deficiency(u_out): - u_out, r = lax_linalg.qr(u_out, full_matrices=False) - r_diag = _extract_diagonal(r) - ones = lax.full(r_diag.shape, 1, u_out.dtype) - u_out = u_out @ _construct_diagonal(lax.select(r_diag >= 0, ones, -ones)) + u_out, r = lax.linalg.qr(u_out, full_matrices=False) + u_out = u_out @ jnp.diag(jnp.where(jnp.diag(r) >= 0, 1, -1)) return u_out - eps = float(dtypes.finfo(a.dtype).eps) + eps = float(jnp.finfo(a.dtype).eps) do_correction = s_out[-1] <= a.shape[1] * eps * s_out[0] cond_f = lambda args: args[1] body_f = lambda args: (correct_rank_deficiency(args[0]), False) - u_out, _ = loops.while_loop(cond_f, body_f, (u_out, do_correction)) + u_out, _ = lax.while_loop(cond_f, body_f, (u_out, do_correction)) return (u_out, s_out, v_out) -@functools.partial(api.jit, static_argnums=(1, 2, 3, 4, 5)) +@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) def svd( a: Any, full_matrices: bool, @@ -217,7 +197,7 @@ def svd( reduce_to_square = False if full_matrices: - q_full, a_full = lax_linalg.qr(a, full_matrices=True) + q_full, a_full = lax.linalg.qr(a, full_matrices=True) q = q_full[:, :n] u_out_null = q_full[:, n:] a = a_full[:n, :] @@ -226,16 +206,16 @@ def svd( # The constant `1.15` comes from Yuji Nakatsukasa's implementation # https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav if m > 1.15 * n: - q, a = lax_linalg.qr(a, full_matrices=False) + q, a = lax.linalg.qr(a, full_matrices=False) reduce_to_square = True if not compute_uv: - with config.default_matmul_precision('float32'): + with jax.default_matmul_precision('float32'): return _svd_tall_and_square_input( a, hermitian, compute_uv, max_iterations, subset_by_index ) - with config.default_matmul_precision('float32'): + with jax.default_matmul_precision('float32'): u_out, s_out, v_out = _svd_tall_and_square_input( a, hermitian, compute_uv, max_iterations, subset_by_index ) @@ -243,20 +223,17 @@ def svd( u_out = q @ u_out if full_matrices: - u_out = lax.concatenate([u_out, u_out_null], dimension=1) + u_out = jnp.hstack((u_out, u_out_null)) - if dtypes.issubdtype(a.dtype, np.complexfloating): - is_finite = (lax.is_finite(a.real) & lax.is_finite(a.imag)).all() - else: - is_finite = lax.is_finite(a).all() - cond_f = lambda args: lax.bitwise_not(args[0].astype(bool)) + is_finite = jnp.all(jnp.isfinite(a)) + cond_f = lambda args: jnp.logical_not(args[0]) body_f = lambda args: ( - lax.full((), True, dtype=bool), - lax.full_like(u_out, np.nan), - lax.full_like(s_out, np.nan), - lax.full_like(v_out, np.nan), + jnp.array(True), + jnp.full_like(u_out, jnp.nan), + jnp.full_like(s_out, jnp.nan), + jnp.full_like(v_out, jnp.nan), ) - _, u_out, s_out, v_out = loops.while_loop( + _, u_out, s_out, v_out = lax.while_loop( cond_f, body_f, (is_finite, u_out, s_out, v_out) ) From 87350b712854593def45975348c7987e76be07e6 Mon Sep 17 00:00:00 2001 From: Damiano Amatruda Date: Tue, 3 Sep 2024 13:54:31 +0000 Subject: [PATCH 337/702] Fix pytype errors and args for jax.Array methods --- jax/_src/basearray.pyi | 26 ++++++++++++++------------ jax/_src/numpy/array_methods.py | 17 ++++++++--------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index bd546169d341..23389b392414 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -131,15 +131,17 @@ class Array(abc.ABC): # np.ndarray methods: def all(self, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ... - def any(self: Array, axis: Axis = None, out: None = None, + def any(self, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ... - def argmax(self: Array, axis: int | None = None, out: None = None, + def argmax(self, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: ... def argmin(self, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: ... - def argpartition(self, kth, axis=-1, kind='introselect', order: None = None) -> Array: ... - def argsort(self, axis: int | None = -1, kind='quicksort', order: None = None) -> Array: ... - def astype(self, dtype: DTypeLike | None = None, max: ArrayLike | None = None) -> Array: ... + def argpartition(self, kth: int, axis: int = -1) -> Array: ... + def argsort(self, axis: int | None = -1, *, kind: None = None, order: None = None, + stable: bool = True, descending: bool = False) -> Array: ... + def astype(self, dtype: DTypeLike | None = None, copy: bool = False, + device: Device | Sharding | None = None) -> Array: ... def choose(self, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: ... def clip(self, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: ... def compress(self, condition: ArrayLike, @@ -148,10 +150,10 @@ class Array(abc.ABC): def conj(self) -> Array: ... def conjugate(self) -> Array: ... def copy(self) -> Array: ... - def cumprod(self, axis: int | Sequence[int] | None = None, - dtype: DTypeLike | None = None, out: None = None) -> Array: ... - def cumsum(self, axis: int | Sequence[int] | None = None, - dtype: DTypeLike | None = None, out: None = None) -> Array: ... + def cumprod(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def cumsum(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None) -> Array: ... def diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: ... def dot(self, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: ... @@ -176,7 +178,7 @@ class Array(abc.ABC): out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: ... - def ptp(self, axis: Axis = None, out: None = None, + def ptp(self, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: ... def ravel(self, order: str = 'C') -> Array: ... @property @@ -189,7 +191,7 @@ class Array(abc.ABC): sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ... def sort(self, axis: int | None = -1, *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False) -> Array: ... - def squeeze(self, axis: int | Sequence[int] | None = None) -> Array: ... + def squeeze(self, axis: Axis = None) -> Array: ... def std(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ... @@ -212,7 +214,7 @@ class Array(abc.ABC): def var(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ... - def view(self, dtype=None, type=None) -> Array: ... + def view(self, dtype: DTypeLike | None = None, type: None = None) -> Array: ... # Even though we don't always support the NumPy array protocol, e.g., for # tracer types, for type checking purposes we must declare support so we diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index a0222b5c586d..547fe1247459 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -90,13 +90,12 @@ def _argmin(self: Array, axis: int | None = None, out: None = None, """ return lax_numpy.argmin(self, axis=axis, out=out, keepdims=keepdims) -def _argpartition(self: Array, kth: int, axis: int = -1, - kind: str = 'introselect', order: None = None) -> Array: +def _argpartition(self: Array, kth: int, axis: int = -1) -> Array: """Return the indices that partially sort the array. Refer to :func:`jax.numpy.argpartition` for the full documentation. """ - return lax_numpy.argpartition(self, kth=kth, axis=axis, kind=kind, order=order) + return lax_numpy.argpartition(self, kth=kth, axis=axis) def _argsort(self: Array, axis: int | None = -1, *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False) -> Array: @@ -123,7 +122,7 @@ def _choose(self: Array, choices: Sequence[ArrayLike], out: None = None, mode: s Refer to :func:`jax.numpy.choose` for the full documentation. """ - return lax_numpy.choose(self, choices=choices) + return lax_numpy.choose(self, choices=choices, out=out, mode=mode) def _clip(self: Array, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: """Return an array whose values are limited to a specified range. @@ -163,16 +162,16 @@ def _copy(self: Array) -> Array: """ return lax_numpy.copy(self) -def _cumprod(self: Array, axis: int | Sequence[int] | None = None, - dtype: DTypeLike | None = None, out: None = None) -> Array: +def _cumprod(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None) -> Array: """Return the cumulative product of the array. Refer to :func:`jax.numpy.cumprod` for the full documentation. """ return reductions.cumprod(self, axis=axis, dtype=dtype, out=out) -def _cumsum(self: Array, axis: int | Sequence[int] | None = None, - dtype: DTypeLike | None = None, out: None = None) -> Array: +def _cumsum(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None) -> Array: """Return the cumulative sum of the array. Refer to :func:`jax.numpy.cumsum` for the full documentation. @@ -337,7 +336,7 @@ def _sort(self: Array, axis: int | None = -1, *, kind: None = None, return lax_numpy.sort(self, axis=axis, kind=kind, order=order, stable=stable, descending=descending) -def _squeeze(self: Array, axis: int | Sequence[int] | None = None) -> Array: +def _squeeze(self: Array, axis: reductions.Axis = None) -> Array: """Remove one or more length-1 axes from array. Refer to :func:`jax.numpy.squeeze` for full documentation. From f92d4e3e3d8f0bf26c1b596d4ccb2d54d2e9cf74 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 3 Sep 2024 14:19:14 -0400 Subject: [PATCH 338/702] Add TPU v6e to the list of known TPUs. JAX will warn if it sees a device ID on this list but the runtime doesn't find one. --- jax/_src/hardware_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py index dd3da5c4f58b..7ab5de297752 100644 --- a/jax/_src/hardware_utils.py +++ b/jax/_src/hardware_utils.py @@ -20,13 +20,16 @@ _TPU_PCI_DEVICE_IDS = [ # TPU v2, v3 '0x0027', + # No public name (plc) + '0x0056', # TPU v4 '0x005e', + # TPU v5p + '0x0062', # TPU v5e '0x0063', - # Testing only - '0x0056', - '0x0062', + # TPU v6e + '0x006f', ] _TPU_ENHANCED_BARRIER_SUPPORTED = [ From 9030aec09757c79de1c51d62e6058533449d3f7c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Sep 2024 11:26:22 -0700 Subject: [PATCH 339/702] Added a new Pallas Triton primitive -- ``plgpu.debug_barrier`` Closes #23400. PiperOrigin-RevId: 670636723 --- jax/_src/pallas/triton/primitives.py | 20 ++++++++++++++++++++ jax/experimental/pallas/gpu.py | 1 + tests/pallas/ops_test.py | 16 ++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 8518a94ed9cf..23fce50dc4f9 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -20,6 +20,7 @@ import jax from jax import core as jax_core +from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas.triton import lowering from jax.interpreters import mlir @@ -120,3 +121,22 @@ def _elementwise_inline_asm_lowering( packed_element=pack, args=args, ).result + + +def debug_barrier() -> None: + """Synchronizes all kernel executions in the grid.""" + return debug_barrier_p.bind() + + +debug_barrier_p = jax_core.Primitive("debug_barrier_p") +debug_barrier_p.multiple_results = True + +@debug_barrier_p.def_abstract_eval +def _debug_barrier_abstract_eval() -> Sequence[jax_core.ShapedArray]: + return () + +@lowering.register_lowering(debug_barrier_p) +def _debug_barrier_lowering(ctx: lowering.LoweringRuleContext): + del ctx # Unused. + gpu_dialect.barrier() + return [] diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py index adade4e8a72c..a24bfe4150df 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/experimental/pallas/gpu.py @@ -15,4 +15,5 @@ """Triton-specific Pallas APIs.""" from jax._src.pallas.triton.primitives import approx_tanh +from jax._src.pallas.triton.primitives import debug_barrier from jax._src.pallas.triton.primitives import elementwise_inline_asm diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index ab7ffcb480b2..85bda21ec7f2 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -972,6 +972,22 @@ def kernel(x_ref, o_ref): x = jnp.arange(256).astype(jnp.float16) np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3) + def test_debug_barrier(self): + if self.INTERPRET: + self.skipTest("debug_barrier is not supported in interpret mode") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + grid=1, + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + plgpu.debug_barrier() + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x) + def test_debug_print(self): # TODO: this test flakes on gpu if jtu.test_device_matches(["gpu"]): From 24bb8ae443638517dfa1dedf7253dda7d345818c Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Tue, 3 Sep 2024 05:08:13 -0700 Subject: [PATCH 340/702] [ffi] Add support for token inputs and outputs --- jax/_src/extend/ffi.py | 45 +++++++++++++++++++++++++++++------------- tests/extend_test.py | 17 ++++++++++++++++ 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index 3965c8b72c67..aa092a768bb3 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -30,7 +30,7 @@ from jax._src.lib import jaxlib from jax._src.lib import xla_client from jax._src.lib.mlir import ir -from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray +from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, Shape map, unsafe_map = util.safe_map, map @@ -100,6 +100,14 @@ def include_dir() -> str: return os.path.join(jaxlib_dir, "include") +def _aval_shape(aval: core.AbstractValue) -> Shape: + return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error + + +def _default_layouts(avals: Iterable[core.AbstractValue]) -> list[list[DimSize]]: + return [list(reversed(range(len(_aval_shape(aval))))) for aval in avals] + + def ffi_lowering( call_target_name: str, *, @@ -139,17 +147,17 @@ def _lowering( if "result_types" not in kwargs: kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] if operand_layouts is None: - kwargs["operand_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_in) # pytype: disable=attribute-error + kwargs["operand_layouts"] = _default_layouts(ctx.avals_in) else: kwargs["operand_layouts"] = operand_layouts if result_layouts is None: - kwargs["result_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_out) + kwargs["result_layouts"] = _default_layouts(ctx.avals_out) else: kwargs["result_layouts"] = result_layouts if "result_shapes" not in kwargs and not all( - core.is_constant_shape(aval.shape) for aval in ctx.avals_out): + core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out): kwargs["result_shapes"] = [ - mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, aval.shape)) + mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval))) for aval in ctx.avals_out] return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore @@ -157,13 +165,23 @@ def _lowering( return _lowering -def _default_layouts(shapes: Iterable[Sequence[DimSize]]) -> list[list[DimSize]]: - return [list(reversed(range(len(shape)))) for shape in shapes] +ResultMetadata = DuckTypedArray | core.AbstractToken + + +def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue, ...]: + avals: list[core.AbstractValue] = [] + for result in results: + if isinstance(result, core.AbstractToken): + avals.append(result) + else: + _check_shape_dtype(result) + avals.append(core.ShapedArray(result.shape, result.dtype)) + return tuple(avals) def ffi_call( target_name: str, - result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray], + result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], *args: ArrayLike, vectorized: bool = False, **kwargs: Any, @@ -189,6 +207,7 @@ def ffi_call( ``dtype`` attributes which are expected to match the shape and dtype of the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often used to define the elements of ``result_shape_dtypes``. + ``jax.core.abstract_token`` may be used to represent a token-typed output. *args: the arguments passed to the custom call. vectorized: boolean specifying whether the callback function can operate in a vectorized manner, as described above. @@ -201,12 +220,10 @@ def ffi_call( """ if isinstance(result_shape_dtypes, Sequence): multiple_results = True - result_types = result_shape_dtypes + result_avals = _result_avals(result_shape_dtypes) else: multiple_results = False - result_types = (result_shape_dtypes,) - map(_check_shape_dtype, result_types) - result_avals = tuple(core.ShapedArray(x.shape, x.dtype) for x in result_types) + result_avals = _result_avals((result_shape_dtypes,)) results = ffi_call_p.bind( *args, result_avals=result_avals, @@ -222,7 +239,7 @@ def ffi_call( def ffi_call_abstract_eval( *avals_in, - result_avals: tuple[core.ShapedArray, ...], + result_avals: tuple[core.AbstractValue, ...], target_name: str, vectorized: bool, **kwargs: Any, @@ -248,7 +265,7 @@ def ffi_call_transpose(*args, target_name, **kwargs): def ffi_call_lowering( ctx: mlir.LoweringRuleContext, *operands: ir.Value, - result_avals: tuple[core.ShapedArray, ...], + result_avals: tuple[core.AbstractValue, ...], target_name: str, vectorized: bool, **kwargs: Any, diff --git a/tests/extend_test.py b/tests/extend_test.py index cdf3af8fbc4d..867e46fe6262 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -31,6 +31,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.interpreters import mlir +from jax._src.lib.mlir.dialects import hlo jax.config.parse_flags_with_absl() @@ -153,6 +154,22 @@ def fun(x): return self.fail("No custom_call found in the lowered IR") + def testToken(self): + def fun(): + token = lax.create_token() + return jex.ffi.ffi_call("test_ffi", core.abstract_token, token) + + # Ensure that token inputs and outputs are translated to the correct type + module = jax.jit(fun).lower().compiler_ir("stablehlo") + for func in module.body.operations: + for block in func.body.blocks: + for op in block.operations: + if op.OPERATION_NAME == "stablehlo.custom_call": + self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) + self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) + return + self.fail("No custom_call found in the lowered IR") + @jtu.sample_product( shape=[(1,), (4,), (5,)], dtype=(np.int32,), From 9b31b73d9d5b6617ca95b4ec800b0867dafc21d6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Sep 2024 13:37:05 -0700 Subject: [PATCH 341/702] Added basic pipelining to Pallas on Mosaic GPU The implementation only allows at most one sequential axis at the moment. PiperOrigin-RevId: 670687671 --- jax/_src/pallas/mosaic_gpu/lowering.py | 159 +++++++++++++++++++------ tests/pallas/mosaic_gpu_test.py | 33 +++++ 2 files changed, 156 insertions(+), 36 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index cd5e3bbddb8d..fb0119025d4c 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -20,7 +20,7 @@ import dataclasses import functools import math -from typing import Any, cast +from typing import Any, Literal, TypedDict, cast import jax from jax._src import core as jax_core @@ -33,6 +33,7 @@ from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect +from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src.state import primitives as sp @@ -136,7 +137,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name def _eval_index_map( - ctx: ModuleContext, idx, block_mapping: pallas_core.BlockMapping + ctx: ModuleContext, idx: ir.Value, block_mapping: pallas_core.BlockMapping ) -> Sequence[ir.Value]: block_indices = lower_jaxpr_to_mosaic_gpu( ctx, block_mapping.index_map_jaxpr.jaxpr, idx @@ -151,6 +152,11 @@ def _eval_index_map( return tuple(result) +class Params(TypedDict, total=False): + num_stages: int + dimension_semantics: Sequence[Literal["sequential", "parallel"]] + + def lower_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, @@ -160,6 +166,8 @@ def lower_jaxpr_to_module( ) -> LoweringResult: del cost_estimate # Unused. + block_mappings = grid_mapping.block_mappings + assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims if len(grid_mapping.grid) > 3: @@ -175,8 +183,7 @@ def lower_jaxpr_to_module( "Scalar prefetch not supported in Mosaic GPU lowering." ) if not all( - isinstance(bm.indexing_mode, pallas_core.Blocked) - for bm in grid_mapping.block_mappings + isinstance(bm.indexing_mode, pallas_core.Blocked) for bm in block_mappings ): raise NotImplementedError( "Only Blocked indexing mode is supported in Mosaic GPU lowering." @@ -192,19 +199,31 @@ def lower_jaxpr_to_module( grid += (1,) * (3 - len(grid)) block = (128,) + (1,) * (len(grid) - 1) + params = Params(**compiler_params.get("mosaic_gpu", {})) + num_stages = params.get("num_stages", 1) + dimension_semantics = params.get( + "dimension_semantics", ["parallel"] * len(grid_mapping.grid) + ) + assert len(dimension_semantics) == len(grid_mapping.grid) + sequential_axes = tuple( + i for i, s in enumerate(dimension_semantics) if s == "sequential" + ) + assert all(grid[axis] for axis in sequential_axes) + assert all(block[axis] == 1 for axis in sequential_axes) + in_structs_gmem = [*grid_mapping.in_shapes] in_structs_smem = [ - jax.ShapeDtypeStruct(bm.block_shape, s.dtype) + jax.ShapeDtypeStruct([num_stages, *bm.block_shape], s.dtype) for bm, s in zip( - grid_mapping.block_mappings[: grid_mapping.num_inputs], + block_mappings[: grid_mapping.num_inputs], grid_mapping.in_shapes, ) ] out_structs_gmem = [*grid_mapping.out_shapes] out_structs_smem = [ - jax.ShapeDtypeStruct(bm.block_shape, s.dtype) + jax.ShapeDtypeStruct([num_stages, *bm.block_shape], s.dtype) for bm, s in zip( - grid_mapping.block_mappings[grid_mapping.num_inputs :], + block_mappings[grid_mapping.num_inputs :], grid_mapping.out_shapes, ) ] @@ -219,56 +238,117 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): buffers_smem, [grid_mapping.num_inputs] ) - [barrier] = cast(mgpu.BarrierRef, barriers) - module_ctx = ModuleContext( name_and_src_info.name, grid_mapping, runtime_smem, smem_used_bytes=0 ) program_ids = map(_program_id, range(len(grid_mapping.grid))) start_indices = map( functools.partial(_eval_index_map, module_ctx, program_ids), - grid_mapping.block_mappings, + block_mappings, ) in_start_indices, out_start_indices = util.split_list( start_indices, [grid_mapping.num_inputs] ) - with mgpu.single_thread(): + def gmem_slice( + start_indices: Sequence[ir.Value], + step: ir.Value, + shape: Sequence[int], + ) -> ir.Value: + return tuple( + mgpu.ds( + arith_dialect.addi( + start_index, arith_dialect.muli(step, _as_index(dim)) + ) + if axis in sequential_axes + else start_index, + dim, + ) + for axis, (start_index, dim) in enumerate(zip(start_indices, shape)) + ) + + @mgpu.single_thread() + def fetch(step: ir.Value, slot: ir.Value) -> None: for start_indices, b_gmem, b_smem in zip( in_start_indices, in_buffers_gmem, in_buffers_smem ): # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. + b_smem_shape = ir.MemRefType(b_smem.type).shape[1:] launch_ctx.async_copy( src_ref=b_gmem, - dst_ref=b_smem, - gmem_slice=tuple( - map(mgpu.ds, start_indices, ir.MemRefType(b_smem.type).shape) - ), - barrier=barrier, + dst_ref=mgpu.memref_slice(b_smem, slot), + gmem_slice=gmem_slice(start_indices, step, b_smem_shape), + barrier=barriers[slot], swizzle=None, arrive=True, uniform=False, ) - if grid_mapping.num_inputs: - # Only wait if async copies were issued. - barrier.wait() - - _ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, buffers_smem) - mgpu.commit_shared() - - for start_indices, b_gmem, b_smem in zip( - out_start_indices, out_buffers_gmem, out_buffers_smem - ): - # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. - launch_ctx.async_copy( - src_ref=b_smem, - dst_ref=b_gmem, - gmem_slice=tuple( - map(mgpu.ds, start_indices, ir.MemRefType(b_smem.type).shape) - ), - swizzle=None, + @mgpu.single_thread() + def store(step: ir.Value, slot: ir.Value) -> None: + for start_indices, b_gmem, b_smem in zip( + out_start_indices, out_buffers_gmem, out_buffers_smem + ): + # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. + b_smem_shape = ir.MemRefType(b_smem.type).shape[1:] + launch_ctx.async_copy( + src_ref=mgpu.memref_slice(b_smem, slot), + dst_ref=b_gmem, + gmem_slice=gmem_slice(start_indices, step, b_smem_shape), + swizzle=None, + uniform=False, + ) + + # Compute the number of steps along each sequential axis. + if sequential_axes: + # TODO(slebedev): Support multiple sequential axes. + if len(sequential_axes) > 1: + raise NotImplementedError( + "Multiple sequential axes are not supported in Mosaic GPU lowering." + ) + [sequential_axis] = sequential_axes + if any( + b_gmem.shape[sequential_axis] % b_smem.shape[1 + sequential_axis] + for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) + ): + raise ValueError( + "Array dimensions along the sequential axis must be divisible by" + " the corresponding block dimensions." + ) + [num_steps] = { + b_gmem.shape[sequential_axis] // b_smem.shape[1 + sequential_axis] + for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) + } + else: + num_steps = 1 + + for slot in range(num_stages): + fetch(_as_index(slot), _as_index(slot)) + + @mgpu.fori(_as_index(num_steps), ()) + def _(step, _): + slot = arith_dialect.remui(step, _as_index(num_stages)) + if grid_mapping.num_inputs: + # Only wait if async copies were issued. + barriers[slot].wait() + + _ = lower_jaxpr_to_mosaic_gpu( + module_ctx, + jaxpr, + [mgpu.memref_slice(b_smem, slot) for b_smem in buffers_smem], + ) + mgpu.commit_shared() + store(step, slot) + + next_step = arith_dialect.addi(step, _as_index(num_stages)) + next_step_in_bounds = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps) ) + with ir.InsertionPoint(scf_dialect.IfOp(next_step_in_bounds).then_block): + fetch(next_step, slot) + scf_dialect.yield_([]) + + return () launch_ctx.await_async_copy(0) @@ -292,7 +372,10 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): *in_structs_smem, *out_structs_smem, *extra_smem_scratch, - mgpu.TMABarrier(), + mgpu.Barrier( + arrival_count=len(in_structs_gmem), + num_barriers=num_stages, + ), ), module_name=name_and_src_info.name, ) @@ -549,6 +632,10 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value: raise NotImplementedError(f"Unsupported constant: {v!r}") +def _i32_constant(v: object) -> ir.Value: + return _ir_constant(v, ir.IntegerType.get_signless(32)) + + def _as_index(v: int | ir.Value) -> ir.Value: if isinstance(v, int): return arith_dialect.constant(ir.IndexType.get(), v) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 3f690f9ca911..bdf6396cae5c 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -52,6 +52,18 @@ def kernel(x_ref, o_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_add_xy(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + + x = jnp.arange(256).astype(jnp.float32) + y = x + 1 + np.testing.assert_array_equal(kernel(x, y), x + y) + def test_add_one_grid(self): @functools.partial( pl.pallas_call, @@ -66,6 +78,27 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 2).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) + @parameterized.product(num_stages=[1, 2, 3]) + def test_add_one_grid_pipelined(self, num_stages): + @functools.partial( + pl.pallas_call, + in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], + out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), + out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), + compiler_params=dict( + mosaic_gpu=dict( + dimension_semantics=["parallel", "sequential"], + num_stages=num_stages, + ), + ), + grid=(2, 1), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 + + x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_add_doubled_sum(self): @functools.partial( pl.pallas_call, From 252caebce34808dffa148cec88cfc1ad85272862 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 3 Sep 2024 14:30:37 -0700 Subject: [PATCH 342/702] Create `jax.make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], devices: Sequence[jax.Device] | None = None)` API to make it easier to create a mesh and reduce a ton of boilerplate. `jax.make_mesh` is the stable API endpoint of `mesh_utils` but without all the extra options. If you want those, you can still use the experimental endpoint in `mesh_utils`. PiperOrigin-RevId: 670707995 --- jax/BUILD | 12 +- jax/__init__.py | 1 + jax/_src/mesh_utils.py | 811 +++++++++++++++++++++++++++++++++ jax/_src/sharding_impls.py | 54 +++ jax/experimental/mesh_utils.py | 800 +------------------------------- tests/mesh_utils_test.py | 2 +- tests/pjit_test.py | 8 +- 7 files changed, 889 insertions(+), 799 deletions(-) create mode 100644 jax/_src/mesh_utils.py diff --git a/jax/BUILD b/jax/BUILD index 574559688c4d..b761722a5b1f 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -274,6 +274,7 @@ py_library_providing_imports_info( ":dtypes", ":effects", ":environment_info", + ":internal_mesh_utils", ":jaxpr_util", ":layout", ":lazy_loader", @@ -784,6 +785,7 @@ pytype_strict_library( deps = [ ":config", ":core", + ":internal_mesh_utils", ":mesh", ":op_shardings", ":partition_spec", @@ -807,6 +809,14 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_library( + name = "internal_mesh_utils", + srcs = ["_src/mesh_utils.py"], + deps = [ + ":xla_bridge", + ], +) + pytype_strict_library( name = "source_info_util", srcs = ["_src/source_info_util.py"], @@ -1064,7 +1074,7 @@ pytype_library( srcs = ["experimental/mesh_utils.py"], visibility = ["//visibility:public"], deps = [ - ":xla_bridge", + ":internal_mesh_utils", ], ) diff --git a/jax/__init__.py b/jax/__init__.py index 168ac9278586..7e958b21c5dd 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -129,6 +129,7 @@ from jax._src.api import vmap as vmap from jax._src.api import xla_computation as _deprecated_xla_computation from jax._src.sharding_impls import NamedSharding as NamedSharding +from jax._src.sharding_impls import make_mesh as make_mesh # Force import, allowing jax.interpreters.* to be used after import jax. from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py new file mode 100644 index 000000000000..7cac0338a923 --- /dev/null +++ b/jax/_src/mesh_utils.py @@ -0,0 +1,811 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils for building a device mesh.""" + +from __future__ import annotations + +import collections +from collections.abc import Callable, Generator, MutableMapping, Sequence +import itertools +import logging +import math +from typing import Any + +from jax._src import xla_bridge as xb +import numpy as np + +logger = logging.getLogger(__name__) + +_TPU_V2 = 'TPU v2' +_TPU_V3 = 'TPU v3' +_TPU_V4 = 'TPU v4' +_TPU_V5_LITE = "TPU v5 lite" + +# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's +# famous contiguous mesh trick. +# +# The trick only works for certain topologies and mesh shapes. Trivial dims of +# size 1 can be added to the shapes listed, and they are also supported. +_TRANSPOSE_TRICKS: dict[ + tuple[int, ...], dict[tuple[int, ...], tuple[int, ...]] +] = { + (2, 2, 1): { + (2, 2): (0, 1, 2), + }, + (2, 2, 4): { + (4, 4): (0, 1, 2), + }, + (4, 4, 4): { + (16, 4): (0, 2, 1), + }, + (4, 8, 8): { + (64, 4): (0, 2, 1), + (4, 64): (0, 2, 1), + }, + (8, 8, 8): { + (64, 8): (0, 2, 1), + }, + (8, 16, 16): { + (256, 8): (0, 2, 1), + (8, 256): (0, 2, 1), + }, +} + +# Physical ordering of core IDs in a tray that creates a ring +_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) +_TRAY_2x2_RING_ORDER = (0, 1, 3, 2) +_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4) + +def _tpu_v2_v3_create_device_mesh( + mesh_shape: Sequence[int], + devices: Sequence[Any], + **unused_kwargs, +) -> np.ndarray: + if len(devices) == 8: + logger.info( + 'Reordering mesh to physical ring order on single-tray TPU v2/v3.' + ) + device_mesh = np.asarray(devices) + device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + elif mesh_shape[-1] == 8: + device_mesh = np.asarray(devices).reshape(mesh_shape) + logger.info( + 'Reordering mesh to physical ring order on each TPU v2/v3 tray.' + ) + perm = np.array(_TRAY_RING_ORDER) + device_mesh = device_mesh[..., perm] + return device_mesh + else: + # TODO(skye): implement 2D mesh_shape logic here: + # https://github.com/tensorflow/lingvo/blob/0df40cf604dfcd14e28f7087d73687a0bd2fe5c6/lingvo/core/gshard_utils.py#L187 + # (possibly replaces above mesh_shape[-1] == 8 case) + return np.asarray(devices).reshape(mesh_shape) + + +def _vlc_create_device_mesh( + mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs +) -> np.ndarray | None: + """Creates rotated pincer device assignment for selected topologies. + + Args: + mesh_shape: Logical mesh shape used by the model. + devices: TPU devices. + **unused_kwargs: ... + + Returns: + None or reordered devices reshaped as `mesh_shape`. + """ + max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices) + bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1 + # Our ring re-ordering makes sense only if the passed-in devices are + # sequential, which may not always be the case. reversed() changes z-minor to + # x-minor. + sequential_devices = sorted( + devices, + key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) + + if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2 + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4 + # Only uses ring order if the whole mesh is a replica group. + if max(mesh_shape) == len(devices): + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + return None + + +# Registers functions to create device mesh for specific device kinds. Takes +# precedence over the more general logic in create_device_mesh(). Handler may +# return None; in that case, it will fall back to using the default logic. +device_kind_handler_dict: dict[ + str, + Callable[..., np.ndarray | None], +] = { + _TPU_V2: _tpu_v2_v3_create_device_mesh, + _TPU_V3: _tpu_v2_v3_create_device_mesh, + _TPU_V5_LITE: _vlc_create_device_mesh, +} + + +def _create_device_mesh_for_nd_torus( + physical_mesh: np.ndarray, + mesh_shape: Sequence[int], + *, + allow_split_physical_axes: bool = False, +) -> tuple[np.ndarray, np.ndarray]: + """Assigns logical parallelism axes to physical axes of an N-D torus network. + + Given logical parallelism axes with sizes in `mesh_shape` and devices in an + N-dimensional torus network represented by `physical_mesh`, maps each logical + axis to one or more physical axes. Prefer to map more-performance-sensitive + logical axes to larger numbers of physical axes to maximize the bandwidth + available to them. Also prefer to assign logical axes to multiple physical + axes of the same size (e.g., a 2D square) rather than multiple physical axes + of different sizes when possible. + + If allow_split_physical_axes = False (default), this routine will error out + instead of splitting a physical axis over more than one logical axis (which + would reduce total usable bandwidth). + + Let's use a concrete example to explain the concepts and considerations. + + As an example, suppose the logical mesh is [data, model], for data and model + parallelism respectively. Also suppose that data parallelism is less + performance sensitive than model parallelism. Consider a 3D TPU pod slice of + shape 4x4x16, represented by a physical mesh of shape (4, 4, 16). + + A TPU pod slice has equal bandwidth along all axes with wraparound links, but + a 2D plane of size 4x4 may have faster XLA collective implementations than a + non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want + the more performance sensitive `model` axis to be mapped to the 4x4 XY plane. + + Args: + physical_mesh: a np.ndarray of devices in the shape of the N-D torus + physical topology. + mesh_shape: shape of the logical mesh (size of the various logical + parallelism axes), with axes ordered by increasing network intensity. + allow_split_physical_axes: If True, we would split physical axes if + necessary to fit the desired mesh shape. + + Returns: + An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with + each logical parallelism axis mapped to one or more physical mesh axes. + The axis assignment matrix, which is a 2-d array mapping from + (physical_axis, logical_axis) to the size assigned, with the invariant + np.prod(assignment, axis=1) = physical_mesh_shape, and + np.prod(assignment, axis=0) = mesh_shape. + """ + # Remaining physical axes to be assigned to logical axes. + assignable_physical_mesh = list(physical_mesh.shape) + # Map each logical axis to a subset of physical axes. + assignment: list[tuple[int, ...]] = [() for _ in mesh_shape] + + # Assign logical axes from highest network intensity to lowest. + # `mesh_shape` is assumed to ordered by lowest network intensity first, so + # reverse it first. + for logical_axis_index, logical_axis_size in reversed( + list(enumerate(mesh_shape)) + ): + # Preferentially map to more physical axes first for higher bandwidth. + for num_axes in range(3, 0, -1): + # Try assign to any subset of size num_axes. Generate all candidates. + indices_and_axes = itertools.combinations( + enumerate(assignable_physical_mesh), num_axes + ) + for elem in indices_and_axes: + c_indices, c_axes = zip(*elem) + # TODO(zhangqiaorjc): Due to limitations in XLA, 2D collectives only + # implemented for square 2D plane. Mapping a physical axis to two + # logical axes might be slower for non-square 2D plane, e.g., map 32 to + # 4x8 or a single axis. If XLA 2D collectives support non-square plane + # soon, we can continue to preferentially map to 2D plane in general, + # otherwise, we should treat non-square 2D plane and 1D submesh equally. + if np.prod(c_axes) == logical_axis_size: + assignment[logical_axis_index] = c_indices + # Zero the assigned physical axes. + assignable_physical_mesh = [ + 0 if i in c_indices else v + for i, v in enumerate(assignable_physical_mesh) + ] + break + if assignment[logical_axis_index]: + # We already found an assignment from one candidate above. + break + else: + # If the num_axes for loop did not break, i.e. none of the candidates work + # goto here with this while-else construct. + if logical_axis_size > 1: + if not allow_split_physical_axes: + # Although this is now implemented, there are downstream tasks + # counting on this being a NotImplementedError. + raise NotImplementedError( + 'Failed to find assignment for logical_axis_index' + f' {logical_axis_index} of size {logical_axis_size} with' + f' remaining assignable mesh {assignable_physical_mesh}. The size' + ' of each axis in your logical mesh must be equal to the product' + ' of some subset of the physical mesh axis sizes. E.g. logical' + ' mesh (4, 16) is compatible with physical mesh 4x4x4 since 4=4' + ' and 16=4x4. If you want to split physical axes, set ' + ' allow_split_physical_axes to True.' + ) + else: + # We will try finding an assignment, even if that means splitting the + # physical axes, which requires a more sophisticated implementation. + return _create_device_mesh_for_nd_torus_splitting_axes( + physical_mesh, mesh_shape + ) + + # Flatten the assignment, e.g., [(), (2,), (0, 1)] -> (2, 0, 1). + transpose: list[int] = [] + assignment_array = np.ones( + [len(physical_mesh.shape), len(mesh_shape)], dtype=np.int64 + ) + for i, x in enumerate(assignment): + for y in x: + physical_mesh_axis = int(y) + assignment_array[physical_mesh_axis, i] = physical_mesh.shape[ + physical_mesh_axis + ] + transpose.append(physical_mesh_axis) + return ( + physical_mesh.transpose(transpose).reshape(mesh_shape), + assignment_array, + ) + + +def _create_device_mesh_for_nd_torus_splitting_axes( + physical_mesh: np.ndarray, + mesh_shape: Sequence[int], +) -> tuple[np.ndarray, np.ndarray]: + """Assigns logical parallelism axes to physical axes of an N-D torus network. + + This implementation allows creating meshes that requires splitting physical + axes, and thus one could produce logical mesh of any shape, as long as the + number of devices matches, e.g., + + - Creating 2x2x4 from 4x4; + + - Creating 2x2x16 from 8x8; + + Args: + physical_mesh: a np.ndarray of devices in the shape of the N-D torus + physical topology. + mesh_shape: shape of the logical mesh (size of the various logical + parallelism axes), with axes ordered by increasing network intensity. + + Returns: + An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with + each logical parallelism axis mapped to one or more physical mesh axes. + The axis assignment matrix, which is a 2-d array mapping from + (physical_axis, logical_axis) to the size assigned, with the invariant + np.prod(assignment, axis=1) = physical_mesh_shape, and + np.prod(assignment, axis=0) = mesh_shape. + """ + if np.prod(physical_mesh.shape) != np.prod(mesh_shape): + raise ValueError( + 'The number of devices in physical mesh' + f' {physical_mesh.shape} does not match the number of devices' + f' in logical mesh {mesh_shape}.' + ) + + physical_mesh_shape = physical_mesh.shape + logical_mesh_shape = tuple(mesh_shape) + + # (Partial) assignment map as an 2-d array [p_axis, l_axis] -> size. + assignment = np.ones( + [len(physical_mesh_shape), len(logical_mesh_shape)], dtype=np.int64 + ) + + # Process logical axes from highest network intensity to lowest. + # `mesh_shape` is assumed to ordered by lowest network intensity first, so + # reverse it. + for logical_axis, logical_axis_size in reversed( + list(enumerate(logical_mesh_shape)) + ): + # Go over all the possible assignment for the logical axis, including the + # one that splits multiple physical axes. + best_logical_axis_assignment = None + for logical_axis_assignment in _enumerate_feasible_logical_axis_assignments( + physical_mesh_shape, assignment, logical_axis_size + ): + # TODO(rosun): Instead of using heuristics, replace this with a proper + # scoring function reflecting the underlying hardware properties. + if ( + best_logical_axis_assignment is None + or _prefer_first_logical_axis_assignment( + logical_axis_assignment, + best_logical_axis_assignment, + physical_mesh_shape=physical_mesh_shape, + assignment=assignment, + ) + ): + best_logical_axis_assignment = logical_axis_assignment + assignment[:, logical_axis] = best_logical_axis_assignment + + # Read out the assignment. + logical_mesh = _generate_logical_mesh( + physical_mesh, logical_mesh_shape, assignment + ) + + return logical_mesh, assignment + + +def _get_prime_factors(x: int) -> list[int]: + """Returns a sorted list of prime factors for the given number.""" + assert x > 0 + factors = [] + for p in range(2, math.isqrt(x) + 2): + while x % p == 0: + factors.append(p) + x //= p + if x == 1: + return factors + else: + return [x] # x is a prime number. + + +def _enumerate_feasible_logical_axis_assignments( + physical_mesh_shape: Sequence[int], + assignment: np.ndarray, + logical_axis_size: int, +) -> Generator[np.ndarray, None, None]: + """Yields feasible assignments for a single logical axis. + + For a physical mesh of shape [x_1, ..., x_n], and the product of all previous + assignments on each physical axes [y_1, ..., y_n], this function yields all + possible assignments for the axis as 1-d arrays [z_1, ..., z_n], so that: + + - prod(z_1, ..., z_n) = logical_axis_size + + - x_i % (z_i * y_i) = 0 + + Args: + physical_mesh_shape: Physical mesh shape. + assignment: Existing assignment matrix. + logical_axis_size: Size of the logical axis to assign. + + Yields: + All valid assignments for the logical axis. Each assignment is represented + as an integer array of length len(physical_mesh_shape). + """ + logical_axis_factors: MutableMapping[int, int] = collections.defaultdict(int) + for factor in _get_prime_factors(logical_axis_size): + logical_axis_factors[factor] += 1 + + available_physical_mesh_shape = np.array(physical_mesh_shape) // np.prod( + assignment, axis=-1 + ) + + # To enable efficient enumerations, we first index physical axes by their + # prime factors. Since we know the prime factorization of the logical axis + # size, we could simply enumerate by picking the correct count for each + # prime factor. + physical_axes_by_factor: MutableMapping[int, list[int]] = ( + collections.defaultdict(list) + ) + for physical_axis, physical_axis_size in enumerate( + available_physical_mesh_shape + ): + for factor in _get_prime_factors(physical_axis_size): + if factor not in logical_axis_factors: + continue + physical_axes_by_factor[factor].append(physical_axis) + + factors = [] + assignments_by_factor = [] + for factor, multiplicity in logical_axis_factors.items(): + factors.append(factor) + assignments_by_factor.append( + set( + itertools.combinations( + physical_axes_by_factor[factor], multiplicity + ) + ) + ) + + for axis_assignment in itertools.product(*assignments_by_factor): + result = np.ones([len(physical_mesh_shape)], dtype=np.int64) + for factor_index, per_factor_assignment in enumerate(axis_assignment): + for physical_axis in per_factor_assignment: + result[physical_axis] *= factors[factor_index] + yield result + + +def _prefer_first_logical_axis_assignment( + x: np.ndarray, + y: np.ndarray, + *, + physical_mesh_shape: Sequence[int], + assignment: np.ndarray, +) -> bool: + """Returns True if the first axis assignment is preferred over the second. + + For now, this is implemented with some very simple heuristics. However, + it is possible to introduce e.g., a value function here based on a more + precise model of the underlying hardware. + + TODO(rosun): Use a proxy of network capacity to select the partitions. + + Args: + x: Logical axis assignment as [len(physical_mesh_shape)] array. + y: Logical axis assignment as [len(physical_mesh_shape)] array. + physical_mesh_shape: Physical mesh shape. + assignment: Assignment matrix. + + Returns: + True if x is preferred over y. + """ + # Prefer occupying complete physical axes. I don't have a good reason for + # this, except that it is compatible with the existing behavior. + # + # E.g., on 4 x 4 x 8, [4, 4, -] will be preferred over [4, -, 4], and then + # over [2, 2, 4]. + x_whole_axis_size = np.prod( + [s for i, s in enumerate(x) if s == physical_mesh_shape[i]] + ) + y_whole_axis_size = np.prod( + [s for i, s in enumerate(y) if s == physical_mesh_shape[i]] + ) + + if x_whole_axis_size != y_whole_axis_size: + return x_whole_axis_size > y_whole_axis_size + + # Prefer occupying more whole physical axes for better bandwidth. + # + # This is consistent with existing logic, i.e., 2 x 2 is preferred over 4. + x_num_whole_axes = len( + [1 for i, s in enumerate(x) if s == physical_mesh_shape[i] and s > 1] + ) + y_num_whole_axes = len( + [1 for i, s in enumerate(y) if s == physical_mesh_shape[i] and s > 1] + ) + + if x_num_whole_axes != y_num_whole_axes: + return x_num_whole_axes > y_num_whole_axes + + # Prefer taking physical axes that are not taken by logical axes of higher + # network intensity. E.g., for a 4 x 4 x 4, suppose that the previous + # assignments are 1 x 2 x 4, and we want to place a new logical axis of size + # 2, we will go for [2, 1, 1] instead of [1, 2, 1], as the latter choice will + # tap into bandwidth already taken by the higher intensity axis. + assigned_physical_mesh_shape = np.prod(assignment, axis=-1) + + x_non_overlapping_axis_size = np.prod( + [s for i, s in enumerate(x) if assigned_physical_mesh_shape[i] > 1] + ) + y_non_overlapping_axis_size = np.prod( + [s for i, s in enumerate(y) if assigned_physical_mesh_shape[i] > 1] + ) + + if x_non_overlapping_axis_size != y_non_overlapping_axis_size: + return x_non_overlapping_axis_size > y_non_overlapping_axis_size + + # Otherwise sort by reverse lexical graphical order, to be consistent with + # existing behavior. + return tuple(x) > tuple(y) + + +def _generate_logical_mesh( + physical_mesh: np.ndarray, + logical_mesh_shape: Sequence[int], + assignment: np.ndarray, +) -> np.ndarray: + """Compute the logical mesh from assignment map. + + Args: + physical_mesh: Physical device mesh. + logical_mesh_shape: Logical mesh shape. + assignment: 2-d assignment matrix shape [physical_dims, logical_dims]. + + Returns: + Logical mesh reshaped from physical mesh. + """ + physical_indices = np.broadcast_to( + np.expand_dims( + np.arange(len(physical_mesh.shape), dtype=np.int64), axis=-1 + ), + assignment.shape, + ).reshape([-1]) + + logical_indices = np.broadcast_to( + np.expand_dims( + np.arange(len(logical_mesh_shape), dtype=np.int64), axis=0 + ), + assignment.shape, + ).reshape([-1]) + + # Axes of logical mesh is ordered by (physical_axis, logical_axis). + # + # Note that we sort for each physical_axis the logical_axis, so that higher + # intensity logical axes are replicated at inner (minor) dimensions. + # + # E.g., if a dimension size is 12 = 3x4, where 3 is higher intensity and 4 + # is lower, we want to reshape so that it becomes 12 = 4x3. Imagine in the + # 1-d case, this will allow more connections between the higher intensity + # axes. + logical_mesh = np.reshape(physical_mesh, assignment.reshape([-1])) + + # We will then group by l_axis as this is what is expected from output. + _, _, transpose_axes = zip( + *sorted( + zip(logical_indices, physical_indices, range(len(logical_indices))) + ) + ) + logical_mesh = np.transpose(logical_mesh, transpose_axes) + + # Reshape to add the trivial dimensions back. + logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) + + return logical_mesh + + +def _bounds_from_last_device(last_device) -> Sequence[int]: + """Gets the bound from the given last device.""" + # Must be passed the device at the highest-coordinate corner of the + # relevant mesh, which is a requirement we know is satisfied by the last + # device in jax.devices(). + assert hasattr(last_device, 'coords'), 'Only TPU supported' + x, y, z = last_device.coords + return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 + + +def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray: + r"""Rearrange TPU devices in a slice into a physical mesh. + + Args: + jax_devices: A list of JAX devices in a TPU slice in process-tiled z, y, x, + core order, e.g. from jax.devices(). The coordinates of these devices + should constitute a cuboid with no holes; e.g., the coordinates can be + {(1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)} (a 1x2x2 cuboid); passing + only 3 of these devices would result in a "hole" in that cuboid, which is + an error. As in our example, the cuboid is not required to include the + point (0, 0, 0). + + Returns: + A np.ndarray of JAX devices with shape [global_x, global_y, global_z]. On + v2 and v3, global_z is instead cores_per_chip (i.e., 2). + """ + device_kind = jax_devices[0].device_kind + device_coords = [d.coords for d in jax_devices] + coord_size = len(device_coords[0]) + # Position-wise max and min coordinates: + max_coords = tuple( + max(dc[i] for dc in device_coords) for i in range(coord_size) + ) + min_coords = tuple( + min(dc[i] for dc in device_coords) for i in range(coord_size) + ) + dims = tuple(h - l + 1 for (h, l) in zip(max_coords, min_coords)) + + max_cores_per_chip = max(d.core_on_chip for d in jax_devices) + min_cores_per_chip = min(d.core_on_chip for d in jax_devices) + cores_per_chip = max_cores_per_chip - min_cores_per_chip + 1 + + assert len(dims) == 3, dims + assert ( + len(jax_devices) == np.prod(dims) * cores_per_chip + ), f'{jax_devices=} {dims=} {cores_per_chip=}' + + if device_kind in (_TPU_V2, _TPU_V3): + out = np.empty(dims[:2] + (cores_per_chip,), dtype=object) + for d in jax_devices: + coords = d.coords + assert coords[2] == 0, d + out[ + coords[0] - min_coords[0], + coords[1] - min_coords[1], + d.core_on_chip - min_cores_per_chip, + ] = d + else: + out = np.empty(dims, dtype=object) + for d in jax_devices: + coords = d.coords + if d.core_on_chip != 0: + raise AssertionError( + 'Creating meshes for TPU >v3 requires one device per chip' + f' ("megacore" mode). Got device id {d.core_on_chip} for a device' + f' of kind {device_kind}: {d}.' + ) + out[ + coords[0] - min_coords[0], + coords[1] - min_coords[1], + coords[2] - min_coords[2], + ] = d + + # Check there is no "hole" in the mesh we constructed. + if (out == None).any(): # pylint: disable=singleton-comparison + raise AssertionError( + 'Constructed mesh contains a "hole"; probable cause: coordinates ' + f'of jax_devices are not a contiguous cuboid: {jax_devices}' + ) + return out + + +# jekbradbury's famous trick for creating contiguous submeshes (where available) +def _transpose_trick( + physical_mesh: np.ndarray, mesh_shape: Sequence[int] +) -> np.ndarray: + mesh_shape = tuple(mesh_shape) + topology = physical_mesh.shape + if topology not in _TRANSPOSE_TRICKS: + raise ValueError( + 'create_device_mesh cannot create contiguous submeshes for ' + f'physical mesh topology {topology}' + ) + + mesh_shape_no_trivial_dims: tuple[int, ...] = () + for dim_size in mesh_shape: + if dim_size != 1: + mesh_shape_no_trivial_dims += (dim_size,) + + if mesh_shape_no_trivial_dims not in _TRANSPOSE_TRICKS[topology]: + raise ValueError( + 'create_device_mesh cannot create contiguous submeshes for ' + f'mesh_shape {mesh_shape} and physical mesh topology {topology}. ' + f'Available mesh_shapes: {list(_TRANSPOSE_TRICKS[topology].keys())}' + ) + + return physical_mesh.transpose( + *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] + ) + + +def create_device_mesh( + mesh_shape: Sequence[int], + devices: Sequence[Any] | None = None, + *, + contiguous_submeshes: bool = False, + allow_split_physical_axes: bool = False, +) -> np.ndarray: + """Creates a performant device mesh for jax.sharding.Mesh. + + Args: + mesh_shape: shape of logical mesh, ordered by increasing network-intensity + e.g. [replica, data, mdl] where mdl has the most network communication + requirements. + devices: optionally, the devices to construct a mesh for. Defaults to + jax.devices(). + contiguous_submeshes: if True, this function will attempt to create a mesh + where each process's local devices form a contiguous submesh. A ValueError + will be raised if this function can't produce a suitable mesh. This + setting was sometimes necessary before the introduction of jax.Array to + ensure non-ragged local arrays; if using jax.Arrays, it's better to keep + this set to False. + allow_split_physical_axes: If True, we will split physical axes if necessary + to produce the desired device mesh. + + Raises: + ValueError: if the number of devices doesn't equal the product of + `mesh_shape`. + + Returns: + A np.ndarray of JAX devices with mesh_shape as its shape that can be fed + into jax.sharding.Mesh with good collective performance. + """ + if devices is None: + devices = xb.devices() + if np.prod(mesh_shape) != len(devices): + raise ValueError( + f'Number of devices {len(devices)} must equal the product ' + f'of mesh_shape {mesh_shape}' + ) + last_device = devices[-1] + + handler = device_kind_handler_dict.get(last_device.device_kind, None) + if handler is not None: + result = handler( + mesh_shape, devices, contiguous_submeshes=contiguous_submeshes + ) + if result is not None: + return result + + if last_device.platform == 'tpu': + physical_mesh = _get_physical_tpu_mesh(devices) + if contiguous_submeshes: + physical_mesh = _transpose_trick(physical_mesh, mesh_shape) + device_mesh, _ = _create_device_mesh_for_nd_torus( + physical_mesh, + mesh_shape, + allow_split_physical_axes=allow_split_physical_axes, + ) + return device_mesh + else: + device_mesh = np.asarray(devices).reshape(mesh_shape) + return device_mesh + + +def create_hybrid_device_mesh( + mesh_shape: Sequence[int], + dcn_mesh_shape: Sequence[int], + devices: Sequence[Any] | None = None, + *, + process_is_granule: bool = False, + should_sort_granules_by_key: bool = True, + allow_split_physical_axes: bool = False, +) -> np.ndarray: + """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. + + Args: + mesh_shape: shape of the logical mesh for the faster/inner network, ordered + by increasing network intensity, e.g. [replica, data, mdl] where mdl has + the most network communication requirements. + dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in + the same order as mesh_shape. + devices: optionally, the devices to construct a mesh for. Defaults to + jax.devices(). + process_is_granule: if True, this function will treat processes as the units + of the slower/outer network. Otherwise it will look for slice_index + attributes on devices and use slices as the units. Enabling this is meant + as a fallback for platforms that don't set slice_index. + should_sort_granules_by_key: Whether device granules should be sorted by the + granule key, either slice or process index, depending on + process_is_granule. + allow_split_physical_axes: If True, we will split physical axes if necessary + to produce the desired device mesh. + + Raises: + ValueError: if the number of slices to which the `devices` belong doesn't + equal the product of `dcn_mesh_shape`, or if the number of devices + belonging to any single slice does not equal the product of `mesh_shape`. + + Returns: + A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape + that can be fed into jax.sharding.Mesh for hybrid parallelism. + """ + if devices is None: + devices = xb.devices() + attr = 'process_index' if process_is_granule else 'slice_index' + if not hasattr(devices[0], attr): + raise ValueError( + f'Device {devices[0]} does not have attribute {attr}. See' + ' `process_is_granule` option.' + ) + granule_dict = collections.defaultdict(list) + for dev in devices: + granule_dict[getattr(dev, attr)].append(dev) + granules = ( + [granule_dict[key] for key in sorted(granule_dict.keys())] + if should_sort_granules_by_key + else granule_dict.values() + ) + if np.prod(dcn_mesh_shape) != len(granules): + raise ValueError( + f'Number of slices {len(granules)} must equal the product of ' + f'dcn_mesh_shape {dcn_mesh_shape}' + ) + per_granule_meshes = [ + create_device_mesh( + mesh_shape, + granule, + allow_split_physical_axes=allow_split_physical_axes, + ) + for granule in granules + ] + # TODO(jekbradbury): handle non-uniform DCN topologies + granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) + blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])( + granule_mesh + ) + device_mesh = np.block(blocks.tolist()) + return device_mesh diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 00d408a73251..a1444c3a2345 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -31,6 +31,7 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge +from jax._src.mesh_utils import create_device_mesh from jax._src.lib import xla_client as xc from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) @@ -1679,3 +1680,56 @@ def _gspmd_to_named_sharding_via_mesh( return create_mesh_pspec_sharding( mesh, parsed_pspec.get_partition_spec(), parsed_pspec, out_s.memory_kind) + + +def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], + *, devices: Sequence[xc.Device] | None = None) -> mesh_lib.Mesh: + """Creates an efficient mesh with the shape and axis names specified. + + This function attempts to automatically compute a good mapping from a set of + logical axes to a physical mesh. For example, on a TPU v3 with 8 devices: + + >>> mesh = jax.make_mesh((8,), ('x')) # doctest: +SKIP + >>> [d.id for d in mesh.devices.flat] # doctest: +SKIP + [0, 1, 2, 3, 6, 7, 4, 5] + + The above ordering takes into account the physical topology of TPU v3. + It orders the devices into a ring, which yields efficient all-reduces on a + TPU v3. + + Now, let's see another example with 16 devices of TPU v3: + + >>> mesh = jax.make_mesh((2, 8), ('x', 'y')) # doctest: +SKIP + >>> [d.id for d in mesh.devices.flat] # doctest: +SKIP + [0, 1, 2, 3, 6, 7, 4, 5, 8, 9, 10, 11, 14, 15, 12, 13] + >>> mesh = jax.make_mesh((4, 4), ('x', 'y')) # doctest: +SKIP + >>> [d.id for d in mesh.devices.flat] # doctest: +SKIP + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + + As you can see, logical axes (`axis_shapes`) affect the ordering of the + devices. + + You can use `jax.experimental.mesh_utils.create_device_mesh` if you want to + use the extra arguments it provides like `contiguous_submeshes` and + `allow_split_physical_axes`. + + Args: + axis_shapes: Shape of the mesh. For example, axis_shape=(4, 2) + axis_names: Names of the mesh axes. For example, axis_names=('x', 'y') + devices: Optional keyword only argument, that allows you to specify the + devices you want to create a mesh with. + + Returns: + A `jax.sharding.Mesh` object. + """ + if devices is None: + devices = xla_bridge.devices() + axis_size = math.prod(axis_shapes) + if axis_size > len(devices): + raise ValueError( + f'Number of devices {len(devices)} must be >= the product ' + f'of mesh_shape {axis_shapes}') + elif axis_size < len(devices): + devices = devices[:axis_size] + mesh_devices = create_device_mesh(axis_shapes, devices) + return mesh_lib.Mesh(mesh_devices, axis_names) diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 7cac0338a923..075e4e6eed48 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -14,798 +14,8 @@ # ============================================================================== """Utils for building a device mesh.""" -from __future__ import annotations - -import collections -from collections.abc import Callable, Generator, MutableMapping, Sequence -import itertools -import logging -import math -from typing import Any - -from jax._src import xla_bridge as xb -import numpy as np - -logger = logging.getLogger(__name__) - -_TPU_V2 = 'TPU v2' -_TPU_V3 = 'TPU v3' -_TPU_V4 = 'TPU v4' -_TPU_V5_LITE = "TPU v5 lite" - -# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's -# famous contiguous mesh trick. -# -# The trick only works for certain topologies and mesh shapes. Trivial dims of -# size 1 can be added to the shapes listed, and they are also supported. -_TRANSPOSE_TRICKS: dict[ - tuple[int, ...], dict[tuple[int, ...], tuple[int, ...]] -] = { - (2, 2, 1): { - (2, 2): (0, 1, 2), - }, - (2, 2, 4): { - (4, 4): (0, 1, 2), - }, - (4, 4, 4): { - (16, 4): (0, 2, 1), - }, - (4, 8, 8): { - (64, 4): (0, 2, 1), - (4, 64): (0, 2, 1), - }, - (8, 8, 8): { - (64, 8): (0, 2, 1), - }, - (8, 16, 16): { - (256, 8): (0, 2, 1), - (8, 256): (0, 2, 1), - }, -} - -# Physical ordering of core IDs in a tray that creates a ring -_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) -_TRAY_2x2_RING_ORDER = (0, 1, 3, 2) -_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4) - -def _tpu_v2_v3_create_device_mesh( - mesh_shape: Sequence[int], - devices: Sequence[Any], - **unused_kwargs, -) -> np.ndarray: - if len(devices) == 8: - logger.info( - 'Reordering mesh to physical ring order on single-tray TPU v2/v3.' - ) - device_mesh = np.asarray(devices) - device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)] - device_mesh = device_mesh.reshape(mesh_shape) - return device_mesh - elif mesh_shape[-1] == 8: - device_mesh = np.asarray(devices).reshape(mesh_shape) - logger.info( - 'Reordering mesh to physical ring order on each TPU v2/v3 tray.' - ) - perm = np.array(_TRAY_RING_ORDER) - device_mesh = device_mesh[..., perm] - return device_mesh - else: - # TODO(skye): implement 2D mesh_shape logic here: - # https://github.com/tensorflow/lingvo/blob/0df40cf604dfcd14e28f7087d73687a0bd2fe5c6/lingvo/core/gshard_utils.py#L187 - # (possibly replaces above mesh_shape[-1] == 8 case) - return np.asarray(devices).reshape(mesh_shape) - - -def _vlc_create_device_mesh( - mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs -) -> np.ndarray | None: - """Creates rotated pincer device assignment for selected topologies. - - Args: - mesh_shape: Logical mesh shape used by the model. - devices: TPU devices. - **unused_kwargs: ... - - Returns: - None or reordered devices reshaped as `mesh_shape`. - """ - max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices) - bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1 - # Our ring re-ordering makes sense only if the passed-in devices are - # sequential, which may not always be the case. reversed() changes z-minor to - # x-minor. - sequential_devices = sorted( - devices, - key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) - - if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2 - device_mesh = np.asarray(sequential_devices) - device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)] - device_mesh = device_mesh.reshape(mesh_shape) - return device_mesh - - if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4 - # Only uses ring order if the whole mesh is a replica group. - if max(mesh_shape) == len(devices): - device_mesh = np.asarray(sequential_devices) - device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)] - device_mesh = device_mesh.reshape(mesh_shape) - return device_mesh - - return None - - -# Registers functions to create device mesh for specific device kinds. Takes -# precedence over the more general logic in create_device_mesh(). Handler may -# return None; in that case, it will fall back to using the default logic. -device_kind_handler_dict: dict[ - str, - Callable[..., np.ndarray | None], -] = { - _TPU_V2: _tpu_v2_v3_create_device_mesh, - _TPU_V3: _tpu_v2_v3_create_device_mesh, - _TPU_V5_LITE: _vlc_create_device_mesh, -} - - -def _create_device_mesh_for_nd_torus( - physical_mesh: np.ndarray, - mesh_shape: Sequence[int], - *, - allow_split_physical_axes: bool = False, -) -> tuple[np.ndarray, np.ndarray]: - """Assigns logical parallelism axes to physical axes of an N-D torus network. - - Given logical parallelism axes with sizes in `mesh_shape` and devices in an - N-dimensional torus network represented by `physical_mesh`, maps each logical - axis to one or more physical axes. Prefer to map more-performance-sensitive - logical axes to larger numbers of physical axes to maximize the bandwidth - available to them. Also prefer to assign logical axes to multiple physical - axes of the same size (e.g., a 2D square) rather than multiple physical axes - of different sizes when possible. - - If allow_split_physical_axes = False (default), this routine will error out - instead of splitting a physical axis over more than one logical axis (which - would reduce total usable bandwidth). - - Let's use a concrete example to explain the concepts and considerations. - - As an example, suppose the logical mesh is [data, model], for data and model - parallelism respectively. Also suppose that data parallelism is less - performance sensitive than model parallelism. Consider a 3D TPU pod slice of - shape 4x4x16, represented by a physical mesh of shape (4, 4, 16). - - A TPU pod slice has equal bandwidth along all axes with wraparound links, but - a 2D plane of size 4x4 may have faster XLA collective implementations than a - non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want - the more performance sensitive `model` axis to be mapped to the 4x4 XY plane. - - Args: - physical_mesh: a np.ndarray of devices in the shape of the N-D torus - physical topology. - mesh_shape: shape of the logical mesh (size of the various logical - parallelism axes), with axes ordered by increasing network intensity. - allow_split_physical_axes: If True, we would split physical axes if - necessary to fit the desired mesh shape. - - Returns: - An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with - each logical parallelism axis mapped to one or more physical mesh axes. - The axis assignment matrix, which is a 2-d array mapping from - (physical_axis, logical_axis) to the size assigned, with the invariant - np.prod(assignment, axis=1) = physical_mesh_shape, and - np.prod(assignment, axis=0) = mesh_shape. - """ - # Remaining physical axes to be assigned to logical axes. - assignable_physical_mesh = list(physical_mesh.shape) - # Map each logical axis to a subset of physical axes. - assignment: list[tuple[int, ...]] = [() for _ in mesh_shape] - - # Assign logical axes from highest network intensity to lowest. - # `mesh_shape` is assumed to ordered by lowest network intensity first, so - # reverse it first. - for logical_axis_index, logical_axis_size in reversed( - list(enumerate(mesh_shape)) - ): - # Preferentially map to more physical axes first for higher bandwidth. - for num_axes in range(3, 0, -1): - # Try assign to any subset of size num_axes. Generate all candidates. - indices_and_axes = itertools.combinations( - enumerate(assignable_physical_mesh), num_axes - ) - for elem in indices_and_axes: - c_indices, c_axes = zip(*elem) - # TODO(zhangqiaorjc): Due to limitations in XLA, 2D collectives only - # implemented for square 2D plane. Mapping a physical axis to two - # logical axes might be slower for non-square 2D plane, e.g., map 32 to - # 4x8 or a single axis. If XLA 2D collectives support non-square plane - # soon, we can continue to preferentially map to 2D plane in general, - # otherwise, we should treat non-square 2D plane and 1D submesh equally. - if np.prod(c_axes) == logical_axis_size: - assignment[logical_axis_index] = c_indices - # Zero the assigned physical axes. - assignable_physical_mesh = [ - 0 if i in c_indices else v - for i, v in enumerate(assignable_physical_mesh) - ] - break - if assignment[logical_axis_index]: - # We already found an assignment from one candidate above. - break - else: - # If the num_axes for loop did not break, i.e. none of the candidates work - # goto here with this while-else construct. - if logical_axis_size > 1: - if not allow_split_physical_axes: - # Although this is now implemented, there are downstream tasks - # counting on this being a NotImplementedError. - raise NotImplementedError( - 'Failed to find assignment for logical_axis_index' - f' {logical_axis_index} of size {logical_axis_size} with' - f' remaining assignable mesh {assignable_physical_mesh}. The size' - ' of each axis in your logical mesh must be equal to the product' - ' of some subset of the physical mesh axis sizes. E.g. logical' - ' mesh (4, 16) is compatible with physical mesh 4x4x4 since 4=4' - ' and 16=4x4. If you want to split physical axes, set ' - ' allow_split_physical_axes to True.' - ) - else: - # We will try finding an assignment, even if that means splitting the - # physical axes, which requires a more sophisticated implementation. - return _create_device_mesh_for_nd_torus_splitting_axes( - physical_mesh, mesh_shape - ) - - # Flatten the assignment, e.g., [(), (2,), (0, 1)] -> (2, 0, 1). - transpose: list[int] = [] - assignment_array = np.ones( - [len(physical_mesh.shape), len(mesh_shape)], dtype=np.int64 - ) - for i, x in enumerate(assignment): - for y in x: - physical_mesh_axis = int(y) - assignment_array[physical_mesh_axis, i] = physical_mesh.shape[ - physical_mesh_axis - ] - transpose.append(physical_mesh_axis) - return ( - physical_mesh.transpose(transpose).reshape(mesh_shape), - assignment_array, - ) - - -def _create_device_mesh_for_nd_torus_splitting_axes( - physical_mesh: np.ndarray, - mesh_shape: Sequence[int], -) -> tuple[np.ndarray, np.ndarray]: - """Assigns logical parallelism axes to physical axes of an N-D torus network. - - This implementation allows creating meshes that requires splitting physical - axes, and thus one could produce logical mesh of any shape, as long as the - number of devices matches, e.g., - - - Creating 2x2x4 from 4x4; - - - Creating 2x2x16 from 8x8; - - Args: - physical_mesh: a np.ndarray of devices in the shape of the N-D torus - physical topology. - mesh_shape: shape of the logical mesh (size of the various logical - parallelism axes), with axes ordered by increasing network intensity. - - Returns: - An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with - each logical parallelism axis mapped to one or more physical mesh axes. - The axis assignment matrix, which is a 2-d array mapping from - (physical_axis, logical_axis) to the size assigned, with the invariant - np.prod(assignment, axis=1) = physical_mesh_shape, and - np.prod(assignment, axis=0) = mesh_shape. - """ - if np.prod(physical_mesh.shape) != np.prod(mesh_shape): - raise ValueError( - 'The number of devices in physical mesh' - f' {physical_mesh.shape} does not match the number of devices' - f' in logical mesh {mesh_shape}.' - ) - - physical_mesh_shape = physical_mesh.shape - logical_mesh_shape = tuple(mesh_shape) - - # (Partial) assignment map as an 2-d array [p_axis, l_axis] -> size. - assignment = np.ones( - [len(physical_mesh_shape), len(logical_mesh_shape)], dtype=np.int64 - ) - - # Process logical axes from highest network intensity to lowest. - # `mesh_shape` is assumed to ordered by lowest network intensity first, so - # reverse it. - for logical_axis, logical_axis_size in reversed( - list(enumerate(logical_mesh_shape)) - ): - # Go over all the possible assignment for the logical axis, including the - # one that splits multiple physical axes. - best_logical_axis_assignment = None - for logical_axis_assignment in _enumerate_feasible_logical_axis_assignments( - physical_mesh_shape, assignment, logical_axis_size - ): - # TODO(rosun): Instead of using heuristics, replace this with a proper - # scoring function reflecting the underlying hardware properties. - if ( - best_logical_axis_assignment is None - or _prefer_first_logical_axis_assignment( - logical_axis_assignment, - best_logical_axis_assignment, - physical_mesh_shape=physical_mesh_shape, - assignment=assignment, - ) - ): - best_logical_axis_assignment = logical_axis_assignment - assignment[:, logical_axis] = best_logical_axis_assignment - - # Read out the assignment. - logical_mesh = _generate_logical_mesh( - physical_mesh, logical_mesh_shape, assignment - ) - - return logical_mesh, assignment - - -def _get_prime_factors(x: int) -> list[int]: - """Returns a sorted list of prime factors for the given number.""" - assert x > 0 - factors = [] - for p in range(2, math.isqrt(x) + 2): - while x % p == 0: - factors.append(p) - x //= p - if x == 1: - return factors - else: - return [x] # x is a prime number. - - -def _enumerate_feasible_logical_axis_assignments( - physical_mesh_shape: Sequence[int], - assignment: np.ndarray, - logical_axis_size: int, -) -> Generator[np.ndarray, None, None]: - """Yields feasible assignments for a single logical axis. - - For a physical mesh of shape [x_1, ..., x_n], and the product of all previous - assignments on each physical axes [y_1, ..., y_n], this function yields all - possible assignments for the axis as 1-d arrays [z_1, ..., z_n], so that: - - - prod(z_1, ..., z_n) = logical_axis_size - - - x_i % (z_i * y_i) = 0 - - Args: - physical_mesh_shape: Physical mesh shape. - assignment: Existing assignment matrix. - logical_axis_size: Size of the logical axis to assign. - - Yields: - All valid assignments for the logical axis. Each assignment is represented - as an integer array of length len(physical_mesh_shape). - """ - logical_axis_factors: MutableMapping[int, int] = collections.defaultdict(int) - for factor in _get_prime_factors(logical_axis_size): - logical_axis_factors[factor] += 1 - - available_physical_mesh_shape = np.array(physical_mesh_shape) // np.prod( - assignment, axis=-1 - ) - - # To enable efficient enumerations, we first index physical axes by their - # prime factors. Since we know the prime factorization of the logical axis - # size, we could simply enumerate by picking the correct count for each - # prime factor. - physical_axes_by_factor: MutableMapping[int, list[int]] = ( - collections.defaultdict(list) - ) - for physical_axis, physical_axis_size in enumerate( - available_physical_mesh_shape - ): - for factor in _get_prime_factors(physical_axis_size): - if factor not in logical_axis_factors: - continue - physical_axes_by_factor[factor].append(physical_axis) - - factors = [] - assignments_by_factor = [] - for factor, multiplicity in logical_axis_factors.items(): - factors.append(factor) - assignments_by_factor.append( - set( - itertools.combinations( - physical_axes_by_factor[factor], multiplicity - ) - ) - ) - - for axis_assignment in itertools.product(*assignments_by_factor): - result = np.ones([len(physical_mesh_shape)], dtype=np.int64) - for factor_index, per_factor_assignment in enumerate(axis_assignment): - for physical_axis in per_factor_assignment: - result[physical_axis] *= factors[factor_index] - yield result - - -def _prefer_first_logical_axis_assignment( - x: np.ndarray, - y: np.ndarray, - *, - physical_mesh_shape: Sequence[int], - assignment: np.ndarray, -) -> bool: - """Returns True if the first axis assignment is preferred over the second. - - For now, this is implemented with some very simple heuristics. However, - it is possible to introduce e.g., a value function here based on a more - precise model of the underlying hardware. - - TODO(rosun): Use a proxy of network capacity to select the partitions. - - Args: - x: Logical axis assignment as [len(physical_mesh_shape)] array. - y: Logical axis assignment as [len(physical_mesh_shape)] array. - physical_mesh_shape: Physical mesh shape. - assignment: Assignment matrix. - - Returns: - True if x is preferred over y. - """ - # Prefer occupying complete physical axes. I don't have a good reason for - # this, except that it is compatible with the existing behavior. - # - # E.g., on 4 x 4 x 8, [4, 4, -] will be preferred over [4, -, 4], and then - # over [2, 2, 4]. - x_whole_axis_size = np.prod( - [s for i, s in enumerate(x) if s == physical_mesh_shape[i]] - ) - y_whole_axis_size = np.prod( - [s for i, s in enumerate(y) if s == physical_mesh_shape[i]] - ) - - if x_whole_axis_size != y_whole_axis_size: - return x_whole_axis_size > y_whole_axis_size - - # Prefer occupying more whole physical axes for better bandwidth. - # - # This is consistent with existing logic, i.e., 2 x 2 is preferred over 4. - x_num_whole_axes = len( - [1 for i, s in enumerate(x) if s == physical_mesh_shape[i] and s > 1] - ) - y_num_whole_axes = len( - [1 for i, s in enumerate(y) if s == physical_mesh_shape[i] and s > 1] - ) - - if x_num_whole_axes != y_num_whole_axes: - return x_num_whole_axes > y_num_whole_axes - - # Prefer taking physical axes that are not taken by logical axes of higher - # network intensity. E.g., for a 4 x 4 x 4, suppose that the previous - # assignments are 1 x 2 x 4, and we want to place a new logical axis of size - # 2, we will go for [2, 1, 1] instead of [1, 2, 1], as the latter choice will - # tap into bandwidth already taken by the higher intensity axis. - assigned_physical_mesh_shape = np.prod(assignment, axis=-1) - - x_non_overlapping_axis_size = np.prod( - [s for i, s in enumerate(x) if assigned_physical_mesh_shape[i] > 1] - ) - y_non_overlapping_axis_size = np.prod( - [s for i, s in enumerate(y) if assigned_physical_mesh_shape[i] > 1] - ) - - if x_non_overlapping_axis_size != y_non_overlapping_axis_size: - return x_non_overlapping_axis_size > y_non_overlapping_axis_size - - # Otherwise sort by reverse lexical graphical order, to be consistent with - # existing behavior. - return tuple(x) > tuple(y) - - -def _generate_logical_mesh( - physical_mesh: np.ndarray, - logical_mesh_shape: Sequence[int], - assignment: np.ndarray, -) -> np.ndarray: - """Compute the logical mesh from assignment map. - - Args: - physical_mesh: Physical device mesh. - logical_mesh_shape: Logical mesh shape. - assignment: 2-d assignment matrix shape [physical_dims, logical_dims]. - - Returns: - Logical mesh reshaped from physical mesh. - """ - physical_indices = np.broadcast_to( - np.expand_dims( - np.arange(len(physical_mesh.shape), dtype=np.int64), axis=-1 - ), - assignment.shape, - ).reshape([-1]) - - logical_indices = np.broadcast_to( - np.expand_dims( - np.arange(len(logical_mesh_shape), dtype=np.int64), axis=0 - ), - assignment.shape, - ).reshape([-1]) - - # Axes of logical mesh is ordered by (physical_axis, logical_axis). - # - # Note that we sort for each physical_axis the logical_axis, so that higher - # intensity logical axes are replicated at inner (minor) dimensions. - # - # E.g., if a dimension size is 12 = 3x4, where 3 is higher intensity and 4 - # is lower, we want to reshape so that it becomes 12 = 4x3. Imagine in the - # 1-d case, this will allow more connections between the higher intensity - # axes. - logical_mesh = np.reshape(physical_mesh, assignment.reshape([-1])) - - # We will then group by l_axis as this is what is expected from output. - _, _, transpose_axes = zip( - *sorted( - zip(logical_indices, physical_indices, range(len(logical_indices))) - ) - ) - logical_mesh = np.transpose(logical_mesh, transpose_axes) - - # Reshape to add the trivial dimensions back. - logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) - - return logical_mesh - - -def _bounds_from_last_device(last_device) -> Sequence[int]: - """Gets the bound from the given last device.""" - # Must be passed the device at the highest-coordinate corner of the - # relevant mesh, which is a requirement we know is satisfied by the last - # device in jax.devices(). - assert hasattr(last_device, 'coords'), 'Only TPU supported' - x, y, z = last_device.coords - return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 - - -def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray: - r"""Rearrange TPU devices in a slice into a physical mesh. - - Args: - jax_devices: A list of JAX devices in a TPU slice in process-tiled z, y, x, - core order, e.g. from jax.devices(). The coordinates of these devices - should constitute a cuboid with no holes; e.g., the coordinates can be - {(1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)} (a 1x2x2 cuboid); passing - only 3 of these devices would result in a "hole" in that cuboid, which is - an error. As in our example, the cuboid is not required to include the - point (0, 0, 0). - - Returns: - A np.ndarray of JAX devices with shape [global_x, global_y, global_z]. On - v2 and v3, global_z is instead cores_per_chip (i.e., 2). - """ - device_kind = jax_devices[0].device_kind - device_coords = [d.coords for d in jax_devices] - coord_size = len(device_coords[0]) - # Position-wise max and min coordinates: - max_coords = tuple( - max(dc[i] for dc in device_coords) for i in range(coord_size) - ) - min_coords = tuple( - min(dc[i] for dc in device_coords) for i in range(coord_size) - ) - dims = tuple(h - l + 1 for (h, l) in zip(max_coords, min_coords)) - - max_cores_per_chip = max(d.core_on_chip for d in jax_devices) - min_cores_per_chip = min(d.core_on_chip for d in jax_devices) - cores_per_chip = max_cores_per_chip - min_cores_per_chip + 1 - - assert len(dims) == 3, dims - assert ( - len(jax_devices) == np.prod(dims) * cores_per_chip - ), f'{jax_devices=} {dims=} {cores_per_chip=}' - - if device_kind in (_TPU_V2, _TPU_V3): - out = np.empty(dims[:2] + (cores_per_chip,), dtype=object) - for d in jax_devices: - coords = d.coords - assert coords[2] == 0, d - out[ - coords[0] - min_coords[0], - coords[1] - min_coords[1], - d.core_on_chip - min_cores_per_chip, - ] = d - else: - out = np.empty(dims, dtype=object) - for d in jax_devices: - coords = d.coords - if d.core_on_chip != 0: - raise AssertionError( - 'Creating meshes for TPU >v3 requires one device per chip' - f' ("megacore" mode). Got device id {d.core_on_chip} for a device' - f' of kind {device_kind}: {d}.' - ) - out[ - coords[0] - min_coords[0], - coords[1] - min_coords[1], - coords[2] - min_coords[2], - ] = d - - # Check there is no "hole" in the mesh we constructed. - if (out == None).any(): # pylint: disable=singleton-comparison - raise AssertionError( - 'Constructed mesh contains a "hole"; probable cause: coordinates ' - f'of jax_devices are not a contiguous cuboid: {jax_devices}' - ) - return out - - -# jekbradbury's famous trick for creating contiguous submeshes (where available) -def _transpose_trick( - physical_mesh: np.ndarray, mesh_shape: Sequence[int] -) -> np.ndarray: - mesh_shape = tuple(mesh_shape) - topology = physical_mesh.shape - if topology not in _TRANSPOSE_TRICKS: - raise ValueError( - 'create_device_mesh cannot create contiguous submeshes for ' - f'physical mesh topology {topology}' - ) - - mesh_shape_no_trivial_dims: tuple[int, ...] = () - for dim_size in mesh_shape: - if dim_size != 1: - mesh_shape_no_trivial_dims += (dim_size,) - - if mesh_shape_no_trivial_dims not in _TRANSPOSE_TRICKS[topology]: - raise ValueError( - 'create_device_mesh cannot create contiguous submeshes for ' - f'mesh_shape {mesh_shape} and physical mesh topology {topology}. ' - f'Available mesh_shapes: {list(_TRANSPOSE_TRICKS[topology].keys())}' - ) - - return physical_mesh.transpose( - *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] - ) - - -def create_device_mesh( - mesh_shape: Sequence[int], - devices: Sequence[Any] | None = None, - *, - contiguous_submeshes: bool = False, - allow_split_physical_axes: bool = False, -) -> np.ndarray: - """Creates a performant device mesh for jax.sharding.Mesh. - - Args: - mesh_shape: shape of logical mesh, ordered by increasing network-intensity - e.g. [replica, data, mdl] where mdl has the most network communication - requirements. - devices: optionally, the devices to construct a mesh for. Defaults to - jax.devices(). - contiguous_submeshes: if True, this function will attempt to create a mesh - where each process's local devices form a contiguous submesh. A ValueError - will be raised if this function can't produce a suitable mesh. This - setting was sometimes necessary before the introduction of jax.Array to - ensure non-ragged local arrays; if using jax.Arrays, it's better to keep - this set to False. - allow_split_physical_axes: If True, we will split physical axes if necessary - to produce the desired device mesh. - - Raises: - ValueError: if the number of devices doesn't equal the product of - `mesh_shape`. - - Returns: - A np.ndarray of JAX devices with mesh_shape as its shape that can be fed - into jax.sharding.Mesh with good collective performance. - """ - if devices is None: - devices = xb.devices() - if np.prod(mesh_shape) != len(devices): - raise ValueError( - f'Number of devices {len(devices)} must equal the product ' - f'of mesh_shape {mesh_shape}' - ) - last_device = devices[-1] - - handler = device_kind_handler_dict.get(last_device.device_kind, None) - if handler is not None: - result = handler( - mesh_shape, devices, contiguous_submeshes=contiguous_submeshes - ) - if result is not None: - return result - - if last_device.platform == 'tpu': - physical_mesh = _get_physical_tpu_mesh(devices) - if contiguous_submeshes: - physical_mesh = _transpose_trick(physical_mesh, mesh_shape) - device_mesh, _ = _create_device_mesh_for_nd_torus( - physical_mesh, - mesh_shape, - allow_split_physical_axes=allow_split_physical_axes, - ) - return device_mesh - else: - device_mesh = np.asarray(devices).reshape(mesh_shape) - return device_mesh - - -def create_hybrid_device_mesh( - mesh_shape: Sequence[int], - dcn_mesh_shape: Sequence[int], - devices: Sequence[Any] | None = None, - *, - process_is_granule: bool = False, - should_sort_granules_by_key: bool = True, - allow_split_physical_axes: bool = False, -) -> np.ndarray: - """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. - - Args: - mesh_shape: shape of the logical mesh for the faster/inner network, ordered - by increasing network intensity, e.g. [replica, data, mdl] where mdl has - the most network communication requirements. - dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in - the same order as mesh_shape. - devices: optionally, the devices to construct a mesh for. Defaults to - jax.devices(). - process_is_granule: if True, this function will treat processes as the units - of the slower/outer network. Otherwise it will look for slice_index - attributes on devices and use slices as the units. Enabling this is meant - as a fallback for platforms that don't set slice_index. - should_sort_granules_by_key: Whether device granules should be sorted by the - granule key, either slice or process index, depending on - process_is_granule. - allow_split_physical_axes: If True, we will split physical axes if necessary - to produce the desired device mesh. - - Raises: - ValueError: if the number of slices to which the `devices` belong doesn't - equal the product of `dcn_mesh_shape`, or if the number of devices - belonging to any single slice does not equal the product of `mesh_shape`. - - Returns: - A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape - that can be fed into jax.sharding.Mesh for hybrid parallelism. - """ - if devices is None: - devices = xb.devices() - attr = 'process_index' if process_is_granule else 'slice_index' - if not hasattr(devices[0], attr): - raise ValueError( - f'Device {devices[0]} does not have attribute {attr}. See' - ' `process_is_granule` option.' - ) - granule_dict = collections.defaultdict(list) - for dev in devices: - granule_dict[getattr(dev, attr)].append(dev) - granules = ( - [granule_dict[key] for key in sorted(granule_dict.keys())] - if should_sort_granules_by_key - else granule_dict.values() - ) - if np.prod(dcn_mesh_shape) != len(granules): - raise ValueError( - f'Number of slices {len(granules)} must equal the product of ' - f'dcn_mesh_shape {dcn_mesh_shape}' - ) - per_granule_meshes = [ - create_device_mesh( - mesh_shape, - granule, - allow_split_physical_axes=allow_split_physical_axes, - ) - for granule in granules - ] - # TODO(jekbradbury): handle non-uniform DCN topologies - granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) - blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])( - granule_mesh - ) - device_mesh = np.block(blocks.tolist()) - return device_mesh +from jax._src.mesh_utils import ( + create_device_mesh as create_device_mesh, + create_hybrid_device_mesh as create_hybrid_device_mesh, + device_kind_handler_dict as device_kind_handler_dict, +) diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 6916c2c37e56..b182caf2dcd2 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -24,7 +24,7 @@ from jax._src import mesh as mesh_lib from jax._src import test_util from jax._src.sharding_impls import NamedSharding, PartitionSpec, local_to_global_shape -from jax.experimental import mesh_utils +from jax._src import mesh_utils from jax.sharding import Mesh # pylint: disable=g-importing-member import numpy as np diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 97dd65176a79..1efbcc146d13 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4352,7 +4352,9 @@ def test_convert_element_type_sharding(self): self.assertEqual(out.sharding, s) def test_jnp_array_sharding(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + if jax.device_count() < 4: + self.skipTest('Requires >=4 devices') + mesh = jax.make_mesh((2, 2), ('x', 'y'), devices=jax.devices()[:4]) s = NamedSharding(mesh, P('x', 'y')) inp = np.arange(16).reshape(8, 2) @@ -4361,7 +4363,9 @@ def test_jnp_array_sharding(self): self.assertEqual(out.sharding, s) def test_jnp_array_inside_jit_sharding(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + if jax.device_count() < 4: + self.skipTest('Requires >=4 devices') + mesh = jax.make_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = np.arange(16).reshape(8, 2) From d0d7493aae6a93036fe7ecb01089cc589393acdd Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 3 Sep 2024 15:04:01 -0700 Subject: [PATCH 343/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/950df464409d4b27c0cb452f78aa221b89e60672. PiperOrigin-RevId: 670719765 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index d21e51836919..c9aa0b8d61a1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d18cd64b7cd61a2ade10089665ac104f639101b1" -XLA_SHA256 = "12d2a18d4f7549305c7949a0d13504e9c3de464792cf72c8d92d62dc414c8ff1" +XLA_COMMIT = "950df464409d4b27c0cb452f78aa221b89e60672" +XLA_SHA256 = "0fec02fa0838f7b2d67482488a58016eeec9393c1230a6b5206f0b0fa8e3eb96" def repo(): tf_http_archive( From 1289640f097562ad3c3f4446bcd2324e86310685 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Sep 2024 15:15:19 -0700 Subject: [PATCH 344/702] Deprecated calling ``jax.dlpack.from_dlpack`` with a DLPack tensor PiperOrigin-RevId: 670723176 --- CHANGELOG.md | 3 +++ jax/_src/deprecations.py | 1 + jax/_src/dlpack.py | 43 +++++++++++++++++++++++++--------------- pyproject.toml | 5 ++++- 4 files changed, 35 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c633ebb40a50..e310b296b05f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. the future. * `jax.numpy.round_` has been deprecated, following removal of the corresponding API in NumPy 2.0. Use {func}`jax.numpy.round` instead. + * Passing a DLPack capsule to {func}`jax.dlpack.from_dlpack` is deprecated. + The argument to {func}`jax.dlpack.from_dlpack` should be an array from + another framework that implements the ``__dlpack__`` protocol. ## jaxlib 0.4.32 diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 96eca4ccf45c..4e01c88afd1f 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -121,6 +121,7 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: # Register a number of deprecations: we do this here to ensure they're # always registered by the time `accelerate` and `is_acelerated` are called. +register('jax-dlpack-import-legacy') register("jax-numpy-astype-complex-to-real") register("jax-numpy-array-none") register('jax-scipy-beta-args') diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 386123ae61f0..ac976234eda5 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -16,14 +16,17 @@ from typing import Any -from jax._src.api import device_put from jax import numpy as jnp from jax._src import array +from jax._src import deprecations from jax._src import xla_bridge +from jax._src.api import device_put from jax._src.lax.lax import _array_copy from jax._src.lib import xla_client -from jax._src.typing import Array, DLDeviceType from jax._src.sharding import Sharding +from jax._src.typing import Array +from jax._src.typing import DLDeviceType + DLPACK_VERSION = (0, 8) MIN_DLPACK_VERSION = (0, 5) @@ -237,21 +240,19 @@ def from_dlpack(external_array, device transfer or copy was requested. Args: - external_array: An array object that has __dlpack__ and __dlpack_device__ - methods, or a DLPack tensor on either CPU or GPU (legacy API). - + external_array: An array object that has ``__dlpack__` and + ``__dlpack_device__`` methods. device: The (optional) :py:class:`Device`, representing the device on which - the returned array should be placed. If given, then the result is committed - to the device. If unspecified, the resulting array will be unpacked onto the - same device it originated from. Setting ``device`` to a device different from - the source of ``external_array`` will require a copy, meaning ``copy`` must be - set to either ``True`` or ``None``. - + the returned array should be placed. If given, then the result is + committed to the device. If unspecified, the resulting array will be + unpacked onto the same device it originated from. Setting ``device`` to a + device different from the source of ``external_array`` will require a + copy, meaning ``copy`` must be set to either ``True`` or ``None``. copy: An (optional) boolean, controlling whether or not a copy is performed. - If ``copy=True`` then a copy is always performed, even if unpacked onto the - same device. If ``copy=False`` then the copy is never performed and will raise - an error if necessary. When ``copy=None`` then a copy may be performed if - needed for a device transfer. + If ``copy=True`` then a copy is always performed, even if unpacked onto + the same device. If ``copy=False`` then the copy is never performed and + will raise an error if necessary. When ``copy=None`` then a copy may be + performed if needed for a device transfer. Returns: A jax.Array @@ -274,5 +275,15 @@ def from_dlpack(external_array, if hasattr(external_array, "__dlpack__"): return _from_dlpack(external_array, device, copy) - # Legacy path + # Deprecated legacy path. + # TODO(slebedev): Remove on or after December 3rd 2023. + deprecations.warn( + "jax-dlpack-import-legacy", + ( + "Calling from_dlpack with a DLPack tensor is deprecated. The argument" + " to from_dlpack should be an array from another framework that" + " implements the __dlpack__ protocol." + ), + stacklevel=2, + ) return _legacy_from_dlpack(external_array, device, copy) diff --git a/pyproject.toml b/pyproject.toml index fb706dbbf8ea..23135b95e126 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,9 @@ filterwarnings = [ # TODO(slebedev): Remove once we bump the minimum TensorFlow version. "default:The key path API is deprecated .*", "default:jax.xla_computation is deprecated.*:DeprecationWarning", + + # TODO(slebedev): Remove once we drop the legacy DLPack import path. + "default:.*from_dlpack with a DLPack tensor is deprecated.*:DeprecationWarning", ] doctest_optionflags = [ "NUMBER", @@ -115,7 +118,7 @@ ignore = [ "C408", # Unnecessary map usage "C417", - # Unnecessary dict comprehension for iterable + # Unnecessary dict comprehension for iterable "C420", # Object names too complex "C901", From 78212ae39e5a2d99071cbdf5d4469c1bd99d0cdf Mon Sep 17 00:00:00 2001 From: Piseth Ky Date: Thu, 20 Jun 2024 13:42:52 -0700 Subject: [PATCH 345/702] better true_divide and divide docs doc wording update --- jax/_src/numpy/ufuncs.py | 38 ++++++++++++++++++++++++++++++++++++-- tests/lax_numpy_test.py | 2 +- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 71adb5d1e1ec..64b5235220cf 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -1035,13 +1035,47 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1)) -@implements(np.true_divide, module='numpy') @partial(jit, inline=True) def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Calculates the division of x1 by x2 element-wise + + JAX implementation of :func:`numpy.true_divide`. + + Args: + x1: Input array, the dividend + x2: Input array, the divisor + + Returns: + An array containing the elementwise quotients, will always use + floating point division. + + Examples: + >>> x1 = jnp.array([3, 4, 5]) + >>> x2 = 2 + >>> jnp.true_divide(x1, x2) + Array([1.5, 2. , 2.5], dtype=float32) + + >>> x1 = 24 + >>> x2 = jnp.array([3, 4, 6j]) + >>> jnp.true_divide(x1, x2) + Array([8.+0.j, 6.+0.j, 0.-4.j], dtype=complex64) + + >>> x1 = jnp.array([1j, 9+5j, -4+2j]) + >>> x2 = 3j + >>> jnp.true_divide(x1, x2) + Array([0.33333334+0.j , 1.6666666 -3.j , + 0.6666667 +1.3333334j], dtype=complex64) + + See Also: + :func:`jax.numpy.floor_divide` for integer division + """ x1, x2 = promote_args_inexact("true_divide", x1, x2) return lax.div(x1, x2) -divide = true_divide + +def divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.true_divide`.""" + return true_divide(x1, x2) @jit diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 0e1e7937e8dd..3a5d1def4906 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6288,7 +6288,7 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', - 'amax', 'amin', 'around', 'round_'] + 'amax', 'amin', 'around', 'divide', 'round_'] for name in dir(jnp): if name.startswith('_') or name in unimplemented: From 7569dd54389b40e1ab70820ef7b8a138c4edb729 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Sep 2024 16:04:23 -0700 Subject: [PATCH 346/702] Update sharded-computation doc to use make_mesh() --- docs/sharded-computation.ipynb | 17 +++++------------ docs/sharded-computation.md | 17 +++++------------ 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index 22d9156f607b..60bf4d41a7a6 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -188,12 +188,9 @@ } ], "source": [ - "# Pardon the boilerplate; constructing a sharding will become easier in future!\n", - "from jax.experimental import mesh_utils\n", + "from jax.sharding import PartitionSpec as P\n", "\n", - "P = jax.sharding.PartitionSpec\n", - "devices = mesh_utils.create_device_mesh((2, 4))\n", - "mesh = jax.sharding.Mesh(devices, ('x', 'y'))\n", + "mesh = jax.make_mesh((2, 4), ('x', 'y'))\n", "sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))\n", "print(sharding)" ] @@ -402,9 +399,7 @@ "@jax.jit\n", "def f_contract_2(x):\n", " out = x.sum(axis=0)\n", - " # mesh = jax.create_mesh((8,), 'x')\n", - " devices = mesh_utils.create_device_mesh(8)\n", - " mesh = jax.sharding.Mesh(devices, 'x')\n", + " mesh = jax.make_mesh((8,), ('x',))\n", " sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", " return jax.lax.with_sharding_constraint(out, sharding)\n", "\n", @@ -457,8 +452,7 @@ ], "source": [ "from jax.experimental.shard_map import shard_map\n", - "P = jax.sharding.PartitionSpec\n", - "mesh = jax.sharding.Mesh(jax.devices(), 'x')\n", + "mesh = jax.make_mesh((8,), ('x',))\n", "\n", "f_elementwise_sharded = shard_map(\n", " f_elementwise,\n", @@ -656,8 +650,7 @@ } ], "source": [ - "P = jax.sharding.PartitionSpec\n", - "mesh = jax.sharding.Mesh(jax.devices(), 'x')\n", + "mesh = jax.make_mesh((8,), ('x',))\n", "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", "\n", "x_sharded = jax.device_put(x, sharding)\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index e6e1948f9902..ef4dc2d3288d 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -72,12 +72,9 @@ Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimens ```{code-cell} :outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5 -# Pardon the boilerplate; constructing a sharding will become easier in future! -from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec as P -P = jax.sharding.PartitionSpec -devices = mesh_utils.create_device_mesh((2, 4)) -mesh = jax.sharding.Mesh(devices, ('x', 'y')) +mesh = jax.make_mesh((2, 4), ('x', 'y')) sharding = jax.sharding.NamedSharding(mesh, P('x', 'y')) print(sharding) ``` @@ -146,9 +143,7 @@ For example, suppose that within `f_contract` above, you'd prefer the output not @jax.jit def f_contract_2(x): out = x.sum(axis=0) - # mesh = jax.create_mesh((8,), 'x') - devices = mesh_utils.create_device_mesh(8) - mesh = jax.sharding.Mesh(devices, 'x') + mesh = jax.make_mesh((8,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) return jax.lax.with_sharding_constraint(out, sharding) @@ -174,8 +169,7 @@ In the automatic parallelism methods explored above, you can write a function as :outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2 from jax.experimental.shard_map import shard_map -P = jax.sharding.PartitionSpec -mesh = jax.sharding.Mesh(jax.devices(), 'x') +mesh = jax.make_mesh((8,), ('x',)) f_elementwise_sharded = shard_map( f_elementwise, @@ -265,8 +259,7 @@ If you shard the leading axis of both `x` and `weights` in the same way, then th ```{code-cell} :outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5 -P = jax.sharding.PartitionSpec -mesh = jax.sharding.Mesh(jax.devices(), 'x') +mesh = jax.make_mesh((8,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) x_sharded = jax.device_put(x, sharding) From e1b497078efe202fbc59dbdd19833482d81a3b12 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 3 Sep 2024 16:22:23 -0700 Subject: [PATCH 347/702] Rename `jtu.create_global_mesh` to `jtu.create_mesh` and use `jax.make_mesh` inside `jtu.create_mesh` to get maximum test coverage of the new API. PiperOrigin-RevId: 670744047 --- jax/_src/test_util.py | 13 +- .../array_serialization/serialization_test.py | 32 ++- tests/array_test.py | 116 ++++---- tests/debugging_primitives_test.py | 4 +- tests/layout_test.py | 34 +-- tests/memories_test.py | 66 ++--- tests/multiprocess_gpu_test.py | 4 +- tests/pgle_test.py | 8 +- tests/pjit_test.py | 262 +++++++++--------- tests/random_test.py | 4 +- tests/shard_alike_test.py | 26 +- tests/shard_map_test.py | 120 ++++---- 12 files changed, 346 insertions(+), 343 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 72533e619fe6..102afc42d079 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1375,15 +1375,16 @@ def with_and_without_mesh(f): ('Mesh', (('x', 2),), (('i', 'x'),)) ))(with_mesh_from_kwargs(f)) -def create_global_mesh(mesh_shape, axis_names): +def create_mesh(mesh_shape, axis_names, iota_order=False): size = math.prod(mesh_shape) if len(jax.devices()) < size: raise unittest.SkipTest(f"Test requires {size} global devices.") - devices = sorted(jax.devices(), key=lambda d: d.id) - mesh_devices = np.array(devices[:size]).reshape(mesh_shape) - global_mesh = jax.sharding.Mesh(mesh_devices, axis_names) - return global_mesh - + if iota_order: + devices = sorted(jax.devices(), key=lambda d: d.id) + mesh_devices = np.array(devices[:size]).reshape(mesh_shape) + return jax.sharding.Mesh(mesh_devices, axis_names) + else: + return jax.make_mesh(mesh_shape, axis_names) class _cached_property: null = object() diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 004f03a85b04..61993637912f 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -52,7 +52,7 @@ def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir): @jtu.skip_on_devices('cpu') def test_memory_consumption(self): - global_mesh = jtu.create_global_mesh((2, 4), ('x', 'y')) + global_mesh = jtu.create_mesh((2, 4), ('x', 'y')) inp_shape = (2_048, 4_096) pspec = P('x', 'y') num = math.prod(inp_shape) @@ -97,7 +97,7 @@ async def deserialize_with_byte_limit(): tm.stop() def test_memory_consumption_for_save(self): - global_mesh = jtu.create_global_mesh((1, 1), ('x', 'y')) + global_mesh = jtu.create_mesh((1, 1), ('x', 'y')) inp_shape = (16 * 1024, 16 * 1024) pspec = P('x', 'y') num = math.prod(inp_shape) @@ -132,7 +132,7 @@ def test_memory_consumption_for_save(self): tm.stop() def test_checkpointing_with_path_variant(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_shape = (8, 2) pspec = P('x', 'y') num = math.prod(inp_shape) @@ -164,7 +164,7 @@ def test_checkpointing_with_path_variant(self): self.assertEqual(m1.dtype, np.int32) def test_checkpointing_jax_array(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_shape = (8, 2) pspec = P('x', 'y') num = math.prod(inp_shape) @@ -188,7 +188,7 @@ def test_checkpointing_jax_array(self): # Third Array def cb3(_): return np.array([], dtype=np.float32) - global_mesh1d = jtu.create_global_mesh((8,), ('x',)) + global_mesh1d = jtu.create_mesh((8,), ('x',)) a3 = array.make_array_from_callback( (0,), NamedSharding(global_mesh1d, P(None)), cb3) ckpt_path3 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/third').full_path) @@ -232,7 +232,7 @@ def cb3(_): self.assertEqual(m3.dtype, np.float32) def test_checkpointing_ocdbt_transaction(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_shape = (8, 2) pspec = P('x', 'y') num = math.prod(inp_shape) @@ -262,7 +262,7 @@ def test_checkpointing_ocdbt_transaction(self): def cb3(_): return np.array([], dtype=np.float32) - global_mesh1d = jtu.create_global_mesh((8,), ('x',)) + global_mesh1d = jtu.create_mesh((8,), ('x',)) a3 = array.make_array_from_callback( (0,), NamedSharding(global_mesh1d, P(None)), cb3 ) @@ -327,7 +327,7 @@ def cb3(_): @parameterized.product(input_dtype=[np.int32, jnp.bfloat16]) def test_checkpointing_with_bigger_shape_jax_array(self, input_dtype): - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) num = math.prod(global_input_shape) @@ -349,7 +349,8 @@ def cb1(index): on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y')) + ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), + P('x', 'y')) m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)], [np.float32]) @@ -375,7 +376,7 @@ def cb1(index): @parameterized.product(input_dtype=[jnp.int4, jnp.int8]) def test_checkpointing_with_int4(self, input_dtype): - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) num = math.prod(global_input_shape) @@ -397,7 +398,8 @@ def cb(index): on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y')) + ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), + P('x', 'y')) target_dtype = jnp.dtype('int4') m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)], @@ -424,7 +426,7 @@ def cb(index): self.assertArraysEqual(l.data, global_input_data.astype(target_dtype)) def test_checkpointing_scalar_jax_array(self): - global_mesh = jtu.create_global_mesh((2,), ('x')) + global_mesh = jtu.create_mesh((2,), ('x')) global_input_shape = () data = np.array(4) s = NamedSharding(global_mesh, P(None)) @@ -441,7 +443,7 @@ def test_checkpointing_scalar_jax_array(self): on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - ds = NamedSharding(jtu.create_global_mesh((2,), ('x')), P(None)) + ds = NamedSharding(jtu.create_mesh((2,), ('x')), P(None)) m1, = serialization.run_deserialization( [ds], @@ -454,7 +456,7 @@ def test_checkpointing_scalar_jax_array(self): self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32)) def test_deserialize_tensorstore_array_jax_array(self): - global_mesh = jtu.create_global_mesh((2,), ('x')) + global_mesh = jtu.create_mesh((2,), ('x')) data = np.arange(1024) tspec = ts.array(data).spec() m1, = serialization.run_deserialization( @@ -550,7 +552,7 @@ def test_load_with_layout(self): if not jtu.test_device_matches(['tpu']): self.skipTest('Layouts are only supported on TPUs') - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(32).reshape(8, 4) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) diff --git a/tests/array_test.py b/tests/array_test.py index 9d62b68bb0a5..3cbb27877a20 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -81,7 +81,7 @@ def test_array_impl_name(self): ("mesh_fully_replicated", P()), ) def test_jax_array_value(self, mesh_axes): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, global_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes)) @@ -121,7 +121,7 @@ def test_jax_array_value(self, mesh_axes): ) def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids, expected_is_fully_replicated): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) s = jax.sharding.NamedSharding(global_mesh, mesh_axes) arr, global_input_data = create_array(global_input_shape, s) @@ -148,7 +148,7 @@ def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, self.assertArraysEqual(g.data, l.data) def test_addressable_data(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) shape = (8, 2) s = jax.sharding.NamedSharding(global_mesh, P(None)) arr, inp_data = create_array(shape, s) @@ -156,7 +156,7 @@ def test_addressable_data(self): self.assertArraysEqual(inp_data, arr.addressable_data(i)) def test_array_delete(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -174,7 +174,7 @@ def test_single_device_array_usage_after_delete(self): _ = x + 1 def test_multi_device_array_usage_after_delete(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) shape = (8, 2) arr = jax.device_put(np.arange(math.prod(shape), dtype=np.int32), jax.sharding.NamedSharding(global_mesh, P('x'))) @@ -205,14 +205,14 @@ def test_device_put_array_delete(self): self.assertIsNone(arr._arrays) def test_array_device_get(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) self.assertArraysEqual(jax.device_get(arr), input_data) def test_repr(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -254,7 +254,7 @@ def test_jnp_array_normal_add(self): self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding) def test_array_sharded_astype(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -272,7 +272,7 @@ def test_jnp_array_astype(self): self.assertArraysEqual(arr_float32, arr.astype(np.float32)) def test_array_delete_idempotent(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) arr = jax.device_put(np.arange(8), jax.sharding.NamedSharding(mesh, P('x'))) arr.delete() @@ -282,7 +282,7 @@ def test_array_delete_idempotent(self): self.assertTrue(arr.is_deleted()) def test_sharded_add(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) a, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -296,7 +296,7 @@ def test_sharded_add(self): self.assertArraysEqual(i.data, expected[i.index]) def test_sharded_zeros_like(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) a, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -318,7 +318,7 @@ def test_wrong_num_arrays(self): if jax.device_count() < 4: self.skipTest('Requires more than 4 devices') shape = (8, 2) - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) devices = jax.local_devices()[:2] # Taking up to 2 devices s = jax.sharding.NamedSharding(mesh, P('x', 'y')) inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -342,7 +342,7 @@ def test_arrays_not_in_device_assignment(self): if jax.device_count() < 4: self.skipTest('Requires more than 4 devices') shape = (8, 2) - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) # sharding device ids = {0, 1} s = jax.sharding.NamedSharding(mesh, P('x')) inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -378,7 +378,7 @@ def test_duplicated_devices_in_arrays(self): if xc._version <= 274: self.skipTest('Test requires jaxlib version 275') shape = (8, 2) - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) # Sharding device ids = {0, 1} s = jax.sharding.NamedSharding(mesh, P('x')) inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -401,7 +401,7 @@ def test_duplicated_devices_in_arrays(self): ) def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shape): shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) inp_data = np.arange(5) @@ -415,7 +415,7 @@ def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shap def test_mismatch_dtype(self): shape = (8, 2) - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) inp_data = np.arange(math.prod(shape), dtype=np.int32).reshape(shape) indices = s.devices_indices_map(shape) @@ -452,7 +452,7 @@ def test_array_iter_pmap_sharding_last_dim_sharded(self): self.assertArraysAllClose(i, j) def test_array_iter_mesh_pspec_sharding_multi_device(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -462,7 +462,7 @@ def test_array_iter_mesh_pspec_sharding_multi_device(self): self.assertArraysEqual(i, j) def test_array_iter_replicated_multi_device(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P(None))) @@ -477,7 +477,7 @@ def test_array_iter_replicated_multi_device(self): i.sharding._to_xla_hlo_sharding(i.ndim))) def test_array_getitem_mesh_pspec_sharding_multi_device(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -496,7 +496,7 @@ def _check(out, inp, shard_shape): self.assertEqual(out.sharding.shard_shape(out.shape), shard_shape) self.assertNotIsInstance(out.sharding, jax.sharding.SingleDeviceSharding) - global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + global_mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) input_shape = (4, 4, 2) arr, np_inp = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y', 'z'))) @@ -523,7 +523,7 @@ def _check(out, inp, shard_shape): _check(arr[1], np_inp[1], (2, 1)) def test_array_getitem_replicated_multi_device(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P(None))) @@ -575,7 +575,7 @@ def test_array_shards_committed(self): self.assertTrue(s.data._committed) def test_array_jnp_array_copy_multi_device(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -592,7 +592,7 @@ def test_array_jnp_array_copy_multi_device(self): c.data.unsafe_buffer_pointer()) def test_array_addressable_shards(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -620,7 +620,7 @@ def check_tracer_hash(x): check_tracer_hash(x) def test_shape_dtype_struct_sharding_jit(self): - mesh = jtu.create_global_mesh((8,), ('x')) + mesh = jtu.create_mesh((8,), ('x')) s = jax.sharding.NamedSharding(mesh, P('x')) x_dummy = jax.ShapeDtypeStruct( @@ -647,7 +647,7 @@ def f(x): s._to_xla_hlo_sharding(x_dummy.ndim))) def test_shape_dtype_struct_sharding_pjit(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) def f(x): @@ -677,7 +677,7 @@ def test_defragment(self): self.skipTest("Manual defragment not exposed via PJRT C API") # Create a few arrays - global_mesh = jtu.create_global_mesh((jax.local_device_count(),), ('x',)) + global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',)) shape = (8, 2) mpsharding = jax.sharding.NamedSharding(global_mesh, P('x',)) arr1, data = create_array(shape, mpsharding) @@ -700,7 +700,7 @@ def test_defragment(self): # OOM, and exposing allocator stats in Python. def test_on_device_size_in_bytes(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) a, _ = create_array( (8, 2), jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) shard_size = a.addressable_shards[0].data.on_device_size_in_bytes() @@ -756,7 +756,7 @@ def test_buffer_protocol_deletion(self): self.assertEqual(x_bytes, y_bytes) def test_array_copy_to_host_async(self): - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = pjit(lambda: jnp.arange(8.), out_shardings=jax.sharding.NamedSharding(global_mesh, P(None)))() self.assertLen(x.sharding.device_set, 4) @@ -765,7 +765,7 @@ def test_array_copy_to_host_async(self): def test_array_fully_replicated_shard(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_shape = (8, 2) arr, inp_data = create_array( inp_shape, jax.sharding.NamedSharding(global_mesh, P())) @@ -776,7 +776,7 @@ def test_array_fully_replicated_shard(self): self.assertArraysEqual(arr.addressable_data(0), inp_data) def test_shard_array_to_fully_replicated(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) sharding = jax.sharding.NamedSharding(global_mesh, P()) arr = jnp.arange(16) self.assertFalse(arr._committed) @@ -786,7 +786,7 @@ def test_shard_array_to_fully_replicated(self): self.assertArraysEqual(out, arr * 2) def test_fully_replicated_donated_array_is_deleted(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) sharding = jax.sharding.NamedSharding(global_mesh, P()) arr = jnp.arange(16) arr_copy = arr.copy() @@ -804,7 +804,7 @@ def test_shards_have_correct_dtype(self, dtype): self.assertEqual(shard.data.dtype, dtype) def test_make_array_from_callback_global_array(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) sharding = jax.sharding.NamedSharding(mesh, P()) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, sharding) @@ -823,7 +823,7 @@ def test_make_array_from_callback_global_array(self): def test_make_array_from_process_data_single_host_data_sharding(self): data = np.ones((1, 512)) - mesh = jtu.create_global_mesh((1, 1), ('x', 'unused')) + mesh = jtu.create_mesh((1, 1), ('x', 'unused')) sharding_spec = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec('x') ) @@ -838,7 +838,7 @@ def test_make_array_from_process_data_single_host_data_sharding(self): class ShardingTest(jtu.JaxTestCase): def test_mesh_pspec_sharding_interface(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) pspec = P('y', 'x') global_shape = (8, 4) mp_sharding = jax.sharding.NamedSharding(mesh, pspec) @@ -855,7 +855,7 @@ def test_mesh_pspec_sharding_interface(self): [0, 2, 4, 6, 1, 3, 5, 7]) def test_util_clear_cache(self): - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) s = NamedSharding(mesh, P()) s.devices_indices_map((8,)) jax.clear_caches() @@ -874,7 +874,7 @@ def test_util_clear_cache(self): ) def test_op_sharding_indices(self, pspec): shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) ops = jax.sharding.GSPMDSharding( list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape))) @@ -892,12 +892,12 @@ def test_op_sharding_indices(self, pspec): ) def test_shard_shape(self, pspec, expected_shard_shape): shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) self.assertEqual(mps.shard_shape(shape), expected_shard_shape) def test_uneven_shard_error(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, P('x', 'y')) with self.assertRaisesRegex( ValueError, @@ -930,7 +930,7 @@ def test_pmap_sharding_hash_eq(self): def test_is_compatible_error(self): shape = (8, 2) - mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl')) + mesh = jtu.create_mesh((1, 1, 2), ('replica', 'data', 'mdl')) mps = jax.sharding.NamedSharding(mesh, P(None, ('mdl',), None, None)) new_mps = jax.sharding.NamedSharding._from_parsed_pspec( mps.mesh, mps._parsed_pspec) @@ -982,7 +982,7 @@ def test_positional_sharding_op_sharding_lowering( self, pspec, shape, axes, transpose): value_shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) devices = jax.local_devices()[:8] # Taking up to 8 devices @@ -1038,7 +1038,7 @@ def test_positional_sharding_aval_compatible(self): ) def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec): ndim = len(mesh_shape) - mesh = jtu.create_global_mesh( + mesh = jtu.create_mesh( mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z')) mps = jax.sharding.NamedSharding(mesh, pspec) original_op_sharding = mps._to_xla_hlo_sharding(ndim) @@ -1071,7 +1071,7 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec): axis_names = ('x', 'y', 'z') else: axis_names = ('x',) - mesh = jtu.create_global_mesh(mesh_shape, axis_names) + mesh = jtu.create_mesh(mesh_shape, axis_names) mps = jax.sharding.NamedSharding(mesh, pspec) shape = (8, 2, 4) mps_op_sharding = mps._to_xla_hlo_sharding(len(shape)) @@ -1086,7 +1086,7 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec): def test_devices_sharding_respects_init_mesh_shape(self): value_shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, P('x', 'y')) devices_sharding = jax.sharding.PositionalSharding(mesh.devices) @@ -1140,14 +1140,14 @@ def test_default_pmap_sharding_with_devices(self): self.assertEqual(ps._device_assignment, new_order) def test_mesh_repr(self): - mesh = jtu.create_global_mesh((1, 1), ('x', 'y')) + mesh = jtu.create_mesh((1, 1), ('x', 'y')) mesh_repr = repr(mesh) self.assertIn('device_ids', mesh_repr) self.assertIn('axis_names', mesh_repr) def test_are_shardings_equivalent(self): - mesh = jtu.create_global_mesh((1,), ('x')) - mesh2 = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((1,), ('x')) + mesh2 = jtu.create_mesh((2, 1), ('x', 'y')) s1 = jax.sharding.NamedSharding(mesh, P('x')) s2 = jax.sharding.SingleDeviceSharding(jax.devices()[0]) @@ -1196,7 +1196,7 @@ def test_are_shardings_equivalent(self): def test_devices_indices_map_good_error_message(self): shape = (1, 2) - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) with self.assertRaisesRegex( ValueError, @@ -1205,7 +1205,7 @@ def test_devices_indices_map_good_error_message(self): s.devices_indices_map(shape) def test_scalar_input_wrong_pspec(self): - mesh = jtu.create_global_mesh((1, ), ('x')) + mesh = jtu.create_mesh((1, ), ('x')) shape = () s = jax.sharding.NamedSharding(mesh, P('x')) with self.assertRaisesRegex( @@ -1222,13 +1222,13 @@ def test_mesh_caching_during_construction(self): self.assertIs(mesh1, mesh2) def test_mesh_str(self): - mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) self.assertEqual(str(mesh), "Mesh('x': 2, 'y': 2, 'z': 2)") def test_make_array_from_callback_error(self): mesh_shape = (2, 3) global_shape = tuple(np.square(mesh_shape)) - mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) + mesh = jtu.create_mesh(mesh_shape, ('x', 'y')) pspec = P('x', 'y') sharding = jax.sharding.NamedSharding(mesh, pspec) n = math.prod(global_shape) @@ -1257,7 +1257,7 @@ def f(x): def test_make_array_from_single_device_arrays_bad_inputs(self): x = jnp.arange(10) - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) s = jax.sharding.NamedSharding(mesh, P('x')) x = jax.device_put(x, s) @@ -1268,7 +1268,7 @@ def test_make_array_from_single_device_arrays_bad_inputs(self): def test_gspmd_sharding_hash_eq(self): - mesh = jtu.create_global_mesh((1, 1, 1), ('x', 'y', 'z')) + mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z')) ns = NamedSharding(mesh, P('x', 'y', 'z')) x1 = GSPMDSharding(mesh._flat_devices_tuple, ns._to_xla_hlo_sharding(3)) @@ -1283,14 +1283,14 @@ def test_device_attr(self): self.assertEqual(x.device, list(x.devices())[0]) # For sharded arrays, x.device returns the sharding - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) x = jax.device_put(x, sharding) self.assertEqual(x.device, sharding) def test_to_device(self): device = jax.devices()[-1] - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) x = jnp.ones((2, 10)) @@ -1306,7 +1306,7 @@ def test_to_device(self): class ShardyShardingTest(jtu.JaxTestCase): def test_long_axis_names(self): - mesh = jtu.create_global_mesh((2, 2, 2), ('sequence', 'data', 'model')) + mesh = jtu.create_mesh((2, 2, 2), ('sequence', 'data', 'model')) s = jax.sharding.NamedSharding(mesh, P(('sequence', 'data'), 'model')) sdy_sharding = s._to_sdy_sharding(3) self.assertEqual( @@ -1323,7 +1323,7 @@ def test_long_axis_names(self): '#sdy.sharding<@mesh, [{"sequence", "data"}, {"model"}, {}]>') def test_unconstrained(self): - mesh = jtu.create_global_mesh((8,), ('x',)) + mesh = jtu.create_mesh((8,), ('x',)) s = jax.sharding.NamedSharding(mesh, P(None, P.UNCONSTRAINED, 'x')) sdy_sharding = s._to_sdy_sharding(3) self.assertEqual( @@ -1351,7 +1351,7 @@ def f(x): 32, x.shape) return bits + x - mesh = jtu.create_global_mesh((num_devices,), ('x',)) + mesh = jtu.create_mesh((num_devices,), ('x',), iota_order=True) s = jax.sharding.NamedSharding(mesh, P('x')) n = num_devices ** 2 @@ -1387,7 +1387,7 @@ def f(x): global_shape = tuple(np.square(mesh_shape)) - mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) + mesh = jtu.create_mesh(mesh_shape, ('x', 'y')) s = jax.sharding.NamedSharding(mesh, pspec) n = math.prod(global_shape) diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index a508373b61a7..f5d7c47115b6 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -1098,7 +1098,7 @@ def f_(x): return jnp.square(x) f = jax.jit(f_) - mesh = jtu.create_global_mesh((2,), ('x')) + mesh = jtu.create_mesh((2,), ('x')) s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) arr = jax.device_put(np.arange(8).reshape(2, 2, 2), s) @@ -1114,7 +1114,7 @@ def f_(x): return jnp.square(x) f = pjit.pjit(f_) - mesh = jtu.create_global_mesh((2,), ('x')) + mesh = jtu.create_mesh((2,), ('x')) s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) arr = jax.device_put(np.arange(8).reshape(2, 2, 2), s) diff --git a/tests/layout_test.py b/tests/layout_test.py index 1af8e259ecce..2f240195f22d 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -49,7 +49,7 @@ def setUp(self): def test_auto_layout(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape1 = (128, 128) shape2 = (128, 128) s1 = NamedSharding(mesh, P('x', 'y')) @@ -116,7 +116,7 @@ def init(x, y): def test_default_layout(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -157,7 +157,7 @@ def f(x): def test_in_layouts_out_layouts(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (8, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -183,7 +183,7 @@ def f(x): def test_sharding_and_layouts(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -226,7 +226,7 @@ def f(x, y, z, a, b, c): self.assertArraysEqual(out2, out6) def test_no_error_dced_args(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) shape = (8, 2) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -247,7 +247,7 @@ def f(x, y): def test_aot_layout_mismatch(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (256, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x')) @@ -283,7 +283,7 @@ def test_cpu_default_backend_layout(self): out_cpu, out_cpu).compile() # doesn't crash def test_device_put_concrete_layout(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (8, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -326,7 +326,7 @@ def invalid_layout_spec(self): Layout(compiled.output_layouts[0], None) def test_layout_on_sds(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) @@ -345,7 +345,7 @@ def test_layout_on_sds(self): jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Layout(DLL.AUTO)) def test_make_array_from_callback(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) @@ -370,7 +370,7 @@ def test_make_array_from_callback(self): np_inp.shape, Layout(None, None), lambda idx: np_inp[idx]) def test_wsc_concrete_layout(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (16, 128) s = NamedSharding(mesh, P('x')) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -393,7 +393,7 @@ def f(x): self.assertArraysEqual(out, np_inp.T) def test_wsc_bfloat16_concrete_layout(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (16, 128) s = NamedSharding(mesh, P('x')) inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) @@ -430,7 +430,7 @@ def test_device_put_user_concrete_layout(self): self.assertArraysEqual(out, np_inp) def test_concrete_layout_jit(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (16, 128) s = NamedSharding(mesh, P('x')) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -474,7 +474,7 @@ def test_incompatible_aval_error_device_put(self): def test_concrete_layout_in_shardings(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) shape = (16, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -528,7 +528,7 @@ def test_in_layouts_jit_jnp_input(self): self.assertArraysEqual(out4, np_inp + 1) def test_layout_donation(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) shape = (16, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -544,7 +544,7 @@ def f(x): self.assertTrue(arr.is_deleted()) def test_layout_donation_auto(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) shape = (128, 16) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -559,7 +559,7 @@ def f(x): self.assertTrue(arr.is_deleted()) def test_layout_donation_matching_in_and_out(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) shape = (128, 16) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -577,7 +577,7 @@ def f(x): @jtu.skip_on_devices('cpu', 'gpu') def test_layout_donation_mismatching_in_and_out_fails(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) shape = (16*2, 32016*2) np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) diff --git a/tests/memories_test.py b/tests/memories_test.py index 6140c6945df5..9b8b990d674b 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -47,7 +47,7 @@ def get_memory_kinds_from_executable(f, args): def _create_inputs(shape, pspec, mem_kind=None): - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, pspec, memory_kind=mem_kind) inp = jax.device_put(np_inp, s) @@ -71,7 +71,7 @@ def setUp(self): ) def test_canonicalize_memory_kind(self, name): if name == "named_sharding": - mesh = jtu.create_global_mesh((1,), "x") + mesh = jtu.create_mesh((1,), "x") ns = NamedSharding(mesh, P("x")) self.assertEqual(ns.memory_kind, self._default_memory_kind) elif name == "positional_sharding": @@ -96,7 +96,7 @@ def test_wrong_memory_kind(self, name): with self.assertRaisesRegex( ValueError, "Could not find memory addressable by device.*" ): - mesh = jtu.create_global_mesh((1,), ("x",)) + mesh = jtu.create_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind="hbm") elif name == "positional_sharding": with self.assertRaisesRegex( @@ -128,7 +128,7 @@ def test_correct_tpu_memory_kind(self, name): self.skipTest("TPU memory kind test.") if name == "named_sharding": - mesh = jtu.create_global_mesh((1,), ("x",)) + mesh = jtu.create_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) elif name == "positional_sharding": PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) @@ -146,7 +146,7 @@ def test_correct_tpu_memory_kind(self, name): ) def test_sharding_eq(self, name): if name == "named_sharding": - mesh = jtu.create_global_mesh((1,), ("x",)) + mesh = jtu.create_mesh((1,), ("x",)) s1 = NamedSharding(mesh, P("x")) s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) @@ -164,7 +164,7 @@ def test_sharding_eq(self, name): self.assertEqual(s1, s2) def test_sharding_equivalent(self): - mesh = jtu.create_global_mesh((1,), ("x",)) + mesh = jtu.create_mesh((1,), ("x",)) ndim = 2 ns1 = NamedSharding(mesh, P("x")) gs1 = GSPMDSharding( @@ -215,7 +215,7 @@ def test_error_transfer_to_memory_kind_outside_jit(self): def test_device_put_host_to_hbm(self, host_memory_kind: str): if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": self.skipTest("unpinned_host does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind) np_inp = np.arange(16).reshape(8, 2) @@ -231,7 +231,7 @@ def test_device_put_host_to_hbm(self, host_memory_kind: str): def test_device_put_hbm_to_host(self, host_memory_kind: str): if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": self.skipTest("unpinned_host does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind) inp = jnp.arange(16).reshape(8, 2) @@ -314,7 +314,7 @@ def test_device_put_on_different_device_with_the_same_memory_kind( # TODO(yashkatariya): Enable this once we can compute on host. # def test_device_put_resharding(self): - # mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + # mesh = jtu.create_mesh((2, 2), ("x", "y")) # s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host") # s_hbm = s_host.with_memory_kind("device") # np_inp = np.arange(16).reshape(8, 2) @@ -341,7 +341,7 @@ def test_device_put_on_different_device_with_the_same_memory_kind( def test_device_put_numpy_array(self, host_memory_kind: str): if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": self.skipTest("unpinned_host does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device") s_host = s_hbm.with_memory_kind(host_memory_kind) @@ -462,7 +462,7 @@ def f(a): def test_parameter_streaming_with_scalar_and_constant(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) scalar_inp = 1 s_host = NamedSharding(mesh, P(), memory_kind="pinned_host") @@ -488,7 +488,7 @@ def f(scalar_input): def test_parameter_and_output_streaming_with_array(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host") inp_host = jax.device_put(np_inp, s_host) @@ -540,7 +540,7 @@ def f(x): ) def test_identity_jit_host_to_device_and_vice_versa(self): - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) s_host = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') s_dev = s_host.with_memory_kind('device') @@ -560,7 +560,7 @@ def test_identity_jit_host_to_device_and_vice_versa(self): self.assertEqual(out_host.sharding, s_host) def test_parameter_streaming_inside_scan(self): - mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z")) + mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) np_inp = np.arange(4096.0).reshape(16, 16, 16) s_host = NamedSharding(mesh, P("x", "y", "z"), memory_kind="pinned_host") arr_host = jax.device_put(np_inp, s_host) @@ -582,7 +582,7 @@ def body(carry, x): def test_output_streaming(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test is flaky on GPU backend.") - mesh = jtu.create_global_mesh((1, 1), ("x", "y")) + mesh = jtu.create_mesh((1, 1), ("x", "y")) np_inp = np.arange(16.0).reshape(8, 2) s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device") s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host") @@ -619,7 +619,7 @@ def test_output_streaming_inside_scan(self): self.skipTest("This test does not work on GPU backend.") if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") - mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z")) + mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) np_inp = np.arange(4096).reshape(16, 16, 16) s_hbm = NamedSharding(mesh, P(None, "y", "z"), memory_kind="device") arr_hbm = jax.device_put(np_inp, s_hbm) @@ -680,7 +680,7 @@ def _check_mem_kind(self, executable_kind, out_sharding, expected_kind): self.assertEqual(executable_kind, expected_kind) def test_compute_no_inputs(self): - mesh = jtu.create_global_mesh((4,), ('data')) + mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) cpu_sharding = NamedSharding(mesh, P('data'), memory_kind='pinned_host') @@ -698,7 +698,7 @@ def init(): def test_compute_no_inputs_host_replicated(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: self.skipTest("This test requires an xla_version >= 3.") - mesh = jtu.create_global_mesh((4,), ('data')) + mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) cpu_sharding = NamedSharding(mesh, P(), memory_kind='pinned_host') @@ -843,7 +843,7 @@ def f(x): self.assertLen(out, 2) def test_nested_no_op_compute(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) @@ -868,7 +868,7 @@ def f2(x): self.assertEqual(out.sharding, s) def test_sharded_compute_on_host(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) @@ -921,7 +921,7 @@ def f_bwd(res, tx): def test_host_offload_in_custom_vjp_sharded(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) s = NamedSharding(mesh, P('x')) @jax.custom_vjp @@ -1005,7 +1005,7 @@ def f(x): def test_pure_host_data_and_compute(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') np_inp = np.arange(16).reshape(8, 2) arr_host = jax.device_put(np_inp, s) @@ -1032,7 +1032,7 @@ def test_eager_compute(self): self.assertArraysAllClose(out, jnp.sin(inp * 2)) def test_compute_per_annotation(self): - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) s = NamedSharding(mesh, P("x", "y")) np_inp = np.arange(16.).reshape(8, 2) arr = jax.device_put(np_inp, s) @@ -1223,7 +1223,7 @@ def mul(x): self.assertArraysEqual(out2, np_inp2 @ np_inp2.T) def test_sharding_devices_indices_map_cache_hit(self): - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) shape = (8, 2) s1 = NamedSharding(mesh, P("x", "y")) s2 = NamedSharding(mesh, P("x", "y"), memory_kind="device") @@ -1238,7 +1238,7 @@ def test_sharding_devices_indices_map_cache_hit(self): def test_no_donation_across_memory_kinds(self): if xb.using_pjrt_c_api(): raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API") - mesh = jtu.create_global_mesh((2, 1), ("x", "y")) + mesh = jtu.create_mesh((2, 1), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) s_hbm = NamedSharding(mesh, P("x")) s_host = s_hbm.with_memory_kind("pinned_host") @@ -1257,7 +1257,7 @@ def f(x): self.assertNotDeleted(inp) def test_single_mem_kind_donation_default_mem_kind(self): - mesh = jtu.create_global_mesh((2,), "x") + mesh = jtu.create_mesh((2,), "x") s = NamedSharding(mesh, P()) @functools.partial(jax.jit, out_shardings=s, donate_argnums=0) @@ -1273,7 +1273,7 @@ def f(inp1): self.assertDeleted(x) def test_compute_offload_inside_shmap(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) @@ -1327,7 +1327,7 @@ def h(x): self.assertArraysAllClose(out, expected_out, rtol=1e-3) def test_mem_kind_donation_pinned_host(self): - mesh = jtu.create_global_mesh((2,), "x") + mesh = jtu.create_mesh((2,), "x") s = NamedSharding(mesh, P(), memory_kind='pinned_host') s_dev = s.with_memory_kind('device') @@ -1349,7 +1349,7 @@ def f(inp1, inp2): @parameterized.parameters("pinned_host", "device") def test_identity_mem_kind_donation(self, mem_kind): - mesh = jtu.create_global_mesh((2,), "x") + mesh = jtu.create_mesh((2,), "x") s = NamedSharding(mesh, P(), memory_kind=mem_kind) @functools.partial(jax.jit, out_shardings=s, donate_argnums=0) @@ -1367,7 +1367,7 @@ def f(inp): @jtu.run_on_devices('tpu') def test_aot_device_implicit_transfer(self): - mesh = jtu.create_global_mesh((1,), 'x') + mesh = jtu.create_mesh((1,), 'x') np_inp = np.arange(8) arr = jax.device_put(np_inp, NamedSharding(mesh, P())) @@ -1397,7 +1397,7 @@ def setUp(self): super().setUp() def test_remat_jaxpr_offloadable(self): - mesh = jtu.create_global_mesh((2,), ("x",)) + mesh = jtu.create_mesh((2,), ("x",)) inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x"))) def policy(prim, *avals, **params): @@ -1440,7 +1440,7 @@ def f(x): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_jaxpr_offloadable(self): - mesh = jtu.create_global_mesh((2,), ("x",)) + mesh = jtu.create_mesh((2,), ("x",)) shape = (256, 128) np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) s = NamedSharding(mesh, P("x")) @@ -1498,7 +1498,7 @@ def g(ys, _): def test_remat_scan_layout_change_offloadable(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Remat scan does not work on GPU backend.") - mesh = jtu.create_global_mesh((2,), ("x",)) + mesh = jtu.create_mesh((2,), ("x",)) shape = (256, 128) np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) s = NamedSharding(mesh, P("x")) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index 0235ba89293f..40c5b6b2ea92 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -354,7 +354,7 @@ def test_gpu_multi_node_transparent_initialize_and_psum(self): def test_pjit_gda_multi_input_multi_output(self): jax.distributed.initialize() - global_mesh = jtu.create_global_mesh((8, 2), ("x", "y")) + global_mesh = jtu.create_mesh((8, 2), ("x", "y")) global_input_shape = (16, 2) global_input_data = np.arange( util.prod(global_input_shape)).reshape(global_input_shape) @@ -558,7 +558,7 @@ def test_pjit_gda_non_contiguous_mesh_2d_aot(self): def test_pjit_gda_eval_shape(self): jax.distributed.initialize() - with jtu.create_global_mesh((16,), ("x")): + with jtu.create_mesh((16,), ("x")): @functools.partial(pjit.pjit, in_shardings=jax.sharding.PartitionSpec(None), diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7dc015c90bca..4d60c3017287 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -54,7 +54,7 @@ def tearDown(self): @unittest.skip("Test failing in CI") def testPGLEProfilerGetFDOProfile(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) @partial( jax.jit, @@ -83,7 +83,7 @@ def f(x, y): @unittest.skip("Test failing in CI") def testPGLEProfilerGetFDOProfileLarge(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) its = 500 @partial( @@ -112,7 +112,7 @@ def f(x): self.assertEqual(fdo_profile.count(b'custom'), its) def testAutoPgle(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) @partial( jax.jit, @@ -245,7 +245,7 @@ def check_if_cache_hit(event): self.assertFalse(pgle_profiler.is_fdo_consumed()) def testPassingFDOProfile(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) @partial( jax.jit, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 1efbcc146d13..3b0fbb1455f4 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -187,10 +187,10 @@ def f(x, y): shape = (8, 8) x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - with jtu.create_global_mesh((2,), ('x')) as mesh: + with jtu.create_mesh((2,), ('x')) as mesh: actual = f(x, x + 1) expected = x + (x + 1) - self.assertEqual(mesh, jtu.create_global_mesh((2,), ('x'))) + self.assertEqual(mesh, jtu.create_mesh((2,), ('x'))) self.assertAllClose(actual, expected, check_dtypes=False) _check_instance(self, actual) self.assertLen(actual.addressable_shards, 2) @@ -226,15 +226,15 @@ def f(x, y): check_dtypes=False) def testDifferentNestedMesh(self): - with jtu.create_global_mesh((2, 1), ("x", "y")) as m1: - with jtu.create_global_mesh((2, 2), ("a", "b")) as m2: + with jtu.create_mesh((2, 1), ("x", "y")) as m1: + with jtu.create_mesh((2, 2), ("a", "b")) as m2: self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, m2) self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, m1) self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, mesh_lib.EMPTY_ENV.physical_mesh) def testSameNestedMesh(self): - mesh = jtu.create_global_mesh((2, 1), ("a", "b")) + mesh = jtu.create_mesh((2, 1), ("a", "b")) thread_resources = mesh_lib.thread_resources with mesh as m1: with mesh as m2: @@ -258,7 +258,7 @@ def dec(): self.assertArraysEqual(out, x) def testMeshHashRace(self): - mesh = jtu.create_global_mesh((2, 1), ('a', 'testMeshHashRace')) + mesh = jtu.create_mesh((2, 1), ('a', 'testMeshHashRace')) self.assertFalse(hasattr(mesh, '_hash')) with concurrent.futures.ThreadPoolExecutor(max_workers=5) as pool: fs = [] @@ -311,7 +311,7 @@ def f(x, y): @jtu.run_on_devices('cpu', 'gpu', 'tpu') def testBufferDonationWithNames(self): - mesh = jtu.create_global_mesh((2,), ('x')) + mesh = jtu.create_mesh((2,), ('x')) s = NamedSharding(mesh, P('x')) @partial(pjit, out_shardings=s, donate_argnames='inp2') @@ -327,7 +327,7 @@ def f(inp1, inp2): @jtu.run_on_devices('cpu', 'gpu', 'tpu') def testBufferDonationWithKwargs(self): - mesh = jtu.create_global_mesh((2,), ('x')) + mesh = jtu.create_mesh((2,), ('x')) s = NamedSharding(mesh, P('x')) @partial(pjit, out_shardings=s, donate_argnames=('inp2', 'inp3')) @@ -346,7 +346,7 @@ def f(inp1, inp2, inp3): @jtu.run_on_devices('cpu', 'gpu', 'tpu') def testBufferDonationWithPyTreeKwargs(self): - mesh = jtu.create_global_mesh((2,), ('x')) + mesh = jtu.create_mesh((2,), ('x')) s = NamedSharding(mesh, P('x')) @partial(pjit, out_shardings=s, donate_argnames='inp2') @@ -371,7 +371,7 @@ def f(inp1, inp2, inp3): @jtu.run_on_devices('tpu', 'cpu', 'gpu') def testBufferDonationWithOutputShardingInference(self): - mesh = jtu.create_global_mesh((2,), 'x') + mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) rs = NamedSharding(mesh, P()) @@ -404,7 +404,7 @@ def f(inp1, inp2, inp3): def testBufferDonationWithOutputShardingInferenceAndTokens(self): if config.use_shardy_partitioner.value: self.skipTest('b/355263220: Shardy does not support callbacks yet.') - mesh = jtu.create_global_mesh((2,), 'x') + mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) def _callback(x): @@ -425,7 +425,7 @@ def f(x): @jtu.run_on_devices('tpu', 'cpu', 'gpu') def testBufferDonationNotDonated(self): - mesh = jtu.create_global_mesh((2,), 'x') + mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) @partial(pjit, donate_argnames=('x')) @@ -467,7 +467,7 @@ def f(x): self.assertIn('sharding = "{replicated}"', str(hlo)) def testShardingConstraintWithArray(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) @partial(pjit, in_shardings=s, out_shardings=s) @@ -495,7 +495,7 @@ def testShardingConstraintWithArrayOpSharding(self): if config.use_shardy_partitioner.value: self.skipTest("Shardy doesn't support PositionalSharding") shape = (8, 8) - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) ops = pxla.to_gspmd_sharding( NamedSharding(mesh, P('x', 'y')), len(shape)) @@ -521,7 +521,7 @@ def f(x): self.assertIn("sharding={replicated}", hlo.as_hlo_text()) def testShardingConstraintPyTreeWithArray(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) @jax.jit def f(x): @@ -544,7 +544,7 @@ def f(x): def testShardingConstraintPyTreeWithUnconstrainedDimsWithJit(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @jax.jit def f(x): x = with_sharding_constraint( @@ -1169,7 +1169,7 @@ def f(x, y): def test_local_sharded_key_array_sda(self): input_shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) seeds = jnp.arange( math.prod(input_shape), dtype=np.uint32).reshape(input_shape) @@ -1186,7 +1186,7 @@ def make_keys(seeds): jax.random.key_data(out) # doesn't crash def test_with_sharding_constraint_is_compatible_error(self): - mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl')) + mesh = jtu.create_mesh((1, 1, 2), ('replica', 'data', 'mdl')) with mesh: def f(x): @@ -1287,7 +1287,7 @@ def f(x): ) def test_with_sharding_constraint_vmap_spmd_axis_name_error(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) def f(x): return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('x'))) @@ -1571,7 +1571,7 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape): partition=partition, ) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32) s = NamedSharding(mesh, P('x')) @@ -1592,7 +1592,7 @@ def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, mesh_axis_names): if config.use_shardy_partitioner.value: self.skipTest('Must register auto partitioner for Shardy') - global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) + global_mesh = jtu.create_mesh(mesh_shape, mesh_axis_names) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1610,7 +1610,7 @@ def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, def test_xla_arr_sharding_mismatch(self): if config.use_shardy_partitioner.value: self.skipTest('Must register auto partitioner for Shardy') - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((2, 2), ('x', 'y')) global_input_shape = (6, 2) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1639,7 +1639,7 @@ def test_xla_arr_sharding_mismatch(self): def test_gda_auto_shardings_len(self): if config.use_shardy_partitioner.value: self.skipTest('Must register auto partitioner for Shardy') - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((2, 2), ('x', 'y')) global_input_shape = (4, 2) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1661,7 +1661,7 @@ def test_jit_arr_partial_auto_sharding_array( self, mesh_shape, mesh_axis_names, pspec): if config.use_shardy_partitioner.value: self.skipTest('Must register auto partitioner for Shardy') - mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) + mesh = jtu.create_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1682,7 +1682,7 @@ def test_jit_arr_partial_auto_sharding_array( self.assertArraysEqual(o._value, input_data) def test_jit_different_mesh_in_auto(self): - mesh1 = jtu.create_global_mesh((4,), ('x',)) + mesh1 = jtu.create_mesh((4,), ('x',)) dev = jax.devices() mesh2 = jax.sharding.Mesh([dev[0], dev[3], dev[2], dev[1]], 'x') f = jax.jit(lambda x, y: (x, y), @@ -1704,7 +1704,7 @@ def test_jit_auto_sharding_partial_tuple_input_shardings( if config.use_shardy_partitioner.value: self.skipTest('Must register auto partitioner for Shardy') - mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) + mesh = jtu.create_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1733,7 +1733,7 @@ def test_jit_auto_sharding_partial_tuple_input_shardings( @unittest.skip('The error is not raised yet. Enable this back once we raise ' 'the error in pjit again.') def test_pjit_array_error(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1763,7 +1763,7 @@ class ArrayPjitTest(jtu.JaxTestCase): ) def test_pjit_array_single_output(self, out_axis_resources, shard_shape): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) @@ -1789,7 +1789,7 @@ def test_pjit_array_single_output(self, out_axis_resources, shard_shape): def test_pjit_array_single_output_with_mesh_context_manager( self, out_axis_resources, shard_shape): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) @@ -1810,7 +1810,7 @@ def test_pjit_array_single_output_with_mesh_context_manager( def test_numpy_array_input_assume_fully_replicated(self): input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_data = np.arange( math.prod(input_shape)).reshape(input_shape) @@ -1825,7 +1825,7 @@ def test_numpy_array_input_assume_fully_replicated(self): def test_numpy_array_input(self): input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_data = np.arange( math.prod(input_shape), dtype=np.float32).reshape(input_shape) with global_mesh: @@ -1854,7 +1854,7 @@ def _checks(out, input_data): self.assertArraysEqual(out._value, input_data) global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) @@ -1879,7 +1879,7 @@ def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, 'TODO(b/355263220) Shardy conflict resolution is not complete. Issue ' 'here is that for `a1 @ a1.T` GSPMD gives dim 0 sharded on `x` while ' 'Shardy gives it fully replicated.') - global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) + global_mesh = jtu.create_mesh(mesh_shape, ('x', 'y')) global_input_shape = (8, 2) spec1 = P('x', 'y') @@ -1924,7 +1924,7 @@ def f(tree): def test_sds_full_like(self): # https://github.com/google/jax/issues/20390 - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) x = jax.ShapeDtypeStruct((4, 4), jnp.float32, sharding=s) y = jnp.zeros_like(x) @@ -1936,7 +1936,7 @@ def test_sds_full_like(self): def test_in_axis_resources_mismatch_error(self): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, _ = create_array(global_input_shape, global_mesh, mesh_axes) @@ -1952,7 +1952,7 @@ def test_in_axis_resources_mismatch_error(self): def test_in_axis_resources_same_as_array_sharding(self): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, _ = create_array(global_input_shape, global_mesh, mesh_axes) @@ -1970,11 +1970,11 @@ def f(): def test_array_device_assignment_mismatch_with_mesh(self): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, _ = create_array( - global_input_shape, jtu.create_global_mesh((2, 2), ('x', 'y')), + global_input_shape, jtu.create_mesh((2, 2), ('x', 'y')), mesh_axes) with global_mesh: @@ -1984,7 +1984,7 @@ def test_array_device_assignment_mismatch_with_mesh(self): def test_array_lower_compile(self): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) a1, input_data = create_array(global_input_shape, global_mesh, P('x', 'y')) a2, _ = create_array(global_input_shape, global_mesh, P('x')) @@ -2038,7 +2038,7 @@ def make_keys(seeds): def test_globally_sharded_key_array_8x4_multi_device_with_out_sharding(self): input_shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) spec = P('x', 'y') seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) @@ -2055,7 +2055,7 @@ def make_keys(seeds): def test_globally_sharded_key_array_8x4_multi_device(self): input_shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) spec = P('x', 'y') seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) @@ -2072,8 +2072,8 @@ def make_keys(seeds): def test_array_device_assignment_mismatch_out_shardings(self): input_shape = (8, 2) - m1 = jtu.create_global_mesh((4, 2), ('x', 'y')) - m2 = jtu.create_global_mesh((2, 2), ('x', 'y')) + m1 = jtu.create_mesh((4, 2), ('x', 'y')) + m2 = jtu.create_mesh((2, 2), ('x', 'y')) spec = P('x', 'y') a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape) @@ -2087,8 +2087,8 @@ def test_array_device_assignment_mismatch_out_shardings(self): def test_array_device_assignment_mismatch_in_and_out_shardings(self): input_shape = (8, 2) - m1 = jtu.create_global_mesh((4, 2), ('x', 'y')) - m2 = jtu.create_global_mesh((2, 2), ('x', 'y')) + m1 = jtu.create_mesh((4, 2), ('x', 'y')) + m2 = jtu.create_mesh((2, 2), ('x', 'y')) spec = P('x', 'y') a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape) @@ -2104,7 +2104,7 @@ def test_array_device_assignment_mismatch_in_and_out_shardings(self): def test_mixed_inputs(self): input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) spec = P('x', 'y') a1, input_data = create_array(input_shape, global_mesh, spec) @@ -2119,7 +2119,7 @@ def test_mixed_inputs(self): f(input_data, a1) def test_pjit_array_same_sharding_aot(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) a1, _ = create_array(input_shape, global_mesh, P(None,)) with global_mesh: @@ -2265,7 +2265,7 @@ def test_array_enabled_non_empty_mesh_with_pspec(self): def test_pjit_uncommitted_array_reshard(self): arr = jnp.array([[1, 2, 3]]) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) with mesh: out = pjit(lambda x: x)(arr) self.assertArraysEqual(out, arr) @@ -2273,7 +2273,7 @@ def test_pjit_uncommitted_array_reshard(self): def test_pjit_uncommitted_array_in_axis_resources_reshard(self): arr = jnp.arange(16).reshape(8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) with mesh: out = pjit(lambda x: x, in_shardings=P('x', 'y'))(arr) self.assertArraysEqual(out, arr) @@ -2285,7 +2285,7 @@ def test_pjit_uncommitted_array_in_axis_resources_reshard(self): def test_pjit_uncommitted_array_and_committed_array(self): shape = (8, 2) uarr = jnp.arange(math.prod(shape), dtype=np.float32).reshape(shape) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) carr, inp_data = create_array(shape, mesh, P('x', 'y')) with mesh: out1, out2 = pjit(lambda x, y: (x, y))(uarr, carr) @@ -2298,7 +2298,7 @@ def test_pjit_uncommitted_array_and_committed_array(self): self.assertEqual(mul_out.shape, (8, 8)) self.assertLen(mul_out.addressable_shards, 8) - with jtu.create_global_mesh((2, 2), ('x', 'y')): + with jtu.create_mesh((2, 2), ('x', 'y')): with self.assertRaisesRegex( ValueError, "Received incompatible devices for pjitted computation"): @@ -2306,7 +2306,7 @@ def test_pjit_uncommitted_array_and_committed_array(self): def test_pjit_uncommitted_array_multi_devices(self): shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp = np.arange(math.prod(shape), dtype=np.int32).reshape(shape) arr = array.ArrayImpl( core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)), @@ -2342,7 +2342,7 @@ def test_pjit_committed_array_different_devices_variadic_args(self): pjit(lambda *x: x)(a, b) def test_pjit_pytree_inp_device_assignment_mismatch(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0]) b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1]) c = jax.device_put(np.arange(16).reshape(8, 2), @@ -2365,7 +2365,7 @@ def test_pjit_pytree_inp_device_assignment_mismatch(self): def test_same_out_sharding_id(self): shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) arr, inp_data = create_array(shape, mesh, P('x', 'y')) f = pjit(lambda x: x) @@ -2387,7 +2387,7 @@ def test_same_out_sharding_id(self): def test_out_sharding_indices_id_cache_hit(self): shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) arr, _ = create_array(shape, mesh, P('x', 'y')) f = pjit(lambda x: x) @@ -2423,7 +2423,7 @@ def f(tree): @jax.enable_custom_prng() def test_device_put_sharding_prng(self): - mesh = jtu.create_global_mesh((8,), ('x',)) + mesh = jtu.create_mesh((8,), ('x',)) s = NamedSharding(mesh, P('x')) x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices())) @@ -2455,7 +2455,7 @@ def test_device_put_sharding_prng(self): self.assertEqual(b.sharding, gs) def test_device_put_on_different_sharding(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) x = jnp.arange(8).reshape(4, 2) s1 = NamedSharding(mesh, P('x')) @@ -2467,7 +2467,7 @@ def test_device_put_on_different_sharding(self): self.assertEqual(b.sharding, s2) def test_with_sharding_constraint_jit(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(jax.jit, static_argnums=(0, 1)) def sharded_zeros(shape, pspec): @@ -2481,7 +2481,7 @@ def sharded_zeros(shape, pspec): out_s._to_xla_hlo_sharding(out.ndim))) def test_with_sharding_constraint_pjit(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(pjit, static_argnums=(0, 1)) def sharded_zeros(shape, pspec): @@ -2495,7 +2495,7 @@ def sharded_zeros(shape, pspec): out_s._to_xla_hlo_sharding(out.ndim))) def test_jit_with_sharding_constraint_committed_inp_error(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) @@ -2529,7 +2529,7 @@ def f(x, y, z): @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument") def test_jit_device_with_sharding_constraint_error(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0]) def sharded_zeros(shape, pspec): @@ -2544,7 +2544,7 @@ def sharded_zeros(shape, pspec): sharded_zeros((4096, 3072), P('x', 'y')) def test_concurrent_pjit(self): - global_mesh = jtu.create_global_mesh((1,), ('x',)) + global_mesh = jtu.create_mesh((1,), ('x',)) sharding = NamedSharding(global_mesh, P('x',)) n = 10 with global_mesh: @@ -2569,7 +2569,7 @@ def _invoke_with_mesh_twice(arg_tuple): def test_trivial_computation(self): shape = (8, 2) - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp_data = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(inp_data, s) @@ -2577,7 +2577,7 @@ def test_trivial_computation(self): self.assertArraysEqual(out, inp_data) def test_trivial_computation_with_sharded_const(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) const = jax.device_put(np.arange(16).reshape(8, 2), NamedSharding(mesh, P('x', 'y'))) with mesh: @@ -2586,17 +2586,17 @@ def test_trivial_computation_with_sharded_const(self): self.assertArraysEqual(out, np.arange(16).reshape(8, 2)) def test_trivial_computation_with_sharded_const_using_transposed_mesh(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) const = jax.device_put(np.arange(16).reshape(8, 2), NamedSharding(mesh, P('x', 'y'))) - mesh2 = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh2 = jtu.create_mesh((1, 2), ('x', 'y')) with mesh2: out = pjit(lambda: const)() self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, np.arange(16).reshape(8, 2)) def test_trivial_computation_with_replicated_literal(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) with mesh: out = pjit(lambda: 1)() self.assertEqual(out.sharding, NamedSharding(mesh, P())) @@ -2605,7 +2605,7 @@ def test_trivial_computation_with_replicated_literal(self): def test_multi_device_pjit_mul(self): shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_data = np.arange(math.prod(shape)).reshape(shape) arr1 = jax.device_put(inp_data, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(inp_data, NamedSharding(mesh, P(None, 'y'))) @@ -2619,7 +2619,7 @@ def test_multi_device_pjit_mul(self): def test_single_device_pjit_cpp_dispatch(self): shape = (8, 2) - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) inp_data = np.arange(math.prod(shape)).reshape(shape) f = pjit(lambda x: x @ x.T, in_shardings=None, out_shardings=None) @@ -2645,7 +2645,7 @@ def test_single_device_add_single_compile(self): def test_global_array_to_host_local_array_already_host_local(self): inp_shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) pspec = P('x', 'y') arr, _ = create_array(inp_shape, mesh, pspec) @@ -2667,7 +2667,7 @@ def f(c, x): self.assertAllClose(exe(x), x + 1, check_dtypes=False) def test_vmap_of_jvp_pjit_no_axis_resources(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) pjit_inp1 = jax.device_put( jnp.arange(8.), jax.sharding.NamedSharding(mesh, P('x'))) pjit_inp2 = jax.device_put( @@ -2693,7 +2693,7 @@ def g_(x, n): self.assertArraysEqual(pjit_out2, jit_out2) def test_vmap_of_jvp_pjit_no_axis_resources_2d(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) f_inp = jnp.arange(8.).reshape(2, 2, 2) # g_inp is sharded with P(None, 'x') because f_inp is sharded with P('x') @@ -2930,7 +2930,7 @@ def f(x, y, z, a, b, c): # pylint: disable=unused-argument self.assertLen(compiled._executable.in_avals, 1) def test_pjit_relayout_multi_slice(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @jax.jit def mul(x): @@ -2952,7 +2952,7 @@ def _check(out, expected_device, expected_out): self.assertLen(out.sharding.device_set, 1) self.assertArraysEqual(out, expected_out @ expected_out.T) - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) with jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument"): @@ -2987,7 +2987,7 @@ def _check(out, expected_device, expected_out): _check(out3, jax.devices()[1], y) def test_pjit_with_device_arg_input_from_another_pjit(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) inp = np.arange(8).reshape(4, 2) y = jax.device_put(inp, jax.sharding.NamedSharding(mesh, P('x', 'y'))) @@ -3067,7 +3067,7 @@ def test_pjit_device_backend_both_error(self): pjit(lambda x: x, device=jax.devices()[0], backend='cpu') def test_pjit_mesh_with_device_or_backend_error(self): - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) with mesh: with self.assertRaisesRegex( ValueError, @@ -3118,7 +3118,7 @@ def test_pmap_sharding_input_to_pjit_single_device(self): self.assertLen(out.devices(), 1) def test_pmap_sharding_input_to_pjit_multi_device(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())) self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) @@ -3137,7 +3137,7 @@ def test_pmap_sharding_input_to_pjit_multi_device(self): out2.sharding._to_xla_hlo_sharding(inp2.ndim))) def test_pmap_sharding_input_pjit_in_axis_resources(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())) self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) @@ -3207,7 +3207,7 @@ def g(z): f(inp) # doesn't crash def test_pjit_sin_nested(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) @pjit def f(x): @@ -3220,7 +3220,7 @@ def f(x): self.assertLen(out.devices(), 8) def test_jit_with_mesh_context_manager(self): - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) with self.assertRaisesRegex( RuntimeError, "jax.jit only supports `Sharding`s being passed to " @@ -3281,7 +3281,7 @@ def f(x): self.assertEqual(count[0], 1) def test_pjit_no_global_cache_hit_axis_resources(self): - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) s = NamedSharding(mesh, P('x')) inp = jnp.arange(8.0) @@ -3312,7 +3312,7 @@ def test_pjit_no_global_cache_hit_axis_resources(self): self.assertEqual(count[0], 1) def test_with_sharding_constraint_spmd_axis_name(self): - mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl')) + mesh = jtu.create_mesh((2, 2, 2), ('replica', 'data', 'mdl')) shape = (8, 4, 2, 2) x = jnp.arange(math.prod(shape)).reshape(shape) @@ -3335,7 +3335,7 @@ def apply_with_scan(x): self.assertListEqual(ns2, [2, 2, 1, 1]) def test_device_put_sharding_nondivisible_sharding_error(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) s = NamedSharding(mesh, P('x')) x = jnp.ones((1,)) @@ -3452,7 +3452,7 @@ def g(x): def test_pjit_out_sharding_preserved(self): if config.use_shardy_partitioner.value: raise unittest.SkipTest("Shardy doesn't support PositionalSharding") - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3498,7 +3498,7 @@ def mul(x): self.assertEqual(cache_info4.misses, cache_info3.misses) def test_cache_hit_pjit_lower_with_cpp_cache_miss(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) np_arr = np.arange(8, dtype=np.float32).reshape(8, 1) arr = jax.device_put(np_arr, ns) @@ -3524,7 +3524,7 @@ def mul(x): self.assertEqual(cache_info2.misses, cache_info1.misses) def test_list_in_pspec(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) with mesh: out = with_sharding_constraint(jnp.arange(8), P(['x'])) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) @@ -3532,7 +3532,7 @@ def test_list_in_pspec(self): def test_sharding_preserved_trivial(self): if config.use_shardy_partitioner.value: raise unittest.SkipTest("Shardy doesn't support PositionalSharding") - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3549,7 +3549,7 @@ def identity(x): self.assertIsInstance(out2.sharding, PositionalSharding) def test_sharding_preserved_aot(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3566,7 +3566,7 @@ def test_sharding_preserved_aot(self): self.assertIsInstance(out2.sharding, NamedSharding) def test_sharding_on_output_with_vmap(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) arr = jax.device_put( np.arange(16).reshape(8, 2), NamedSharding(mesh, P(None, 'x'))) @@ -3586,7 +3586,7 @@ def test_sharding_on_output_with_vmap(self): def test_jit_mul_sum_sharding_preserved(self): if config.use_shardy_partitioner.value: raise unittest.SkipTest("Shardy doesn't support PositionalSharding") - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3644,7 +3644,7 @@ def test_single_device_sharding_preserved(self): self.assertEqual(out4.devices(), {jax.devices()[1]}) def test_none_out_sharding(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) x = jnp.arange(8) with mesh: out = pjit(lambda x: x * 2, out_shardings=None)(x) @@ -3661,7 +3661,7 @@ def test_none_out_sharding(self): def test_sharding_preserved_apply_primitive(self): if config.use_shardy_partitioner.value: raise unittest.SkipTest("Shardy doesn't support PositionalSharding") - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) arr = jax.device_put(np.arange(8).reshape(8, 1), ns) @@ -3684,14 +3684,14 @@ def test_sharding_preserved_apply_primitive(self): self.assertEqual(out4.devices(), {jax.devices()[1]}) def test_same_named_sharding_pspec_on_eager_ops(self): - mesh = jtu.create_global_mesh((1, 8, 1), ('x', 'y', 'z')) + mesh = jtu.create_mesh((1, 8, 1), ('x', 'y', 'z')) sharding = jax.sharding.NamedSharding(mesh, P('x', 'y', 'z')) x = jax.device_put(jnp.arange(32).reshape(1, -1, 1), sharding) y = x + 1 self.assertEqual(x.sharding, y.sharding) def test_different_named_sharding_object_replicated(self): - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) sharding = jax.sharding.NamedSharding(mesh, P('x')) x = jax.device_put(np.arange(16).reshape(8, 2), sharding) y = jnp.sum(x) @@ -3705,7 +3705,7 @@ def test_vmap_pjit_single_device(self): self.assertIsInstance(out.sharding, SingleDeviceSharding) def test_to_gspmd_sharding_cache_with_and_without_device(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) np_inp = jnp.arange(4) def identity(x): @@ -3744,7 +3744,7 @@ def top(x): self.assertEqual(count[0], 1) def test_wsc_eager(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) inp = jax.device_put(np_inp, NamedSharding(mesh, P())) out = with_sharding_constraint(inp, NamedSharding(mesh, P('x'))) @@ -3754,14 +3754,14 @@ def test_wsc_eager(self): self.assertArraysEqual(s.data, np_inp[s.index]) def test_wsc_eager_no_resharding(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) inp = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) out = with_sharding_constraint(inp, NamedSharding(mesh, P('x'))) self.assertEqual(id(out), id(inp)) def test_wsc_eager_different_order_devices(self): - mesh1 = jtu.create_global_mesh((2,), ('x',)) + mesh1 = jtu.create_mesh((2,), ('x',)) mesh2 = jax.sharding.Mesh([jax.devices()[1], jax.devices()[0]], 'x') inp = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) with self.assertRaisesRegex( @@ -3803,7 +3803,7 @@ def test_shape_dtype_struct_as_const_error(self): jax.jit(lambda x: (x, const))(jnp.arange(8)) def test_jit_out_shardings_none(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -3812,7 +3812,7 @@ def test_jit_out_shardings_none(self): self.assertEqual(out.sharding, s) def test_jit_in_shardings_none(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -3826,7 +3826,7 @@ def test_jit_in_shardings_none(self): self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0])) def test_device_put_in_jit_default_mem_kind_no_op(self): - mesh = jtu.create_global_mesh((2,), 'x') + mesh = jtu.create_mesh((2,), 'x') np_inp = np.arange(8) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) @@ -3840,7 +3840,7 @@ def f(x): self.assertNotIn('@annotate_device_placement', lowered_text) def test_jit_both_shardings_none(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -3854,7 +3854,7 @@ def test_jit_both_shardings_none(self): self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0])) def test_jit_lower_shape_dtype_struct_sharding_none(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) lower_inp1 = jax.ShapeDtypeStruct((8, 2), np.int32, sharding=s) @@ -3891,7 +3891,7 @@ def f(inp): jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr) def test_no_output_multiple_devices(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) @pjit def f(): @@ -3982,7 +3982,7 @@ def test_mpmd_device_put_fast_path(self): def test_prng_sharding_propagation(self): input_shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) spec = P('x', 'y') seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) @@ -4008,7 +4008,7 @@ def make_keys(seeds): def test_prng_sharding_propagation_with_nested_jit(self): input_shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) spec = P('x', 'y') seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) @@ -4037,7 +4037,7 @@ def f(): def test_partial_sharded_prng_key_inp(self): input_shape = (8, 2, 2) - mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) spec = P('x', 'y', None) seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) @@ -4064,7 +4064,7 @@ def make_keys(seeds): def test_jit_partially_specified_shardings(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) s2 = NamedSharding(mesh, P('x')) @@ -4084,7 +4084,7 @@ def f(x, y, z, a, b): self.assertArraysEqual(out5, np_inp.T) def test_input_shardings_aot(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) @@ -4100,7 +4100,7 @@ def test_parameter_tupled_jit(self): if not jtu.test_device_matches(["tpu"]): self.skipTest('Parameters are tupled only on TPU if >2000 parameters') - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x')) @jax.jit @@ -4151,7 +4151,7 @@ def test_jit_token_input(self): self.assertIsInstance(out2, core.Token) def test_uneven_sharding_wsc(self): - mesh = jtu.create_global_mesh( + mesh = jtu.create_mesh( (2, 1, 1, 1, 1), ('data', 'expert', 'fsdp', 'seq', 'model') ) @@ -4198,7 +4198,7 @@ def get_wsc_eqn_sharding(jaxpr): for s in core.subjaxprs(jaxpr): return get_wsc_eqn_sharding(s) - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) inp = jnp.ones((10, 10)) def a_function(x): @@ -4341,7 +4341,7 @@ def test_device_put_efficient_reshard_complex_mesh(self, shape): self.assertEqual(out2.sharding, s1) def test_convert_element_type_sharding(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = np.arange(16).reshape(8, 2) @@ -4405,7 +4405,7 @@ def test_jnp_array_sharded_array_no_op(self): self.assertEqual(out.unsafe_buffer_pointer(), arr.unsafe_buffer_pointer()) def test_wsc_named_sharding_nullary(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) s = NamedSharding(mesh, P()) @jax.jit @@ -4417,7 +4417,7 @@ def f(): @jtu.run_on_devices('tpu', 'gpu') def test_aot_device_mismatch(self): - mesh = jtu.create_global_mesh((1,), 'x') + mesh = jtu.create_mesh((1,), 'x') np_inp = np.arange(8) arr = jax.device_put(np_inp, NamedSharding(mesh, P())) @@ -4466,7 +4466,7 @@ def f(x): @unittest.skipIf(xla_extension_version < 281, 'Requires xla_extension_version >= 281') def test_wsc_abstract_mesh(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -4487,7 +4487,7 @@ def f(x): @unittest.skipIf(xla_extension_version < 281, 'Requires xla_extension_version >= 281') def test_wsc_sds_abstract_mesh(self): - mesh = jtu.create_global_mesh((2,), 'x') + mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P()) abstract_mesh = mesh_lib.AbstractMesh(mesh.shape_tuple) @@ -4502,7 +4502,7 @@ def f(x): @unittest.skipIf(xla_extension_version < 281, 'Requires xla_extension_version >= 281') def test_wsc_vmap_abstract_mesh(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) @@ -4520,7 +4520,7 @@ def f(x): @unittest.skipIf(xla_extension_version < 281, 'Requires xla_extension_version >= 281') def test_wsc_abstract_mesh_errors(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) s_abs = NamedSharding(abstract_mesh, P('x')) @@ -4535,7 +4535,7 @@ def test_wsc_abstract_mesh_errors(self): arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) abs_mesh2 = mesh_lib.AbstractMesh( - jtu.create_global_mesh((2,), 'y').shape_tuple) + jtu.create_mesh((2,), 'y').shape_tuple) with self.assertRaisesRegex( ValueError, 'Mesh shape of the input.*does not' @@ -4551,7 +4551,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): @config.sharding_in_types(True) def test_basic_mul(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4783,7 +4783,7 @@ def h(x): ) def test_pjit_with_deleted_input_at_first_call(self, committed): shape = (8,) - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) inp_data = np.arange(math.prod(shape)).reshape(shape) if committed: s = NamedSharding(mesh, P('x',)) @@ -4801,7 +4801,7 @@ def test_pjit_with_deleted_input_at_first_call(self, committed): ) def test_pjit_with_deleted_input_at_subsequent_call(self, committed): shape = (8,) - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) inp_data = np.arange(math.prod(shape)).reshape(shape) if committed: s = NamedSharding(mesh, P('x',)) @@ -4836,7 +4836,7 @@ def f(x, y): g(x, y2) def test_dce_no_array(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x'))) @jax.jit @@ -5181,7 +5181,7 @@ def test_hlo_sharding_manual_replicated(self): def test_op_sharding_cache_on_mesh_pspec_sharding(self): ndim = 2 - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps1 = NamedSharding(mesh, P('x', 'y')) op1 = mps1._to_xla_hlo_sharding(ndim) cache_info1 = sharding_impls.named_sharding_to_xla_hlo_sharding.cache_info() @@ -5196,7 +5196,7 @@ def test_op_sharding_cache_on_mesh_pspec_sharding(self): self.assertEqual(cache_info2.currsize, cache_info1.currsize) def test_get_partition_spec(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y', None)) self.assertEqual(s._parsed_pspec.get_partition_spec(), P('x', 'y', None)) @@ -5221,7 +5221,7 @@ def test_mesh_with_string_axis_names(self): self.assertTupleEqual(mesh.axis_names, ('dp',)) def test_sharded_in_place_assignment(self): - mesh = jtu.create_global_mesh((8,), ('data',)) + mesh = jtu.create_mesh((8,), ('data',)) idx = [0, 2, 5, 7, 8, 10, 13, 15] n = 16 @@ -5248,7 +5248,7 @@ def setUp(self): raise unittest.SkipTest('Shardy is not available.') def test_lowering_input_output_sharding(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5260,7 +5260,7 @@ def f(x): self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text()) def test_lowering_with_sharding_constraint(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) arr = np.arange(16).reshape(4, 2, 2) @jax.jit @@ -5272,7 +5272,7 @@ def f(x): self.assertIn('<@mesh, [{"x"}, {}, {"y"}]>', lowered_str) def test_lowering_with_sharding_constraint_unconstrained(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) arr = np.arange(16).reshape(4, 2, 2) @jax.jit @@ -5286,7 +5286,7 @@ def f(x): # TODO(bartchr): run on CPU once Shardy is added to the XLA CPU pipeline. @jtu.skip_on_devices('cpu') def test_compile_with_inferred_out_sharding(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jax.device_put(np.arange(8 * 4).reshape(8, 4), jax.sharding.NamedSharding(mesh, P('x', 'y'))) y = jax.device_put(np.arange(4 * 16).reshape(4, 16), diff --git a/tests/random_test.py b/tests/random_test.py index 80e8ea76f82c..72b57d3f8723 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -957,7 +957,7 @@ def test_device_put_replicated(self): def test_make_array_from_callback(self): devices = jax.devices() shape = (len(devices),) - mesh = jtu.create_global_mesh((len(devices),), ('x',)) + mesh = jtu.create_mesh((len(devices),), ('x',)) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) def callback(index): i = jnp.arange(len(devices))[index[0]] @@ -969,7 +969,7 @@ def callback(index): def test_make_array_from_single_device_arrays(self): devices = jax.devices() shape = (len(devices),) - mesh = jtu.create_global_mesh((len(devices),), ('x',)) + mesh = jtu.create_mesh((len(devices),), ('x',)) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) keys = random.split(random.key(0), len(devices)) arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)] diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 11305e937a08..10267ff5eb98 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -39,7 +39,7 @@ class ShardAlikeDownstreamTest(jtu.JaxTestCase): def test_full_like(self): x = jnp.arange(16, dtype='float32').reshape(8, 2) - mesh = jtu.create_global_mesh((8,), ("i",)) + mesh = jtu.create_mesh((8,), ("i",)) x = jax.device_put(x, NamedSharding(mesh, P('i', None))) y = jnp.full_like(x, 1) self.assertEqual(x.sharding, y.sharding) @@ -51,7 +51,7 @@ def setUp(self): super().setUp() def test_basic(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -68,7 +68,7 @@ def f(x): self.assertArraysEqual(out, np_inp * np_inp * 4) def test_output_sharded_alike_input(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -83,7 +83,7 @@ def f(x): self.assertArraysEqual(out, np_inp * 2) def test_arange_shard_alike_jit(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -98,7 +98,7 @@ def f(x): self.assertArraysEqual(out, np_inp) def test_different_shapes(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x',)) inp = jax.device_put(np_inp, s) @@ -113,7 +113,7 @@ def f(x): f(inp) def test_double_shard_alike(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -131,7 +131,7 @@ def f(x): self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) def test_shard_like_eager(self): - mesh = jtu.create_global_mesh((4, 1), ('x', 'y')) + mesh = jtu.create_mesh((4, 1), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -145,7 +145,7 @@ def f(x): self.assertArraysEqual(out, np_inp) def test_shard_map(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -167,7 +167,7 @@ def f(x): self.assertEqual(out2.sharding, s) def test_grad(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) np_inp = np.arange(8.) s = NamedSharding(mesh, P('x')) inp = jax.device_put(np_inp, s) @@ -188,7 +188,7 @@ def f(x): jax.grad(jax.jit(f))(inp) # doesn't crash def test_shard_input_as_output(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) np_inp = np.arange(8.) s = NamedSharding(mesh, P('x')) @@ -218,7 +218,7 @@ def g(x): self.assertEqual(out4.sharding, s) def test_shard_alike_inputs(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8.) s = NamedSharding(mesh, P('x')) rep_s = NamedSharding(mesh, P()) @@ -237,7 +237,7 @@ def f(x, y): self.assertEqual(out2.sharding, s) def test_vmap_one_mapped(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(2) s = NamedSharding(mesh, P('y')) inp = jax.device_put(np_inp, s) @@ -256,7 +256,7 @@ def _shard_slice_like_arg(s): self.assertArraysEqual(out, np.tile(np_inp, [8, 1])) def test_vmap_both_mapped(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp1 = jax.device_put(np_inp, s) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 16daaca6d48e..857c69ce3d91 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -749,7 +749,7 @@ def f(x): @unittest.skipIf(xla_extension_version < 281, 'Requires xla_extension_version >= 281') def test_shard_map_abstract_mesh(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -807,7 +807,7 @@ def f(x): @unittest.skipIf(xla_extension_version < 281, 'Requires xla_extension_version >= 281') def test_shmap_abstract_mesh_errors(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) @@ -819,7 +819,7 @@ def test_shmap_abstract_mesh_errors(self): out_specs=P('x'))(jnp.arange(8)) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) - mesh2 = jtu.create_global_mesh((2,), 'y') + mesh2 = jtu.create_mesh((2,), 'y') abs_mesh2 = AbstractMesh(mesh2.shape_tuple) with self.assertRaisesRegex( ValueError, @@ -901,7 +901,7 @@ def f(_): @jax.legacy_prng_key('allow') def test_prngkeyarray_eager(self): # https://github.com/google/jax/issues/15398 - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) rng = jax.random.PRNGKey(0) @@ -917,7 +917,7 @@ def f(key): _ = g(sharded_rng) # don't crash! def test_functools_partial_rank_error(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial def f(x): @@ -929,7 +929,7 @@ def f(x): g(x) def test_in_specs_none_error(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) def f(x): return x @@ -943,7 +943,7 @@ def f(x): return x shard_map(f, mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash def test_scan_rep_rule(self): - mesh = jtu.create_global_mesh((2, 2,), ('x', 'y')) + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) def f(x, y, z): x, y, z = x.sum(), y.sum(), z.sum() @@ -996,7 +996,7 @@ def foo_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents return foo(x), 3. * x_dot - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) @@ -1015,7 +1015,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) @@ -1029,7 +1029,7 @@ def foo(): if jit: foo = jax.jit(foo) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) ans = shard_map(foo, mesh, in_specs=(), out_specs=P('x'))() expected = jnp.arange(4.) self.assertAllClose(ans, expected, check_dtypes=False) @@ -1045,7 +1045,7 @@ def foo(): if jit: foo = jax.jit(foo) - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) ans1, ans2, ans3 = shard_map(foo, mesh, in_specs=(), out_specs=P('i', 'j'))() expected1 = jnp.arange(4.)[:, None] + jnp.zeros((4, 2)) @@ -1056,7 +1056,7 @@ def foo(): self.assertAllClose(ans3, expected3, check_dtypes=False) def test_axis_index_eager(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=(), out_specs=P()) def foo(): @@ -1068,7 +1068,7 @@ def foo(): def test_jaxpr_shardings_with_no_outputs(self): # https://github.com/google/jax/issues/15385 - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) @jax.jit @partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('i')) @@ -1084,7 +1084,7 @@ def g(a_block): g(np.arange(32)) # don't crash def test_device_put(self): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): @@ -1109,7 +1109,7 @@ def g(x): def test_key_array_with_replicated_last_tile_dim(self): # See https://github.com/google/jax/issues/16137 - mesh = jtu.create_global_mesh((2, 4), ('i', 'j')) + mesh = jtu.create_mesh((2, 4), ('i', 'j')) def f(rng): @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'), @@ -1149,7 +1149,7 @@ def assert_dce_result(self, jaxpr: core.Jaxpr, used_outputs: list[bool], jtu.check_grads(f, inputs_dce, order=2, modes=['rev']) def test_returned_out_sharding(self): - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(jnp.zeros((2, 2)), s) out = shard_map(lambda x: x, mesh, P('x', 'y'), P('x', 'y'))(inp) @@ -1157,7 +1157,7 @@ def test_returned_out_sharding(self): self.assertArraysEqual(out, inp) def test_dce(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) def f(x, y, z): @partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P(None, 'i')), @@ -1208,7 +1208,7 @@ def g(y, z): check_diff=False) def test_post_process_partial_eval_with_scalar_res(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) g = jax.grad(lambda x: shard_map(lambda: jnp.sin(x), mesh=mesh, in_specs=P(), out_specs=P())())(2.0) self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False) @@ -1239,7 +1239,7 @@ def test_rewrite_process_call(self): def f(x): return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x')) x = jnp.arange(4.) y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call @@ -1248,7 +1248,7 @@ def f(x): def test_rewrite_post_process_call(self): # We shouldn't hit post_process_call here because of RewriteTrace's dynamic # behavior (i.e. no data dependence). - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @jax.jit @partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) @@ -1270,7 +1270,7 @@ def foo_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents return foo(x), 2. * x_dot - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: @@ -1298,7 +1298,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: @@ -1326,7 +1326,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: @@ -1343,7 +1343,7 @@ def foo_bwd(_, y_bar): def test_same_pspec_eager_shard_map(self): # This behavior is not guaranteed by JAX and this test can be changed if # the behavior changes. - mesh = jtu.create_global_mesh((1, 4, 1), ('data', 'seq', 'model')) + mesh = jtu.create_mesh((1, 4, 1), ('data', 'seq', 'model')) def f(x): return x * x + 2 @@ -1371,7 +1371,7 @@ def foo_bwd(y, _): foo.defvjp(foo_fwd, foo_bwd) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x, y: foo(x, y) * y, mesh, in_specs=(P(), P('x')), out_specs=P('x')) if jit: @@ -1406,7 +1406,7 @@ def foo_scan(x): y, _ = jax.lax.scan(lambda x, _: (foo(x), None), x, None, length=1) return y - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo_scan(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: @@ -1421,7 +1421,7 @@ def foo_scan(x): self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True) def test_transpose_identity(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) def f(x): @@ -1446,7 +1446,7 @@ def g(x): self.assertLen(e2.params['jaxpr'].eqns, 1) def test_fanout_specs_transpose_to_psum(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x')) def f(x): @@ -1459,7 +1459,7 @@ def f(x): self.assertEqual(e2.params['axes'], ('x',)) def test_fanin_psum_transposes_to_fanout(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P()) def f(x): @@ -1471,7 +1471,7 @@ def f(x): self.assertEqual(str(e1.primitive), 'pbroadcast') def test_psum_with_implicit_fanout_self_transposes(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): @@ -1484,7 +1484,7 @@ def f(x): self.assertEqual(str(e2.primitive), 'pbroadcast') def test_rewrite_binops(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=(P(), P('x')), out_specs=P('x')) def f(x, y): @@ -1497,7 +1497,7 @@ def f(x, y): self.assertEqual(e.params['axes'], ('x',)) def test_rewrite_scan(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): @@ -1515,7 +1515,7 @@ def f(x): def test_check_rep_false_grads(self): # This test is redundant with the systematic tests below, but it serves as a # direct regression test for a bug. - mesh = jtu.create_global_mesh((4,), ('heads',)) + mesh = jtu.create_mesh((4,), ('heads',)) def f(q, k, v): @@ -1549,7 +1549,7 @@ def bar(x): @parameterized.parameters(it.product([True, False], repeat=2)) def test_res_forwarding_optimization(self, jit, remat): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): @@ -1572,7 +1572,7 @@ def f(x): @parameterized.parameters(it.product([True, False], repeat=2)) def test_res_forwarding_optimization_complex(self, jit, remat): # like the above test, but a different function `f` - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): @@ -1594,7 +1594,7 @@ def f(x): @parameterized.parameters([True, False]) def test_check_rep_failure_inside_rule(self, jit): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) def loss(w, x): @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) @@ -1608,7 +1608,7 @@ def f(x): jax.grad(loss)(3.0, jnp.arange(8.)) # don't crash def test_conv_general_dilated(self): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) dot = partial(lax.conv_general_dilated, window_strides=(), padding='VALID', dimension_numbers=('NC', 'IO', 'NC')) @@ -1624,25 +1624,25 @@ def f(x, y): self.assertAllClose(y, a @ b, check_dtypes=False, atol=1e-2, rtol=1e-2) def test_cumsum(self): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8.) shard_map(jnp.cumsum, mesh=mesh, in_specs=P('i'), out_specs=P('i') )(x) # don't crash def test_custom_jvp_inside_jit(self): - mesh = jtu.create_global_mesh((4,), ('batch',)) + mesh = jtu.create_mesh((4,), ('batch',)) x = shard_map(jax.jit(jax.nn.relu), mesh=mesh, in_specs=P('batch'), out_specs=P('batch'))(jnp.arange(16.)) # don't crash def test_random_normal_rules(self): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) keys = jax.random.split(jax.random.key(0), 4) shard_map(lambda k: jax.random.normal(k[0], (1,)), mesh=mesh, in_specs=P('i'), out_specs=P('i'))(keys) # don't crash def test_erf_rules(self): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(16.) shard_map(jax.lax.erf, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) # don't crash @@ -1732,7 +1732,7 @@ def f(inputs): modes=['rev'], atol=1e-3, rtol=1e-3) def test_partial_auto(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( @@ -1759,7 +1759,7 @@ def f(x): def test_sharded_prng_with_abstract_mesh(self): shape = (8, 2, 2) - mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) np_inp = np.arange(math.prod(shape), dtype=np.uint32).reshape(shape) key = prng.random_seed(np_inp, impl=prng.threefry_prng_impl) @@ -1774,7 +1774,7 @@ def shard_key(key): self.assertEqual(out.sharding, NamedSharding(mesh, P())) def test_partial_auto_error_wsc_manual(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( @@ -1797,7 +1797,7 @@ def f(x): f(v) def test_partial_auto_error_invalid_auto(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( @@ -1820,7 +1820,7 @@ def f(x): f(v) def test_partial_auto_error_wrong_in_specs(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( @@ -1843,7 +1843,7 @@ def f(x): f(v) def test_nested_partial_auto(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): return x * x @@ -1866,7 +1866,7 @@ def f(x): self.assertAllClose(v*v, f(v), check_dtypes=False) def test_axis_size_1_partial_auto(self): - mesh = jtu.create_global_mesh((1, 2, 2), ('i', 'j', 'k')) + mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k')) def h(x): return x * x @@ -1884,7 +1884,7 @@ def f(x): self.assertAllClose(v*v, f(v), check_dtypes=False) def test_partial_auto_of_pjit(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def h(): def _make_zeros(): @@ -1901,7 +1901,7 @@ def f(): self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) def test_partial_auto_of_pjit_different_mesh(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) mesh2 = jax.sharding.Mesh(mesh.devices, ('k', 'l')) def h(): @@ -1920,7 +1920,7 @@ def f(): def test_vmap_grad_shmap_spmd_axis_name_residuals(self): # https://github.com/google/jax/pull/21032 - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial( shard_map, @@ -1937,7 +1937,7 @@ def f(x): def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self): # https://github.com/google/jax/pull/21056 - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial(jax.remat, policy=lambda *_, **__: True) @partial( @@ -1955,7 +1955,7 @@ def f(x): def test_grad_shmap_residuals_axis_names_in_mesh_order(self): # https://github.com/google/jax/issues/21236 - mesh = jtu.create_global_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a')) + mesh = jtu.create_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a')) @partial( shard_map, @@ -1975,7 +1975,7 @@ def f(x): ) def test_vmap_spmd_axis_name_error(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial( shard_map, @@ -2005,7 +2005,7 @@ def g(x): jax.vmap(g, spmd_axis_name='i')(xs) def test_in_spec_none(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(8).reshape(4, 2) @@ -2056,7 +2056,7 @@ def f4(o1, o2, x, o3): self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def test_in_spec_none_divisibility_errors(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(4).reshape(2, 2) with self.assertRaisesRegex(ValueError, 'divisible'): @@ -2078,7 +2078,7 @@ def test_in_spec_none_divisibility_errors(self): )((object(), object()), x) def test_in_spec_none_rank_errors(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(4) with self.assertRaisesRegex(ValueError, 'rank'): @@ -2101,7 +2101,7 @@ def test_in_spec_none_rank_errors(self): def test_custom_linear_solve_rep_rules(self): # https://github.com/google/jax/issues/20162 - mesh = jtu.create_global_mesh((1,), ('i',)) + mesh = jtu.create_mesh((1,), ('i',)) a = jnp.array(1).reshape(1, 1) b = jnp.array(1).reshape(1) @@ -2113,7 +2113,7 @@ def f(a, b): _ = f(a, b) # don't crash def test_temporary_error_suppression_flag(self): - mesh = jtu.create_global_mesh((2,), ('i',)) + mesh = jtu.create_mesh((2,), ('i',)) def f(x, y): z = shard_map(lambda x, y: x + jax.lax.all_gather(y, 'i', tiled=True), @@ -2375,7 +2375,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): @staticmethod def make_mesh(mesh_shape): - return jtu.create_global_mesh(tuple(mesh_shape.values()), tuple(mesh_shape)) + return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape)) @parameterized.named_parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) From 87c78f8ece4827e479767c768258fd48e0e74b37 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Sep 2024 16:30:43 -0700 Subject: [PATCH 348/702] doc: update examples of deprecation timelines --- docs/deprecation.md | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/deprecation.md b/docs/deprecation.md index 7a8b867b6f2e..385d31271421 100644 --- a/docs/deprecation.md +++ b/docs/deprecation.md @@ -13,24 +13,25 @@ nine months longer than SPEC-0 recommends. This means we support at least: -* All minor Python releases in the 45 months prior to each JAX release. For example: +* All Python feature releases in the 45 months prior to each JAX release. For example: - * **Python 3.9** was released October 2020, and will be supported in new JAX releases at least until **July 2024**. * **Python 3.10** was released October 2021, and will be supported in new JAX releases at least until **July 2025**. * **Python 3.11** was released October 2022, and will be supported in new JAX releases at least until **July 2026**. + * **Python 3.12** was released October 2023, and will be supported in new JAX releases at least until **July 2027**. -* All minor NumPy releases in the 24 months prior to each JAX release. For example: +* All NumPy feature releases in the 24 months prior to each JAX release. For example: - * **NumPy 1.22** was released December 2021, and will be supported in new JAX releases at least until **December 2023**. - * **NumPy 1.23** was released June 2022, and will be supported in new JAX releases at least until **June 2024**. * **NumPy 1.24** was released December 2022, and will be supported in new JAX releases at least until **December 2024**. + * **NumPy 1.25** was released June 2023, and will be supported in new JAX releases at least until **June 2025** + * **NumPy 1.26** was released September 2023, and will be supported in new JAX releases at least until **September 2025** + * **NumPy 2.0** was released June 2024, and will be supported in new JAX releases at least until **June 2026** -* All minor SciPy releases in the 24 months prior to each JAX release, starting - with SciPy version 1.9. For example: +* All SciPy feature releases in the 24 months prior to each JAX release. For example: - * **Scipy 1.9** was released July 2022, and will be supported in new JAX releases at least until **July 2024**. * **Scipy 1.10** was released January 2023, and will be supported in new JAX releases at least until **January 2025**. * **Scipy 1.11** was released June 2023, and will be supported in new JAX releases at least until **June 2025**. + * **Scipy 1.12** was released January 2024, and will be supported in new JAX releases at least until **January 2026**. + * **Scipy 1.13** was released April 2024, and will be supported in new JAX releases at least until **April 2026**. JAX releases may support older versions of Python, NumPy, and SciPy than strictly required by this policy, but support for older versions may be dropped at any time beyond the listed From c1d3c2db9f4e63a99c479a663ace47540ee36054 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 3 Sep 2024 22:56:48 -0700 Subject: [PATCH 349/702] [Mosaic TPU] Fix mosaic alignment check in concatenate rule. PiperOrigin-RevId: 670837792 --- jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 25775a1994ab..7f1f9b63bdc8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2466,11 +2466,11 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, "Not implemented: Only native tiling with offset (0, 0) is supported " "when concatenation along tiling dims."); } - // Check if shapes of src and res are aligned to native tiling. + // Check if the concat dim size of src and res is aligned to native tiling. auto check_aligned = [&](const VectorType &vty) { + auto i = dimension - res_ty.getRank(); return vty.getRank() >= 2 && - *(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) == 0 && - *(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) == 0; + *(vty.getShape().end() + i) % *(layout.tiling().end() + i) == 0; }; bool is_aligned = check_aligned(res_ty); int op_idx = 0; From 8310a6ab1b2c0231fdd8c1f5b4dd4a128d8176fa Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 3 Sep 2024 23:38:39 -0700 Subject: [PATCH 350/702] update example optimizers library docstring * JAXopt is being merged into Optax, so point only to Optax * Update Optax's github repository URL --- jax/example_libraries/optimizers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jax/example_libraries/optimizers.py b/jax/example_libraries/optimizers.py index f9b66ea1c082..3ad717ce358a 100644 --- a/jax/example_libraries/optimizers.py +++ b/jax/example_libraries/optimizers.py @@ -16,7 +16,7 @@ You likely do not mean to import this module! The optimizers in this library are intended as examples only. If you are looking for a fully featured optimizer -library, two good options are JAXopt_ and Optax_. +library, consider Optax_. This module contains some convenient optimizer definitions, specifically initialization and update functions, which can be used with ndarrays or @@ -85,8 +85,7 @@ def step(step, opt_state): value, opt_state = step(i, opt_state) -.. _JAXopt: https://github.com/google/jaxopt -.. _Optax: https://github.com/deepmind/optax +.. _Optax: https://github.com/google-deepmind/optax """ from __future__ import annotations From cb45fb426a57f65e954cf5265a58197ed3b1dd8c Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 4 Sep 2024 19:22:58 +0530 Subject: [PATCH 351/702] Better docs for jax.numpy: log and log1p --- jax/_src/numpy/ufuncs.py | 71 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 64b5235220cf..91356f40b2fa 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -374,9 +374,41 @@ def exp(x: ArrayLike, /) -> Array: """ return lax.exp(*promote_args_inexact('exp', x)) -@implements(np.log, module='numpy') + @partial(jit, inline=True) def log(x: ArrayLike, /) -> Array: + """Calculate element-wise natural logarithm of the input. + + JAX implementation of :obj:`numpy.log`. + + Args: + x: input array or scalar. + + Returns: + An array containing the logarithm of each element in ``x``, promotes to inexact + dtype. + + See also: + - :func:`jax.numpy.exp`: Calculates element-wise exponential of the input. + - :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input. + - :func:`jax.numpy.log1p`: Calculates element-wise logarithm of one plus input. + + Examples: + ``jnp.log`` and ``jnp.exp`` are inverse functions of each other. Applying + ``jnp.log`` on the result of ``jnp.exp(x)`` yields the original input ``x``. + + >>> x = jnp.array([2, 3, 4, 5]) + >>> jnp.log(jnp.exp(x)) + Array([2., 3., 4., 5.], dtype=float32) + + Using ``jnp.log`` we can demonstrate well-known properties of logarithms, such + as :math:`log(a*b) = log(a)+log(b)`. + + >>> x1 = jnp.array([2, 1, 3, 1]) + >>> x2 = jnp.array([1, 3, 2, 4]) + >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2)) + Array(True, dtype=bool) + """ return lax.log(*promote_args_inexact('log', x)) @@ -423,9 +455,44 @@ def expm1(x: ArrayLike, /) -> Array: """ return lax.expm1(*promote_args_inexact('expm1', x)) -@implements(np.log1p, module='numpy') + @partial(jit, inline=True) def log1p(x: ArrayLike, /) -> Array: + """Calculates element-wise logarithm of one plus input, ``log(x+1)``. + + JAX implementation of :obj:`numpy.log1p`. + + Args: + x: input array or scalar. + + Returns: + An array containing the logarithm of one plus of each element in ``x``, + promotes to inexact dtype. + + Note: + ``jnp.log1p`` is more accurate than when using the naive computation of + ``log(x+1)`` for small values of ``x``. + + See also: + - :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the + input. + - :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input. + - :func:`jax.numpy.log`: Calculates element-wise logarithm of the input. + + Examples: + >>> x = jnp.array([2, 5, 9, 4]) + >>> jnp.allclose(jnp.log1p(x), jnp.log(x+1)) + Array(True, dtype=bool) + + For values very close to 0, ``jnp.log1p(x)`` is more accurate than + ``jnp.log(x+1)``: + + >>> x1 = jnp.array([1e-4, 1e-6, 2e-10]) + >>> jnp.expm1(jnp.log1p(x1)) # doctest: +SKIP + Array([1.00000005e-04, 9.99999997e-07, 2.00000003e-10], dtype=float32) + >>> jnp.expm1(jnp.log(x1+1)) # doctest: +SKIP + Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32) + """ return lax.log1p(*promote_args_inexact('log1p', x)) From 10893033b9633205c9db3897ec080770329e096c Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 4 Sep 2024 19:57:51 +0530 Subject: [PATCH 352/702] Remove unused docstring addition: _PRECISION_DOC --- jax/_src/numpy/lax_numpy.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 7c0162d784c0..8d17f09ac3db 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -99,16 +99,6 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: else: return core.canonicalize_shape(shape, context) -# Common docstring additions: - -_PRECISION_DOC = """\ -In addition to the original NumPy arguments listed below, also supports -``precision`` for extra control over matrix-multiplication precision -on supported devices. ``precision`` may be set to ``None``, which means -default precision for the backend, a :class:`~jax.lax.Precision` enum value -(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple -of two :class:`~jax.lax.Precision` enums indicating separate precision for each argument. -""" # Some objects below rewrite their __module__ attribute to this name. _PUBLIC_MODULE_NAME = "jax.numpy" From d6394c0795e3af6a7a55f7854307d040b34b0d3b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 4 Sep 2024 10:10:31 -0700 Subject: [PATCH 353/702] random.key_impl: improve repr of output --- jax/_src/random.py | 12 +++++++++--- tests/random_test.py | 6 ++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 7e006fac8319..3b85dd01875b 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -146,11 +146,17 @@ class PRNGSpec: def __init__(self, impl): self._impl = impl - def __str__(self) -> str: return str(self._impl) - def __hash__(self) -> int: return hash(self._impl) + def __repr__(self) -> str: + return f"PRNGSpec({self._impl.name!r})" + + def __str__(self) -> str: + return str(self._impl) + + def __hash__(self) -> int: + return hash(self._impl) def __eq__(self, other) -> bool: - return self._impl == other._impl + return isinstance(other, PRNGSpec) and self._impl == other._impl # TODO(frostig,vanderplas): remove PRNGImpl from this union when it's diff --git a/tests/random_test.py b/tests/random_test.py index 72b57d3f8723..941172f75278 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1119,6 +1119,12 @@ class A: pass with self.assertRaisesRegex(TypeError, 'unrecognized type .* PRNG'): jax.random.key(42, impl=A()) + @jtu.sample_product(name=[name for name, _ in PRNG_IMPLS]) + def test_key_spec_repr(self, name): + key = jax.random.key(42, impl=name) + spec = jax.random.key_impl(key) + self.assertEqual(repr(spec), f"PRNGSpec({name!r})") + def test_keyarray_custom_vjp(self): # Regression test for https://github.com/google/jax/issues/18442 @jax.custom_vjp From 5b8c8dd12c5fda75d8df7b8a68986089a0d55fef Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 4 Sep 2024 14:43:54 -0400 Subject: [PATCH 354/702] Add citation for random.orthogonal. --- jax/_src/random.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jax/_src/random.py b/jax/_src/random.py index 6105d56f9148..84d7328166a6 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -2041,6 +2041,11 @@ def orthogonal( Returns: A random array of shape `(*shape, n, n)` and specified dtype. + + References: + .. [1] Mezzadri, Francesco. (2007). "How to generate random matrices from + the classical compact groups". Notices of the American Mathematical + Society, 54(5), 592-604. https://arxiv.org/abs/math-ph/0609050. """ shape = core.canonicalize_shape(shape) key, _ = _check_prng_key("orthogonal", key) From bf66e816ddbaedf47ac784603891a8e36a96735f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 4 Sep 2024 11:48:37 -0700 Subject: [PATCH 355/702] Split physical axes by default when device kind is `TPU v5 lite` to allow for mesh shapes (2, 2) when there are 8 v5e devices on a 4x2 topology. PiperOrigin-RevId: 671047455 --- jax/_src/sharding_impls.py | 9 +++++++-- tests/pjit_test.py | 8 ++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index a1444c3a2345..0b1dc082765e 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -31,7 +31,7 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge -from jax._src.mesh_utils import create_device_mesh +from jax._src import mesh_utils from jax._src.lib import xla_client as xc from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) @@ -1731,5 +1731,10 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], f'of mesh_shape {axis_shapes}') elif axis_size < len(devices): devices = devices[:axis_size] - mesh_devices = create_device_mesh(axis_shapes, devices) + if devices[0].device_kind == mesh_utils._TPU_V5_LITE: + allow_split_physical_axes = True + else: + allow_split_physical_axes = False + mesh_devices = mesh_utils.create_device_mesh( + axis_shapes, devices, allow_split_physical_axes=allow_split_physical_axes) return mesh_lib.Mesh(mesh_devices, axis_names) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3b0fbb1455f4..25b2680e8d15 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2351,7 +2351,7 @@ def test_pjit_pytree_inp_device_assignment_mismatch(self): msg = ("Received incompatible devices for pjitted computation. Got " r"argument {} of.* with shape int.*\[3\] and device ids " r"\[0\].*and argument {} of.* with shape int.*\[8,2\] and " - r"device ids \[0, 1, 2, 3\].*") + r"device ids.*") with self.assertRaisesRegex( ValueError, msg.format(r'tuple_inp\[0\]', r'tuple_inp\[1\]\[0\]')): @@ -2509,7 +2509,7 @@ def sharded_inp(inp): ValueError, "Received incompatible devices for jitted computation. Got argument " r"inp of.*sharded_inp with shape bfloat16\[8,2\] and device ids \[0\].*" - r"sharding_constraint inside jit with device ids \[0, 1, 2, 3\].*"): + r"sharding_constraint inside jit with device ids.*"): sharded_inp(committed_inp) @pjit @@ -2523,7 +2523,7 @@ def f(x, y, z): ValueError, "Received incompatible devices for pjitted computation. Got argument " r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*" - r"pjit inside pjit with device ids \[0, 1, 2, 3\].*"): + r"pjit inside pjit with device ids.*"): my_nested_pjit(committed_inp, committed_inp, committed_inp) @jtu.ignore_warning(category=DeprecationWarning, @@ -2540,7 +2540,7 @@ def sharded_zeros(shape, pspec): ValueError, "Received incompatible devices for jitted computation. Got explicit " r"output sharding with device ids \[0\].*sharding_constraint inside " - r"jit with device ids \[0, 1, 2, 3\].*"): + r"jit with device ids.*"): sharded_zeros((4096, 3072), P('x', 'y')) def test_concurrent_pjit(self): From e7d3785b18c68ab4b0674307783da7bb76de0095 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Sep 2024 15:40:51 -0700 Subject: [PATCH 356/702] Refactor & document cumulative reductions --- jax/_src/numpy/reductions.py | 288 ++++++++++++++++++++++++++++------- jax/numpy/__init__.pyi | 8 +- 2 files changed, 238 insertions(+), 58 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index dddb44dc9207..b6aea9e195a9 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -1787,73 +1787,253 @@ def __call__(self, a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) -> Array: ... -# TODO(jakevdp): should we change these semantics to match those of numpy? -CUML_REDUCTION_LAX_DESCRIPTION = """ -Unlike the numpy counterpart, when ``dtype`` is not specified the output dtype will always -match the dtype of the input. -""" - -def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array], - fill_nan: bool = False, fill_value: ArrayLike = 0, - promote_integers: bool = False) -> CumulativeReduction: - @implements(np_reduction, skip_params=['out'], - lax_description=CUML_REDUCTION_LAX_DESCRIPTION) - def cumulative_reduction(a: ArrayLike, axis: Axis = None, - dtype: DTypeLike | None = None, out: None = None) -> Array: - return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out) - - @partial(api.jit, static_argnames=('axis', 'dtype')) - def _cumulative_reduction(a: ArrayLike, axis: Axis = None, - dtype: DTypeLike | None = None, out: None = None) -> Array: - check_arraylike(np_reduction.__name__, a) - if out is not None: - raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} " - f"is not supported.") - dtypes.check_user_dtype_supported(dtype, np_reduction.__name__) - - if axis is None or _isscalar(a): - a = lax.reshape(a, (np.size(a),)) - if axis is None: - axis = 0 +def _cumulative_reduction( + name: str, reduction: Callable[..., Array], + a: ArrayLike, axis: int | None, dtype: DTypeLike | None, out: None, + fill_nan: bool = False, fill_value: ArrayLike = 0, + promote_integers: bool = False) -> Array: + """Helper function for implementing cumulative reductions.""" + check_arraylike(name, a) + if out is not None: + raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported") + dtypes.check_user_dtype_supported(dtype, name) + + if axis is None or _isscalar(a): + a = lax.reshape(a, (np.size(a),)) + if axis is None: + axis = 0 + + a_shape = list(np.shape(a)) + num_dims = len(a_shape) + axis = _canonicalize_axis(axis, num_dims) + + if fill_nan: + a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) + + result_type: DTypeLike = dtypes.dtype(dtype or a) + if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_): + result_type = _promote_integer_dtype(result_type) + result_type = dtypes.canonicalize_dtype(result_type) + + a = lax.convert_element_type(a, result_type) + result = reduction(a, axis) + + # We downcast to boolean because we accumulate in integer types + if dtypes.issubdtype(dtype, np.bool_): + result = lax.convert_element_type(result, np.bool_) + return result + + +@partial(api.jit, static_argnames=('axis', 'dtype')) +def cumsum(a: ArrayLike, axis: int | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Cumulative sum of elements along an axis. + + JAX implementation of :func:`numpy.cumsum`. + + Args: + a: N-dimensional array to be accumulated. + axis: integer axis along which to accumulate. If None (default), then + array will be flattened and accumulated along the flattened axis. + dtype: optionally specify the dtype of the output. If not specified, + then the output dtype will match the input dtype. + out: unused by JAX + + Returns: + An array containing the accumulated sum along the given axis. + + See also: + - :func:`jax.numpy.cumulative_sum`: cumulative sum via the array API standard. + - :meth:`jax.numpy.add.accumulate`: cumulative sum via ufunc methods. + - :func:`jax.numpy.nancumsum`: cumulative sum ignoring NaN values. + - :func:`jax.numpy.sum`: sum along axis + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.cumsum(x) # flattened cumulative sum + Array([ 1, 3, 6, 10, 15, 21], dtype=int32) + >>> jnp.cumsum(x, axis=1) # cumulative sum along axis 1 + Array([[ 1, 3, 6], + [ 4, 9, 15]], dtype=int32) + """ + return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out) + + +@partial(api.jit, static_argnames=('axis', 'dtype')) +def cumprod(a: ArrayLike, axis: int | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Cumulative product of elements along an axis. + + JAX implementation of :func:`numpy.cumprod`. + + Args: + a: N-dimensional array to be accumulated. + axis: integer axis along which to accumulate. If None (default), then + array will be flattened and accumulated along the flattened axis. + dtype: optionally specify the dtype of the output. If not specified, + then the output dtype will match the input dtype. + out: unused by JAX + + Returns: + An array containing the accumulated product along the given axis. + + See also: + - :meth:`jax.numpy.multiply.accumulate`: cumulative product via ufunc methods. + - :func:`jax.numpy.nancumprod`: cumulative product ignoring NaN values. + - :func:`jax.numpy.prod`: product along axis + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.cumprod(x) # flattened cumulative product + Array([ 1, 2, 6, 24, 120, 720], dtype=int32) + >>> jnp.cumprod(x, axis=1) # cumulative product along axis 1 + Array([[ 1, 2, 6], + [ 4, 20, 120]], dtype=int32) + """ + return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out) + - a_shape = list(np.shape(a)) - num_dims = len(a_shape) - axis = _canonicalize_axis(axis, num_dims) +@partial(api.jit, static_argnames=('axis', 'dtype')) +def nancumsum(a: ArrayLike, axis: int | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Cumulative sum of elements along an axis, ignoring NaN values. - if fill_nan: - a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) + JAX implementation of :func:`numpy.nancumsum`. - result_type: DTypeLike = dtypes.dtype(dtype or a) - if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_): - result_type = _promote_integer_dtype(result_type) - result_type = dtypes.canonicalize_dtype(result_type) + Args: + a: N-dimensional array to be accumulated. + axis: integer axis along which to accumulate. If None (default), then + array will be flattened and accumulated along the flattened axis. + dtype: optionally specify the dtype of the output. If not specified, + then the output dtype will match the input dtype. + out: unused by JAX + + Returns: + An array containing the accumulated sum along the given axis. + + See also: + - :func:`jax.numpy.cumsum`: cumulative sum without ignoring NaN values. + - :func:`jax.numpy.cumulative_sum`: cumulative sum via the array API standard. + - :meth:`jax.numpy.add.accumulate`: cumulative sum via ufunc methods. + - :func:`jax.numpy.sum`: sum along axis + + Examples: + >>> x = jnp.array([[1., 2., jnp.nan], + ... [4., jnp.nan, 6.]]) + + The standard cumulative sum will propagate NaN values: + + >>> jnp.cumsum(x) + Array([ 1., 3., nan, nan, nan, nan], dtype=float32) + + :func:`~jax.numpy.nancumsum` will ignore NaN values, effectively replacing + them with zeros: - a = lax.convert_element_type(a, result_type) - result = reduction(a, axis) + >>> jnp.nancumsum(x) + Array([ 1., 3., 3., 7., 7., 13.], dtype=float32) - # We downcast to boolean because we accumulate in integer types - if dtypes.issubdtype(dtype, np.bool_): - result = lax.convert_element_type(result, np.bool_) - return result + Cumulative sum along axis 1: - return cumulative_reduction + >>> jnp.nancumsum(x, axis=1) + Array([[ 1., 3., 3.], + [ 4., 4., 10.]], dtype=float32) + """ + return _cumulative_reduction("nancumsum", lax.cumsum, a, axis, dtype, out, + fill_nan=True, fill_value=0) + + +@partial(api.jit, static_argnames=('axis', 'dtype')) +def nancumprod(a: ArrayLike, axis: int | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Cumulative product of elements along an axis, ignoring NaN values. + + JAX implementation of :func:`numpy.nancumprod`. + + Args: + a: N-dimensional array to be accumulated. + axis: integer axis along which to accumulate. If None (default), then + array will be flattened and accumulated along the flattened axis. + dtype: optionally specify the dtype of the output. If not specified, + then the output dtype will match the input dtype. + out: unused by JAX + Returns: + An array containing the accumulated product along the given axis. + + See also: + - :func:`jax.numpy.cumprod`: cumulative product without ignoring NaN values. + - :meth:`jax.numpy.multiply.accumulate`: cumulative product via ufunc methods. + - :func:`jax.numpy.prod`: product along axis + + Examples: + >>> x = jnp.array([[1., 2., jnp.nan], + ... [4., jnp.nan, 6.]]) + + The standard cumulative product will propagate NaN values: + + >>> jnp.cumprod(x) + Array([ 1., 2., nan, nan, nan, nan], dtype=float32) + + :func:`~jax.numpy.nancumprod` will ignore NaN values, effectively replacing + them with ones: + + >>> jnp.nancumprod(x) + Array([ 1., 2., 2., 8., 8., 48.], dtype=float32) + + Cumulative product along axis 1: + + >>> jnp.nancumprod(x, axis=1) + Array([[ 1., 2., 2.], + [ 4., 4., 24.]], dtype=float32) + """ + return _cumulative_reduction("nancumprod", lax.cumprod, a, axis, dtype, out, + fill_nan=True, fill_value=1) + + +@partial(api.jit, static_argnames=('axis', 'dtype')) +def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Utility function to compute cumsum with integer promotion.""" + return _cumulative_reduction("_cumsum_with_promotion", lax.cumsum, + a, axis, dtype, out, promote_integers=True) -cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False) -cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False) -nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum, - fill_nan=True, fill_value=0) -nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, - fill_nan=True, fill_value=1) -_cumsum_with_promotion = _make_cumulative_reduction( - np.cumsum, lax.cumsum, fill_nan=False, promote_integers=True -) -@implements(getattr(np, 'cumulative_sum', None)) def cumulative_sum( x: ArrayLike, /, *, axis: int | None = None, dtype: DTypeLike | None = None, include_initial: bool = False) -> Array: + """Cumulative sum along the axis of an array. + + JAX implementation of :func:`numpy.cumulative_sum`. + + Args: + x: N-dimensional array + axis: integer axis along which to accumulate. If ``x`` is one-dimensional, + this argument is optional. + dtype: optional dtype of the output. + include_initial: if True, then include the initial value in the cumulative + sum. Default is False. + + Returns: + An array containing the accumulated values. + + See Also: + - :func:`jax.numpy.cumsum`: alternative API for cumulative sum. + - :func:`jax.numpy.nancumsum`: cumulative sum while ignoring NaN values. + - :func:`jax.numpy.add.accumulate`: cumulative sum via the ufunc API. + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.cumulative_sum(x, axis=1) + Array([[ 1, 3, 6], + [ 4, 9, 15]], dtype=int32) + >>> jnp.cumulative_sum(x, axis=1, include_initial=True) + Array([[ 0, 1, 3, 6], + [ 0, 4, 9, 15]], dtype=int32) + """ check_arraylike("cumulative_sum", x) x = lax_internal.asarray(x) if x.ndim == 0: diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 5e2c1dce4c3d..d5b66c1b3b32 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -272,10 +272,10 @@ def cross( axis: int | None = ..., ) -> Array: ... csingle: Any -def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... cumproduct = cumprod -def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ..., dtype: DTypeLike | None = ..., @@ -633,9 +633,9 @@ def nanargmin( out: None = ..., keepdims: builtins.bool | None = ..., ) -> Array: ... -def nancumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nancumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... -def nancumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nancumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def nanmax(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., From 0e6650e89d3abd1c2e9f8116ed378bc67b587ed9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 4 Sep 2024 12:31:19 -0700 Subject: [PATCH 357/702] filecheck test: use lax.cumsum directly to prevent false-positive --- tests/filecheck/subcomputations.filecheck.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/filecheck/subcomputations.filecheck.py b/tests/filecheck/subcomputations.filecheck.py index 1f8e9d32e5b1..b3c3191ca416 100644 --- a/tests/filecheck/subcomputations.filecheck.py +++ b/tests/filecheck/subcomputations.filecheck.py @@ -19,7 +19,6 @@ from absl import app import jax -from jax import numpy as jnp from jax.interpreters import mlir from jax._src.lib.mlir import ir import numpy as np @@ -39,7 +38,7 @@ def main(_): # CHECK-NOT: func private @cumsum @print_ir(np.empty([2, 7], np.int32), np.empty([2, 7], np.int32)) def cumsum_only_once(x, y): - return jnp.cumsum(x) + jnp.cumsum(y) + return jax.lax.cumsum(x) + jax.lax.cumsum(y) # Test merging modules # CHECK-LABEL: TEST: merge_modules From 3672b633c30fe82ef94d6cb83889894bdda64295 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Wed, 4 Sep 2024 12:40:56 -0700 Subject: [PATCH 358/702] Fix a deprecation warning for NumPy array conversion. To address https://github.com/google/jax/actions/runs/10654663500/job/29531268089#step:6:656 ``` E DeprecationWarning: NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 536870912 to int16 will fail in the future. E For the old behavior, usually: E np.array(value).astype(dtype) E will give the desired result (the cast overflows). ``` PiperOrigin-RevId: 671064730 --- jax/_src/dtypes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 81f4180a1c12..05350d8621f8 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -320,7 +320,7 @@ def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray: """ if dtype is None and type(x) in python_scalar_dtypes: dtype = _scalar_type_to_dtype(type(x), x) - return np.asarray(x, dtype) + return np.array(x).astype(dtype) iinfo = ml_dtypes.iinfo finfo = ml_dtypes.finfo From a8a55e0f2edd7e43e7ebd54171c8cf5fee973cb3 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 4 Sep 2024 12:47:57 -0700 Subject: [PATCH 359/702] Added pl.CompilerParams subclass for Mosaic GPU PiperOrigin-RevId: 671066741 --- jax/_src/pallas/core.py | 12 ++++++++---- jax/_src/pallas/mosaic/core.py | 6 +++--- jax/_src/pallas/mosaic_gpu/core.py | 19 +++++++++++++++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 9 ++------- jax/_src/pallas/pallas_call.py | 2 +- jax/experimental/pallas/__init__.py | 1 + .../paged_attention/paged_attention_kernel.py | 9 ++++++--- tests/pallas/mosaic_gpu_test.py | 9 ++++----- 8 files changed, 44 insertions(+), 23 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 03bfd28d0b9a..01b027d386c3 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -23,7 +23,7 @@ import functools import itertools import threading -from typing import Any, ClassVar, Hashable, Union +from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable import warnings import jax @@ -66,10 +66,14 @@ def __repr__(self): SEMAPHORE_INTERPRET_DTYPE = jnp.int16 -@dataclasses.dataclass(frozen=True) -class CompilerParams: +@runtime_checkable +class CompilerParams(Protocol): """Base class for compiler parameters.""" - PLATFORM: ClassVar[str] = "unspecified" + PLATFORM: ClassVar[str] + + # Subclasses must be dataclasses. + __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] + @dataclasses.dataclass(frozen=True) class NameAndSrcInfo: diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index e549ee05e770..61b1dc435e72 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,14 +19,14 @@ import dataclasses import enum import functools -from typing import Any, ClassVar, Hashable +from typing import Any, ClassVar, Hashable, Literal import jax from jax._src import core as jax_core from jax._src import dtypes from jax._src import util -import jax.numpy as jnp from jax._src.pallas import core as pallas_core +import jax.numpy as jnp import numpy as np map, unsafe_map = util.safe_map, map @@ -68,7 +68,7 @@ class TPUCompilerParams(pallas_core.CompilerParams): device_type: The device type to compile for. """ PLATFORM: ClassVar[str] = "mosaic" - dimension_semantics: Sequence[str] | None = None + dimension_semantics: Sequence[Literal["parallel", "arbitrary"]] | None = None allow_input_fusion: Sequence[bool] | None = None vmem_limit_bytes: int | None = None collective_id: int | None = None diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index fd06a9829644..6619a9acfd02 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -14,8 +14,10 @@ """Contains GPU-specific Pallas abstractions.""" +from collections.abc import Sequence import dataclasses import enum +from typing import ClassVar, Literal from jax import core as jax_core from jax._src.pallas import core as pallas_core import jax.numpy as jnp @@ -23,6 +25,23 @@ AbstractMemoryRef = pallas_core.AbstractMemoryRef +@dataclasses.dataclass(frozen=True) +class GPUCompilerParams(pallas_core.CompilerParams): + """Mosaic GPU compiler parameters. + + Attributes: + dimension_semantics: A list of dimension semantics for each grid + dimension of the kernel. Either "parallel" for dimensions that can + execute in any order, or "sequential" for dimensions that must be + executed sequentially. + num_stages: The number of pipline stages in the kernel. Defaults to 1, + meaning no pipelining is done. + """ + PLATFORM: ClassVar[str] = "mosaic_gpu" + dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None + num_stages: int = 1 + + class GPUMemorySpace(enum.Enum): GMEM = "gmem" SMEM = "smem" diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index fb0119025d4c..88cd4545ae7b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -20,7 +20,7 @@ import dataclasses import functools import math -from typing import Any, Literal, TypedDict, cast +from typing import Any, cast import jax from jax._src import core as jax_core @@ -152,11 +152,6 @@ def _eval_index_map( return tuple(result) -class Params(TypedDict, total=False): - num_stages: int - dimension_semantics: Sequence[Literal["sequential", "parallel"]] - - def lower_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, @@ -199,7 +194,7 @@ def lower_jaxpr_to_module( grid += (1,) * (3 - len(grid)) block = (128,) + (1,) * (len(grid) - 1) - params = Params(**compiler_params.get("mosaic_gpu", {})) + params = compiler_params.get("mosaic_gpu", {}) num_stages = params.get("num_stages", 1) dimension_semantics = params.get( "dimension_semantics", ["parallel"] * len(grid_mapping.grid) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 5bbf37dc3663..d3b0ea8ca080 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1291,7 +1291,7 @@ def pallas_call( if compiler_params is None: compiler_params = {} if isinstance(compiler_params, pallas_core.CompilerParams): - if compiler_params.PLATFORM not in ["mosaic", "triton"]: + if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]: raise ValueError( f"Unknown platform in compiler params: {compiler_params.PLATFORM}") compiler_params = { diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 9a768ed53e75..832f7b7d1184 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -21,6 +21,7 @@ from jax._src.deprecations import register as _register_deprecation from jax._src.pallas.core import Blocked from jax._src.pallas.core import BlockSpec +from jax._src.pallas.core import CompilerParams from jax._src.pallas.core import CostEstimate from jax._src.pallas.core import IndexingMode from jax._src.pallas.core import no_block_spec diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index cd811a874385..eb1e11df17da 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -14,7 +14,9 @@ """PagedAttention TPU kernel.""" +from collections.abc import Sequence import functools +from typing import Literal import jax from jax import lax @@ -516,6 +518,7 @@ def paged_attention( ) q_dtype_for_kernel_launch = q.dtype + dimension_semantics: Sequence[Literal["parallel", "arbitrary"]] if inline_seq_dim: kernel = paged_flash_attention_kernel_inline_seq_dim grid = ( @@ -525,7 +528,7 @@ def paged_attention( if megacore_mode == "kv_head" else num_kv_heads, ) - dimension_sematics = ("parallel", "arbitrary", "arbitrary") + dimension_semantics = ("parallel", "arbitrary", "arbitrary") else: kernel = paged_flash_attention_kernel grid = ( @@ -536,7 +539,7 @@ def paged_attention( else num_kv_heads, pages_per_sequence // pages_per_compute_block, ) # type: ignore - dimension_sematics = ("parallel", "arbitrary", "arbitrary", "arbitrary") # type: ignore + dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary") if k_scales_pages is not None and v_scales_pages is not None: in_specs = [ @@ -641,7 +644,7 @@ def paged_attention( scratch_shapes=scratch_shapes, ), compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=dimension_sematics), + dimension_semantics=dimension_semantics), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index bdf6396cae5c..80cfd04c44d6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -80,16 +80,15 @@ def kernel(x_ref, o_ref): @parameterized.product(num_stages=[1, 2, 3]) def test_add_one_grid_pipelined(self, num_stages): + @functools.partial( pl.pallas_call, in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), - compiler_params=dict( - mosaic_gpu=dict( - dimension_semantics=["parallel", "sequential"], - num_stages=num_stages, - ), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "sequential"], + num_stages=num_stages, ), grid=(2, 1), ) From 2d74c6aa05340b3d69d8234f137be6cf023d678f Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 4 Sep 2024 13:31:35 -0700 Subject: [PATCH 360/702] Add TritonCompilerParams for specifying compiler arguments instead of a dict. PiperOrigin-RevId: 671081069 --- jax/BUILD | 2 +- jax/_src/pallas/pallas_call.py | 6 ++- jax/_src/pallas/triton/BUILD | 6 +++ jax/_src/pallas/triton/core.py | 38 +++++++++++++++++++ .../pallas/triton/pallas_call_registration.py | 10 +++-- jax/experimental/pallas/gpu.py | 1 + jax/experimental/pallas/ops/gpu/attention.py | 6 +-- .../pallas/ops/gpu/decode_attention.py | 5 ++- jax/experimental/pallas/ops/gpu/layer_norm.py | 6 +-- jax/experimental/pallas/ops/gpu/rms_norm.py | 3 +- jax/experimental/pallas/ops/gpu/softmax.py | 4 +- tests/pallas/ops_test.py | 4 +- 12 files changed, 73 insertions(+), 18 deletions(-) create mode 100644 jax/_src/pallas/triton/core.py diff --git a/jax/BUILD b/jax/BUILD index b761722a5b1f..c4a421362f37 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -650,9 +650,9 @@ pytype_strict_library( ":pallas_gpu_users", ], deps = [ - ":pallas", "//jax/_src/pallas/mosaic_gpu:core", # build_cleaner: keep "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/triton:core", "//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep "//jax/_src/pallas/triton:primitives", ], diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d3b0ea8ca080..301504c8f319 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1277,8 +1277,10 @@ def pallas_call( If missing, then we use `{kernel_name} at {file}:{line}`. compiler_params: Optional compiler parameters. If a dict is provided, it should be of the form {platform: {param_name: param_value}}, where - platform is either 'mosaic' or 'triton'. For TPUs, it is also possible - to pass in a pallas.tpu.TPUCompilerParams struct. + platform is either 'mosaic' or 'triton'. It is also possible + to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and + `jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs. + Returns: A function that can be called on a number of positional array arguments to diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index c40fb19ec808..a9babcba0577 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -27,6 +27,12 @@ package( ], ) +pytype_strict_library( + name = "core", + srcs = ["core.py"], + deps = ["//jax/_src/pallas"], +) + pytype_strict_library( name = "primitives", srcs = ["primitives.py"], diff --git a/jax/_src/pallas/triton/core.py b/jax/_src/pallas/triton/core.py new file mode 100644 index 000000000000..a61dfd61b9b1 --- /dev/null +++ b/jax/_src/pallas/triton/core.py @@ -0,0 +1,38 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains Triton-specific Pallas abstractions.""" +from __future__ import annotations + +import dataclasses +from typing import ClassVar + +from jax._src.pallas import core as pallas_core + +@dataclasses.dataclass(frozen=True) +class TritonCompilerParams(pallas_core.CompilerParams): + """Compiler parameters for Triton. + + Attributes: + num_warps: The number of warps to use for the kernel. Each warp consists of + 32 threads. + num_stages: The number of stages the compiler should use for software + pipelining loops. + serialized_metadata: Additional compiler metadata. This field is unstable + and may be removed in the future. + """ + PLATFORM: ClassVar[str] = "triton" + num_warps: int | None = None + num_stages: int | None = None + serialized_metadata: str | None = None diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index b94adfb8fb3f..5ee7077dcc1f 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -61,11 +61,14 @@ def pallas_call_lowering( ) triton_params = compiler_params.get("triton", compiler_params) num_warps = triton_params.pop("num_warps", 4) + num_warps = 4 if num_warps is None else num_warps [lowering_platform] = ctx.platforms or ctx.module_context.platforms if lowering_platform == "rocm": num_stages = triton_params.pop("num_stages", 1) + num_stages = 1 if num_stages is None else num_stages else: num_stages = triton_params.pop("num_stages", 3) + num_stages = 3 if num_stages is None else num_stages if debug: print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") @@ -101,9 +104,10 @@ def pallas_call_lowering( ) if "serialized_metadata" in (triton_params or {}): # This field is unstable and may be removed in the future. - backend_config["serialized_metadata"] = ir.StringAttr.get( - triton_params["serialized_metadata"] - ) + if triton_params["serialized_metadata"] is not None: + backend_config["serialized_metadata"] = ir.StringAttr.get( + triton_params["serialized_metadata"] + ) return mlir.custom_call( call_target_name="__gpu$xla.gpu.triton", result_types=out_types, diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py index a24bfe4150df..4f38192e3a14 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/experimental/pallas/gpu.py @@ -14,6 +14,7 @@ """Triton-specific Pallas APIs.""" +from jax._src.pallas.triton.core import TritonCompilerParams from jax._src.pallas.triton.primitives import approx_tanh from jax._src.pallas.triton.primitives import debug_barrier from jax._src.pallas.triton.primitives import elementwise_inline_asm diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 63541e8cb439..8e28be840d37 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -21,6 +21,7 @@ import jax from jax import lax from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu import jax.numpy as jnp import numpy as np @@ -216,9 +217,8 @@ def mha( out_specs=pl.BlockSpec( (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), - compiler_params=dict( - triton=dict(num_warps=num_warps_, num_stages=num_stages) - ), + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps_, num_stages=num_stages), out_shape=out_shape, debug=debug, interpret=interpret, diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index 9be724a1f42c..dde80d4603cc 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -21,6 +21,7 @@ import jax from jax import lax from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu import jax.numpy as jnp @@ -153,8 +154,8 @@ def attn_unbatched( pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m ], - compiler_params=dict( - triton=dict(num_warps=num_warps_, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps_, num_stages=num_stages ), out_shape=[ jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index 0c39a9bf6e0d..e531395079ba 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -24,6 +24,7 @@ from jax._src.lax.control_flow.for_loop import for_loop from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu def layer_norm_forward_kernel( x_ref, weight_ref, bias_ref, # Input arrays @@ -282,9 +283,8 @@ def layer_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages), grid=(), out_shape=out_shape, debug=False, diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index e1dfa3c5b9b7..3e373b895b8d 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -26,6 +26,7 @@ from jax._src.lax.control_flow.for_loop import for_loop from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu def rms_norm_forward_kernel( x_ref, weight_ref, bias_ref, # Input arrays @@ -83,7 +84,7 @@ def rms_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, diff --git a/jax/experimental/pallas/ops/gpu/softmax.py b/jax/experimental/pallas/ops/gpu/softmax.py index 3671331b8df8..33b416d165d7 100644 --- a/jax/experimental/pallas/ops/gpu/softmax.py +++ b/jax/experimental/pallas/ops/gpu/softmax.py @@ -18,6 +18,7 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu def _vmappable_softmax_kernel( @@ -79,7 +80,8 @@ def softmax( kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row) f = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)), + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=1), grid=(), out_shape=out_shape, debug=debug, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 85bda21ec7f2..564a59ec2552 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -996,7 +996,7 @@ def test_debug_print(self): self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), grid=1, - compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) + compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1) ) def kernel(x_ref, o_ref): pl.debug_print("It works!") @@ -1016,7 +1016,7 @@ def test_debug_print_with_values(self): self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), grid=1, - compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) + compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1) ) def kernel(x_ref, o_ref): pl.debug_print("x[0] =", x_ref[0]) From 38184dda9a5f230c405727572f88f4312dc15197 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 4 Sep 2024 13:56:37 -0700 Subject: [PATCH 361/702] Remove CUDA and NCCL repository rules calls from RBE configs. The CUDA and NCCL repositories are created on a host machine now and shared via Bazel cache between host and remote machines. PiperOrigin-RevId: 671089856 --- .bazelrc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.bazelrc b/.bazelrc index 4456994bf2f7..98bc29d8a991 100644 --- a/.bazelrc +++ b/.bazelrc @@ -201,9 +201,10 @@ build:rbe_linux --host_linkopt=-lm # https://github.com/bazelbuild/bazel/issues/13623 build:rbe_cpu_linux_base --config=rbe_linux build:rbe_cpu_linux_base --config=cuda_clang -build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" -build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" -build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_cpu_linux_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" +build:rbe_cpu_linux_base --crosstool_top="@local_config_cuda//crosstool:toolchain" +build:rbe_cpu_linux_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_cpu_linux_base --repo_env=TF_SYSROOT="/dt9" build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" @@ -223,9 +224,10 @@ build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda12.3_nvcc_base --config=nvcc_clang build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2" build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" -build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@local_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_SYSROOT="/dt9" build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" From ea5fd29b90b244e29e382357bf641bbd542f7fab Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 4 Sep 2024 14:44:13 -0700 Subject: [PATCH 362/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/79f2636de5521c8fb98fa4ab33724a5043fdf2d2. PiperOrigin-RevId: 671106526 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index c9aa0b8d61a1..24be11d85893 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "950df464409d4b27c0cb452f78aa221b89e60672" -XLA_SHA256 = "0fec02fa0838f7b2d67482488a58016eeec9393c1230a6b5206f0b0fa8e3eb96" +XLA_COMMIT = "79f2636de5521c8fb98fa4ab33724a5043fdf2d2" +XLA_SHA256 = "fe46154bbeac942fe166580ab0e91a329adcc7428aadd90fcf96784f0291fde8" def repo(): tf_http_archive( From 85d792a92d07f5600d7796d57019fcad58228a59 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Wed, 24 Jul 2024 16:31:03 +0200 Subject: [PATCH 363/702] Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions. --- jax/_src/cudnn/__init__.py | 2 + jax/_src/cudnn/fusion.py | 91 ++++++++++++++++++++++++++++++++++++++ tests/BUILD | 14 ++++++ tests/cudnn_fusion_test.py | 69 +++++++++++++++++++++++++++++ 4 files changed, 176 insertions(+) create mode 100644 jax/_src/cudnn/fusion.py create mode 100644 tests/cudnn_fusion_test.py diff --git a/jax/_src/cudnn/__init__.py b/jax/_src/cudnn/__init__.py index 862a661e24b9..23d1fa28ff43 100644 --- a/jax/_src/cudnn/__init__.py +++ b/jax/_src/cudnn/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .fusion import cudnn_fusion diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py new file mode 100644 index 000000000000..8a13399e3d63 --- /dev/null +++ b/jax/_src/cudnn/fusion.py @@ -0,0 +1,91 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import jax +from jax import core as jax_core +from jax.interpreters import mlir +from jax.interpreters.mlir import hlo +from jax.interpreters.mlir import ir + + + +def _cudnn_fusion_impl(*args, jaxpr, **unused_kwargs): + del unused_kwargs + return jax_core.jaxpr_as_fun(jaxpr)(*args) + + +def _custom_abstract_eval(*args, jaxpr, **unused_kwargs): + del unused_kwargs + del args + return jaxpr.out_avals + + +cudnn_fusion_p = jax_core.Primitive("cudnn_fusion") +cudnn_fusion_p.multiple_results = True +cudnn_fusion_p.def_abstract_eval(_custom_abstract_eval) +cudnn_fusion_p.def_impl(_cudnn_fusion_impl) + + +def call_cudnn_fusion(f, *args, **kwargs): + """Creates a new cudnn_fusion corresponding to calling + the given function f with args and kwargs.""" + jaxpr, out_shapes = jax.make_jaxpr( + functools.partial(f, **kwargs), return_shape=True + )(*args) + flat_args = jax.tree.leaves(args) + out_tree = jax.tree.structure(out_shapes) + out_flat = cudnn_fusion_p.bind(*flat_args, name=f.__name__, jaxpr=jaxpr) + return jax.tree.unflatten(out_tree, out_flat) + + +def _cudnn_fusion_stablehlo_lowering( + ctx, + *args, + name, + jaxpr, +): + """Make cudnn_fusion which calls the implementation function. + Currently this leaks a CallOp since we're using the `core_call_lowering` + function, but this should get cleaned up by DCE easily. + """ + impl = mlir.core_call_lowering( + ctx, *args, name=name + ".impl", call_jaxpr=jaxpr + ) + call_op = impl[0].owner + called_fn = call_op.attributes["callee"] + cudnn_fusion = hlo.CustomCallOp( + [r.type for r in call_op.results], + call_op.operands, + call_target_name="__cudnn$fusion", + called_computations=ir.ArrayAttr.get([called_fn]), + ) + return cudnn_fusion.results + + +mlir.register_lowering( + cudnn_fusion_p, _cudnn_fusion_stablehlo_lowering, platform="cuda" + ) + + +def cudnn_fusion(f): + """Makes a function become a cuDNN kernel. Relies on XLA's handling of + custom fusions with __cudnn$fusion backend. Currently limited to GEMM + fusions. For example - batch matmul with mixed types and addition: + + @cudnn_fusion + def fn(x, y, z): + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + """ + return functools.partial(call_cudnn_fusion, f) diff --git a/tests/BUILD b/tests/BUILD index 45743d306fd6..b624b6bef3ac 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1523,6 +1523,20 @@ py_test( ], ) +jax_test( + name = "cudnn_fusion_test", + srcs = ["cudnn_fusion_test.py"], + disable_backends = [ + "cpu", + "tpu", + ], + enable_configs = [ + "gpu_a100", + "gpu_h100", + ], + tags = ["multiaccelerator"], +) + exports_files( [ "api_test.py", diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py new file mode 100644 index 000000000000..e70ba12361a2 --- /dev/null +++ b/tests/cudnn_fusion_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest, parameterized +from unittest import SkipTest +from jax._src import test_util as jtu +import jax +import jax.numpy as jnp +from jax._src.cudnn import cudnn_fusion + + +jax.config.parse_flags_with_absl() + + +class CudnnFusionTest(jtu.JaxTestCase): + def setUp(self): + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("8.0")): + self.skipTest("Only works on >= sm80 GPUs") + super().setUp() + + @parameterized.parameters(["", "pmap"]) + @jtu.run_on_devices("cuda") + def test_cudnn_fusion(self, mode): + batch_size = 2 + if mode == "pmap" and jax.device_count() < batch_size: + raise SkipTest("pmap test requires 2 GPUs") + + @cudnn_fusion + def comp1(x, y, z): + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + + k = jax.random.key(0) + s = batch_size, 16, 16 + x = jnp.int8(jax.random.normal(k, shape=s)) + y = jnp.bfloat16(jax.random.normal(k, shape=s)) + z = jnp.float32(jax.random.normal(k, shape=s)) + + fn = jax.pmap(comp1) if mode == "pmap" else comp1 + jitted = jax.jit(comp1) + lowered = jitted.lower(x, y, z) + stablehlo = lowered.as_text("stablehlo") + self.assertIn("func.func private @comp1", stablehlo) + self.assertIn("__cudnn$fusion", stablehlo) + + hlo = lowered.as_text("hlo") + self.assertIn('custom_call_target="__cudnn$fusion"', hlo) + self.assertIn("called_computations=", hlo) + + hlo_after_opt = lowered.compile().as_text() + self.assertIn("kind=kCustom", hlo_after_opt) + self.assertIn("plan_id", hlo_after_opt) + + self.assertAllClose(jitted(x, y, z), fn(x, y, z)) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) From 145228600f3171504e8b4291f270daf9b44a22f5 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 4 Sep 2024 16:32:47 -0700 Subject: [PATCH 364/702] Allow optionally coloring output arguments from tpu_custom_call PiperOrigin-RevId: 671143385 --- jax/_src/tpu_custom_call.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index d3c61e5a1722..97b6a2cfd32a 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -113,7 +113,7 @@ class CustomCallBackendConfig: allow_input_fusion: list[bool] | None serialization_format: int | None internal_scratch_in_bytes: int | None - output_memory_spaces: tuple[MemorySpace, ...] | None + output_memory_spaces: tuple[MemorySpace | None, ...] | None # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -161,7 +161,8 @@ def to_json(self) -> bytes: for i, memory_space in enumerate(self.output_memory_spaces): if i: config.write(b",") - config.write(str(memory_space.color).encode("ascii")) + color = memory_space.color if memory_space is not None else -1 + config.write(str(color).encode("ascii")) config.write(b"]") config.write(b"}") # End of custom_call_config. if self.device_type is not None: @@ -446,7 +447,7 @@ def _lower_to_custom_call_config( internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, - output_memory_spaces: tuple[MemorySpace, ...] | None = None, + output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ) -> CustomCallBackendConfig: lowered_module_asm, ( has_communication, @@ -491,7 +492,7 @@ def _lowered_to_custom_call_config( needs_hlo_passes: bool, needs_layout_passes: bool, device_type: str | None, - output_memory_spaces: tuple[MemorySpace, ...] | None = None, + output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ): if has_custom_barrier: if collective_id is None: @@ -541,7 +542,7 @@ def lower_module_to_custom_call( internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, - output_memory_spaces: tuple[MemorySpace, ...] | None, + output_memory_spaces: tuple[MemorySpace | None, ...] | None, device_type: str | None, ) -> Sequence[ir.Value]: config = _lower_to_custom_call_config( @@ -582,7 +583,7 @@ def as_tpu_kernel( internal_scratch_in_bytes: int | None = None, collective_id: int | None = None, serialization_format: int | None = 1, - output_memory_spaces: tuple[MemorySpace, ...] | None = None, + output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" config = _lower_to_custom_call_config( From 51eb0d27c73df9bba44e6f5155913336c58ff270 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 4 Sep 2024 22:17:19 +0100 Subject: [PATCH 365/702] Fixed some type errors under pyright These are mostly due to relience on submodule import side-effects, which AFAIU are unchecked by both pytype and mypy. --- jax/_src/export/_export.py | 2 +- jax/_src/interpreters/ad.py | 6 ++++-- jax/_src/pallas/pallas_call.py | 24 ++++++++++++++---------- jax/_src/pallas/primitives.py | 4 ++-- jax/_src/scipy/signal.py | 29 ++++++++++++++--------------- 5 files changed, 35 insertions(+), 30 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index ee1b0dabba8d..d0159f7a4334 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -58,7 +58,7 @@ zip = util.safe_zip DType = Any -Shape = jax._src.core.Shape +Shape = core.Shape # The values of input and output sharding from the lowering. LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue] HloSharding = xla_client.HloSharding diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ea9da4574e3d..f1b25cf96a95 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -21,7 +21,6 @@ from functools import partial from typing import Any -import jax from jax._src import config from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe @@ -389,6 +388,9 @@ def post_process_custom_jvp_call(self, out_tracers, _): def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, symbolic_zeros): + # Local import to prevent an import cycle. + from jax._src.lax import lax + primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) fwd_in = [(core.full_lower(p), type(t) is not Zero) for p, t in zip(primals_in, tangents_in)] @@ -402,7 +404,7 @@ def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) - tangents_out = map(jax._src.lax.lax.tie_p.bind, primals_out, tangents_out) + tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d3b0ea8ca080..28123f0a0fc1 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -36,7 +36,8 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core -from jax._src.pallas.primitives import uninitialized_value +from jax._src.pallas import primitives +from jax._src.pallas import utils as pallas_utils from jax._src.state import discharge as state_discharge from jax._src.util import ( safe_map, @@ -111,13 +112,15 @@ def _pad_values_to_block_dimension(value, ) if padded_shape != value.shape: pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) - pad_value = uninitialized_value(shape=(), dtype=value.dtype) + pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype) value = jnp.pad(value, pad_width, constant_values=pad_value) return value def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]: scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals) - return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals) + return tuple( + primitives.uninitialized_value(a.shape, a.dtype) for a in scratch_avals + ) def _initialize_output_vals( block_mappings_output: Iterable[BlockMapping], @@ -128,8 +131,9 @@ def _initialize_output_vals( if i in oi_map: output_vals.append(input_args[oi_map[i]]) else: - output_vals.append(uninitialized_value(bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype)) + output_vals.append(primitives.uninitialized_value( + bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype)) return output_vals def _logical_to_interpret_mode_dtype(dtype): @@ -212,7 +216,7 @@ def _pallas_call_impl_interpret( if padding is not None and any(p != (0, 0) for p in padding): if input_output_aliases: raise NotImplementedError("Padding with aliasing not supported.") - pad_value = uninitialized_value(shape=(), dtype=x.dtype) + pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype) x = lax.pad(x, pad_value, [(*p, 0) for p in padding]) carry.append(x) @@ -872,9 +876,9 @@ def get_size(i, x, d): val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim] def when_wrapped_kernel(lengths_ref, *args, **kwargs): - b_idx = jax.experimental.pallas.program_id(stacked_axis) + b_idx = primitives.program_id(stacked_axis) i_idx = ( - jax.experimental.pallas.program_id(ragged_axis_dim) + primitives.program_id(ragged_axis_dim) * val_at_ragged_dim ) b_len = lengths_ref[b_idx] @@ -883,7 +887,7 @@ def when_wrapped_kernel(lengths_ref, *args, **kwargs): # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") - @jax.experimental.pallas.when(i_idx < b_len) + @pallas_utils.when(i_idx < b_len) def f(): # Important! This allows us to trace the inner kernel with the correct # grid to preserve user program_id semantics. Ex: program_id(0) will @@ -893,7 +897,7 @@ def f(): if debug_zero_fill_counterfactual: - @jax.experimental.pallas.when(i_idx >= b_len) + @pallas_utils.when(i_idx >= b_len) def g(): for arg_ref in args: arg_ref[...] = jnp.zeros_like(arg_ref) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 53227478c312..e41a8cf59975 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -707,7 +707,7 @@ class PrintEffect(effects.Effect): debug_print_p.multiple_results = True -def debug_print(fmt: str, *args: jax.ArrayLike): +def debug_print(fmt: str, *args: jax.typing.ArrayLike): """Prints scalar values from inside a Pallas kernel. Args: @@ -732,7 +732,7 @@ def debug_print(fmt: str, *args: jax.ArrayLike): def check_debug_print_format( - fmt: str, *args: jax.ArrayLike + fmt: str, *args: jax.typing.ArrayLike ): n_placeholders = 0 for _, field, spec, conversion in string.Formatter().parse(fmt): diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index cb3719fafd8f..d950cd2ea395 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -27,8 +27,9 @@ import jax.numpy.fft import jax.numpy as jnp from jax import lax -from jax._src.api_util import _ensure_index_tuple +from jax._src import core from jax._src import dtypes +from jax._src.api_util import _ensure_index_tuple from jax._src.lax.lax import PrecisionLike from jax._src.numpy import linalg from jax._src.numpy.util import ( @@ -655,8 +656,7 @@ def pad(x, n, axis=-1): f"Unknown boundary option '{boundary}', " f"must be one of: {list(boundary_funcs.keys())}") - axis = jax.core.concrete_or_error(operator.index, axis, - "axis of windowed-FFT") + axis = core.concrete_or_error(operator.index, axis, "axis of windowed-FFT") axis = canonicalize_axis(axis, x.ndim) if y is None: @@ -686,8 +686,8 @@ def pad(x, n, axis=-1): noverlap_int: int = 0 if nperseg is not None: # if specified by user - nperseg_int = jax.core.concrete_or_error(int, nperseg, - "nperseg of windowed-FFT") + nperseg_int = core.concrete_or_error( + int, nperseg, "nperseg of windowed-FFT") if nperseg_int < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape @@ -698,14 +698,13 @@ def pad(x, n, axis=-1): if noverlap is None: noverlap_int = nperseg_int // 2 else: - noverlap_int = jax.core.concrete_or_error(int, noverlap, - "noverlap of windowed-FFT") + noverlap_int = core.concrete_or_error( + int, noverlap, "noverlap of windowed-FFT") if nfft is None: nfft_int = nperseg_int else: - nfft_int = jax.core.concrete_or_error(int, nfft, - "nfft of windowed-FFT") + nfft_int = core.concrete_or_error(int, nfft, "nfft of windowed-FFT") # Special cases for size == 0 if y is None: @@ -1015,8 +1014,8 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: An array with `(..., output_size)`-shape containing overlapped signal. """ check_arraylike("_overlap_and_add", x) - step_size = jax.core.concrete_or_error(int, step_size, - "step_size for overlap_and_add") + step_size = core.concrete_or_error( + int, step_size, "step_size for overlap_and_add") if x.ndim < 2: raise ValueError('Input must have (..., frames, frame_length) shape.') @@ -1114,7 +1113,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided else Zxx.shape[freq_axis]) - nperseg_int = jax.core.concrete_or_error(int, nperseg or n_default, + nperseg_int = core.concrete_or_error(int, nperseg or n_default, "nperseg: segment length of STFT") if nperseg_int < 1: raise ValueError('nperseg must be a positive integer') @@ -1125,13 +1124,13 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', if input_onesided and nperseg_int == n_default + 1: nfft_int += 1 # Odd nperseg, no FFT padding else: - nfft_int = jax.core.concrete_or_error(int, nfft, "nfft of STFT") + nfft_int = core.concrete_or_error(int, nfft, "nfft of STFT") if nfft_int < nperseg_int: raise ValueError( f'FFT length ({nfft_int}) must be longer than nperseg ({nperseg_int}).') - noverlap_int = jax.core.concrete_or_error(int, noverlap or nperseg_int // 2, - "noverlap of STFT") + noverlap_int = core.concrete_or_error( + int, noverlap or nperseg_int // 2, "noverlap of STFT") if noverlap_int >= nperseg_int: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg_int - noverlap_int From 2082662bb1b25eb3576db9896881cce16d4527f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Thu, 5 Sep 2024 01:58:58 -0700 Subject: [PATCH 366/702] Port Hessenberg Decomposition to XLA's FFI This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 671283487 --- jaxlib/cpu/cpu_kernels.cc | 4 ++ jaxlib/cpu/lapack.cc | 8 +++ jaxlib/cpu/lapack_kernels.cc | 70 +++++++++++++++++++++++ jaxlib/cpu/lapack_kernels.h | 34 +++++++++-- jaxlib/cpu/lapack_kernels_using_lapack.cc | 29 ++++++++-- 5 files changed, 136 insertions(+), 9 deletions(-) diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 93717ea9b492..c2e122c048a4 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -149,6 +149,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgeev_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dgeev_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgeev_ffi); JAX_CPU_REGISTER_HANDLER(lapack_zgeev_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_sgehrd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dgehrd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cgehrd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zgehrd_ffi); #undef JAX_CPU_REGISTER_HANDLER diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 354a1cf9ab34..8fc480951b1e 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -142,6 +142,10 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("dgehrd")); AssignKernelFn>>(lapack_ptr("cgehrd")); AssignKernelFn>>(lapack_ptr("zgehrd")); + AssignKernelFn>(lapack_ptr("sgehrd")); + AssignKernelFn>(lapack_ptr("dgehrd")); + AssignKernelFn>(lapack_ptr("cgehrd")); + AssignKernelFn>(lapack_ptr("zgehrd")); AssignKernelFn>(lapack_ptr("ssytrd")); AssignKernelFn>(lapack_ptr("dsytrd")); @@ -253,6 +257,10 @@ nb::dict Registrations() { dict["lapack_dgeev_ffi"] = EncapsulateFunction(lapack_dgeev_ffi); dict["lapack_cgeev_ffi"] = EncapsulateFunction(lapack_cgeev_ffi); dict["lapack_zgeev_ffi"] = EncapsulateFunction(lapack_zgeev_ffi); + dict["lapack_sgehrd_ffi"] = EncapsulateFunction(lapack_sgehrd_ffi); + dict["lapack_dgehrd_ffi"] = EncapsulateFunction(lapack_dgehrd_ffi); + dict["lapack_cgehrd_ffi"] = EncapsulateFunction(lapack_cgehrd_ffi); + dict["lapack_zgehrd_ffi"] = EncapsulateFunction(lapack_zgehrd_ffi); return dict; } diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 2bc62542e4da..7d58395228d1 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -1627,6 +1627,59 @@ template struct Gehrd; template struct Gehrd>; template struct Gehrd>; +// FFI Kernel + +template +ffi::Error HessenbergDecomposition::Kernel( + ffi::Buffer x, lapack_int low, lapack_int high, + ffi::ResultBuffer x_out, ffi::ResultBuffer tau, + ffi::ResultBuffer info) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); + + CopyIfDiffBuffer(x, x_out); + + ValueType* x_out_data = x_out->typed_data(); + ValueType* tau_data = tau->typed_data(); + lapack_int* info_data = info->typed_data(); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); + FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, + MaybeCastNoOverflow(x_rows)); + // Prepare LAPACK workspaces. + int64_t work_size = GetWorkspaceSize(x_rows, x_cols, low, high); + FFI_ASSIGN_OR_RETURN(auto work_size_v, + MaybeCastNoOverflow(work_size)); + auto work_data = AllocateScratchMemory(work_size); + + int64_t x_size{x_rows * x_cols}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&x_cols_v, &low, &high, x_out_data, &x_leading_dim_v, tau_data, + work_data.get(), &work_size_v, info_data); + x_out_data += x_size; + tau_data += x_cols - 1; + ++info_data; + } + return ffi::Error::Success(); +} + +template +int64_t HessenbergDecomposition::GetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols, + lapack_int low, + lapack_int high) { + ValueType optimal_size = {}; + lapack_int workspace_query = -1; + lapack_int info = 0; + fn(&x_cols, &low, &high, nullptr, &x_rows, nullptr, &optimal_size, + &workspace_query, &info); + return info == 0 ? static_cast(std::real(optimal_size)) : -1; +} + +template struct HessenbergDecomposition; +template struct HessenbergDecomposition; +template struct HessenbergDecomposition; +template struct HessenbergDecomposition; + //== Tridiagonal Reduction ==// // lapack sytrd/hetrd @@ -1811,6 +1864,17 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*eigvecs_right*/) \ .Ret<::xla::ffi::Buffer>(/*info*/)) +#define JAX_CPU_DEFINE_GEHRD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, HessenbergDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("low") \ + .Attr("high") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*tau*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + // FFI Handlers JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32); @@ -1853,6 +1917,11 @@ JAX_CPU_DEFINE_GEEV(lapack_dgeev_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_cgeev_ffi, ::xla::ffi::DataType::C64); JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128); +JAX_CPU_DEFINE_GEHRD(lapack_sgehrd_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GEHRD(lapack_dgehrd_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GEHRD(lapack_cgehrd_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128); + #undef JAX_CPU_DEFINE_TRSM #undef JAX_CPU_DEFINE_GETRF #undef JAX_CPU_DEFINE_GEQRF @@ -1864,5 +1933,6 @@ JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128); #undef JAX_CPU_DEFINE_HEEVD #undef JAX_CPU_DEFINE_GEEV #undef JAX_CPU_DEFINE_GEEV_COMPLEX +#undef JAX_CPU_DEFINE_GEHRD } // namespace jax diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index a571de5dd6de..b4f54b923478 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -192,9 +192,9 @@ struct QrFactorization { inline static FnType* fn = nullptr; - static ::xla::ffi::Error Kernel( - ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, - ::xla::ffi::ResultBuffer tau); + static ::xla::ffi::Error Kernel(::xla::ffi::Buffer x, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer tau); static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); }; @@ -444,8 +444,7 @@ struct EigenvalueDecompositionHermitian { ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, ::xla::ffi::ResultBuffer x_out, ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> eigenvalues, - ::xla::ffi::ResultBuffer info, - eig::ComputationMode mode); + ::xla::ffi::ResultBuffer info, eig::ComputationMode mode); }; // lapack geev @@ -579,6 +578,27 @@ struct real_type> { typedef T type; }; +// FFI Kernel + +template <::xla::ffi::DataType dtype> +struct HessenbergDecomposition { + using ValueType = ::xla::ffi::NativeType; + using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, + ValueType* a, lapack_int* lda, ValueType* tau, + ValueType* work, lapack_int* lwork, lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, lapack_int low, lapack_int high, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer tau, + ::xla::ffi::ResultBuffer info); + + static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, + lapack_int low, lapack_int high); +}; + //== Tridiagonal Reduction ==// //== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==// @@ -630,6 +650,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeev_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeev_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeev_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeev_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgehrd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgehrd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgehrd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgehrd_ffi); } // namespace jax diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 2a2597629b93..9f13bb99d582 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -71,10 +71,10 @@ jax::RealGees::FnType dgees_; jax::ComplexGees>::FnType cgees_; jax::ComplexGees>::FnType zgees_; -jax::Gehrd::FnType sgehrd_; -jax::Gehrd::FnType dgehrd_; -jax::Gehrd>::FnType cgehrd_; -jax::Gehrd>::FnType zgehrd_; +jax::HessenbergDecomposition::FnType sgehrd_; +jax::HessenbergDecomposition::FnType dgehrd_; +jax::HessenbergDecomposition::FnType cgehrd_; +jax::HessenbergDecomposition::FnType zgehrd_; jax::Sytrd::FnType ssytrd_; jax::Sytrd::FnType dsytrd_; @@ -211,6 +211,22 @@ static_assert( jax::EigenvalueDecompositionComplex::FnType, jax::ComplexGeev>::FnType>, JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Gehrd::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Gehrd::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Gehrd>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Gehrd>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); #undef JAX_KERNEL_FNTYPE_MISMATCH_MSG @@ -315,6 +331,11 @@ static auto init = []() -> int { AssignKernelFn>(cgeev_); AssignKernelFn>(zgeev_); + AssignKernelFn>(sgehrd_); + AssignKernelFn>(dgehrd_); + AssignKernelFn>(cgehrd_); + AssignKernelFn>(zgehrd_); + return 0; }(); From f9204e6311afa98e4eea252efebfe966ac2fb729 Mon Sep 17 00:00:00 2001 From: Yury Kirpichev Date: Sun, 18 Aug 2024 19:24:44 -0700 Subject: [PATCH 367/702] Add md page about XLA&XLA_FLAGS in JAX --- docs/advanced_guide.rst | 1 + docs/xla_flags.md | 89 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 docs/xla_flags.md diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index 5cf32f696252..85ed315c98e5 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -29,3 +29,4 @@ operations. :maxdepth: 1 notebooks/convolutions + xla_flags diff --git a/docs/xla_flags.md b/docs/xla_flags.md new file mode 100644 index 000000000000..b332940ccb9d --- /dev/null +++ b/docs/xla_flags.md @@ -0,0 +1,89 @@ +# List of XLA compiler flags + + + +## Introduction +This guide gives a brief overview of XLA and how XLA relates to Jax. +For in-depth details please refer to [XLA documentation](https://openxla.org/xla). Then it lists commonly-used XLA compiler flags designed to optimize performance of Jax programs. + +## XLA: The Powerhouse Behind Jax +XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that plays a pivotal role in Jax's performance and flexibility. It enables Jax to generate optimized code for various hardware backends (CPUs, GPUs, TPUs) by transforming and compiling your Python/NumPy-like code into efficient machine instructions. + +Jax uses XLA's JIT compilation capabilities to transform your Python functions into optimized XLA computations at runtime. + +## Configuring XLA in Jax: +You can influence XLA's behavior in Jax by setting XLA_FLAGS environment variables before running your Python script or colab notebook. + +For the colab notebooks: + +Provide flags using `os.environ['XLA_FLAGS']`: + + +```python +import os + +# Set multiple flags separated by spaces +os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2' +``` + +For the python scripts: + +Specify `XLA_FLAGS` as a part of cli command: + +```bash +XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py +``` + +**Important Notes:** + +* Set `XLA_FLAGS` before importing Jax or other relevant libraries. Changing `XLA_FLAGS` after backend initialization will have no effect and given backend initialization time is not clearly defined it is usually safer to set `XLA_FLAGS` before executing any Jax code. +* Experiment with different flags to optimize performance for your specific use case. + + +**For further information:** +* Complete and up to date documentation about XLA can be found in the official [XLA documentation](https://openxla.org/xla). + +* For backends supported by open-source version of XLA (CPU, GPU), XLA flags are defined with their default values in [xla/debug_options_flags.cc](https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc), and a complete list of flags could be found [here](https://github.com/openxla/xla/blob/main/xla/xla.proto). +* TPU compiler flags are not part of [OpenXLA](https://github.com/openxla/xla), but commonly-used options are listed below. + +* Please note that this list of flags is not exhaustive and is subject to change. These flags are implementation details, and there is no guarantee that they will remain available or maintain their current behavior. +### Common XLA flags +| Flag | Type | Notes | +| ---- | ---- | ----- | +| `xla_dump_to` | String (filepath) | The folder where pre-optimization HLO files and other artifacts will be placed (see [XLA Tools](https://openxla.org/xla/tools)). | +| `xla_enable_async_collective_permute` | TristateFlag (true/false/auto) | Rewrites all collective-permute operations to their asynchronous variants. When set to `auto`, XLA can turn on async collective based on other configurations or conditions automatically. | +| `xla_enable_async_all_gather` | TristateFlag (true/false/auto) | If set to true, enables async all gather. If `auto`, enables only for platforms that implement async all-gather. The implementation (such as BC-offload or continuation fusion) is chosen based on other flag values. | +| `xla_disable_hlo_passes` | String (comma-separated list of pass names) | Comma-separated list of HLO passes to be disabled. These names must exactly match the pass name (no whitespace around commas). | + +### TPU XLA flags +| Flag | Type | Notes | +| ---- | ---- | ----- | +| `xla_tpu_enable_data_parallel_all_reduce_opt` | Boolean (true/false) | Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding. | +| `xla_tpu_data_parallel_opt_different_sized_ops` | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes doesn't match what can Be saved in place in the stacked variables. Can increase memory pressure. | +| `xla_tpu_enable_async_collective_fusion` | Boolean (true/false) | Enables the pass which fuses async collective communications with compute ops (output/loop-fusion or convolution) that are scheduled between their -start and -done instructions. | +| `xla_tpu_enable_async_collective_fusion_fuse_all_gather` | TristateFlag (true/false/auto) | Enables fusing all-gathers within the AsyncCollectiveFusion pass.
If set to `auto`, it will be enabled based on the target. | +| `xla_tpu_enable_async_collective_fusion_multiple_steps` | Boolean (true/false) | Enables continuing the same async collective in multiple steps (fusions) in the AsyncCollectiveFusion pass. | +| `xla_tpu_overlap_compute_collective_tc` | Boolean (true/false) | Enables the overlap of compute and communication on a single TensorCore, i.e., one core equivalent of MegaCore fusion. | +| `xla_tpu_spmd_rng_bit_generator_unsafe` | Boolean (true/false) | Whether to run RngBitGenerator HLO in a partitioned way, which is unsafe if deterministic results are expected with different shardings on different parts of the computation. | +| `xla_tpu_megacore_fusion_allow_ags` | Boolean (true/false) | Allows fusing all-gathers with convolutions/all-reduces. | +| `xla_tpu_enable_ag_backward_pipelining` | Boolean (true/false) | Pipelines all-gathers (currently megascale all-gathers) backwards through scan loops. | + +### GPU XLA flags +| Flag | Type | Notes | +| ---- | ---- | ----- | +| `xla_gpu_enable_latency_hiding_scheduler` | Boolean (true/false) |This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False. | +| `xla_gpu_enable_triton_gemm` | Boolean (true/false) | Use Triton-based matrix multiplication. | +| `xla_gpu_graph_level` | Flag (0-3) | The legacy flag for setting GPU graph level. Use xla_gpu_enable_command_buffer in new use cases. 0 = off; 1 = capture fusions and memcpys; 2 = capture gemms; 3 = capture convolutions. | +| `xla_gpu_all_reduce_combine_threshold_bytes` | Integer (bytes) | These flags tune when to combine multiple small AllGather / ReduceScatter / AllReduce into one big AllGather / ReduceScatter / AllReduce to reduce time spent on cross-device communication. For example, for the AllGather / ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather / ReduceScatter. By default, the combine_threshold_bytes is set to 256. | +| `xla_gpu_all_gather_combine_threshold_bytes` | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above. | +| `xla_gpu_reduce_scatter_combine_threshold_bytes` | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above. | +| `xla_gpu_enable_pipelined_all_gather` | Boolean (true/false) | Enable pipelinling of all-gather instructions. | +| `xla_gpu_enable_pipelined_reduce_scatter` | Boolean (true/false) | Enable pipelinling of reduce-scatter instructions. | +| `xla_gpu_enable_pipelined_all_reduce` | Boolean (true/false) | Enable pipelinling of all-reduce instructions. | +| `xla_gpu_enable_while_loop_double_buffering` | Boolean (true/false) | Enable double-buffering for while loop. | +| `xla_gpu_enable_triton_softmax_fusion` | Boolean (true/false) | Use Triton-based Softmax fusion. | +| `xla_gpu_enable_all_gather_combine_by_dim` | Boolean (true/false) | Combine all-gather ops with the same gather dimension or irrespective of their dimension. | +| `xla_gpu_enable_reduce_scatter_combine_by_dim` | Boolean (true/false) | Combine reduce-scatter ops with the same dimension or irrespective of their dimension. | + +**Additional reading:** +* [GPU performance tips](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) From f3b91b2042ae6d9c23bda28f82fa656767ef898b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 5 Sep 2024 04:14:48 -0700 Subject: [PATCH 368/702] Export `PointerType` and `register_dialect` from `jaxlib.triton.dialect` The `... as ...` form tells the type checker that the name is exported. See #7570. PiperOrigin-RevId: 671318047 --- jaxlib/triton/dialect.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jaxlib/triton/dialect.py b/jaxlib/triton/dialect.py index 1bbb565b69b2..0e3fb4d982cb 100644 --- a/jaxlib/triton/dialect.py +++ b/jaxlib/triton/dialect.py @@ -21,9 +21,9 @@ from collections.abc import Sequence from jaxlib.mlir._mlir_libs._triton_ext import ( - PointerType, - infer_reduce_op_encoding, - register_dialect, + PointerType as PointerType, + register_dialect as register_dialect, + infer_reduce_op_encoding as _infer_reduce_op_encoding, ) from jaxlib.mlir import ir @@ -86,7 +86,7 @@ def _infer_reduce_op_return_types( if not shape: return_types.append(op_type.element_type) elif op_encoding := op_type.encoding: - encoding = infer_reduce_op_encoding(op_encoding, axis) + encoding = _infer_reduce_op_encoding(op_encoding, axis) if encoding is not None: raise RuntimeError("Failed to infer return type encoding for ReduceOp") return_types.append( From 8feab682097b0949d0504ec0ee73f4637aeb1f57 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 5 Sep 2024 04:57:02 -0700 Subject: [PATCH 369/702] [Mosaic GPU] Remove the unnecessary scratch space operand And clean up the C++ dispatch code. We don't use HBM scratch anymore since we pass TMA descriptors as kernel arguments. PiperOrigin-RevId: 671327420 --- .../mosaic_gpu/pallas_call_registration.py | 1 - jax/experimental/mosaic/gpu/__init__.py | 27 +++++++------------ jaxlib/mosaic/gpu/custom_call.cc | 25 +++++++---------- 3 files changed, 19 insertions(+), 34 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 9f28fa7c2944..5b46caf1553a 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -70,6 +70,5 @@ def pallas_call_lowering( ctx, *args, module=module.operation.get_asm(binary=True, enable_debug_info=True), - gmem_scratch_bytes=lowering_result.gmem_scratch_bytes, out_types=lowering_result.out_structs, ) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index eb8eba9dfacf..2e2941fca5b1 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -88,14 +88,14 @@ @mosaic_gpu_p.def_abstract_eval -def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes): - del module, gmem_scratch_bytes # Unused. +def _mosaic_gpu_abstract_eval(*_, module, out_types): + del module # Unused. return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] # TODO(apaszke): Implement a proper system for managing kernel lifetimes KNOWN_KERNELS = {} -def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes): +def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): del out_types # Unused. kernel_id = hashlib.sha256(module).digest() # Note that this is technically only a half measure. Someone might load a @@ -108,19 +108,13 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes) KNOWN_KERNELS[kernel_id] = module op = mlir.custom_call( "mosaic_gpu", - result_types=[ - *(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out), - mlir.aval_to_ir_type( - jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8) - ), - ], + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], operands=args, operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], - result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out] - + [[0]], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], backend_config=kernel_id + module, ) - return op.results[:-1] # Skip the scratch space. + return op.results mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") @@ -766,8 +760,8 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: ir.Attribute.parse("#llvm.linkage"), addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. ) - @func.FuncOp.from_py_func(ptr_ty, ptr_ty, ptr_ty) - def main(token_ptr, buffers, gmem_scratch_ptr): + @func.FuncOp.from_py_func(ptr_ty, ptr_ty) + def main(token_ptr, buffers): nonlocal gmem_scratch_bytes token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) arg_refs = [] @@ -803,7 +797,7 @@ def main(token_ptr, buffers, gmem_scratch_ptr): sym_tab.insert(global_scratch) module.operation.verify() - return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple + return module, out_shape, unwrap_output_tuple def as_gpu_kernel( @@ -822,7 +816,7 @@ def as_gpu_kernel( elif not isinstance(in_shape, tuple): in_shape = (in_shape,) - module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = ( + module, out_shape, unwrap_output_tuple = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, module_name, prof_spec @@ -844,7 +838,6 @@ def bind(*args): *args, out_types=out_shape, module=module_asm, - gmem_scratch_bytes=gmem_scratch_bytes, ) if prof_spec is not None: diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index d9b1e0775ecc..2e5723b184a8 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -353,20 +353,16 @@ absl::StatusOr> Compile( class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - void* scratch_addr, MosaicHostFunc* host_launch) - : engine_(std::move(engine)), - ctx_(ctx), - scratch_addr_(scratch_addr), - host_launch_(host_launch) {} + MosaicHostFunc* host_launch) + : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch) {} - std::tuple GetHostLaunch() { - return std::make_tuple(ctx_, scratch_addr_, host_launch_); + std::tuple GetHostLaunch() { + return std::make_tuple(ctx_, host_launch_); } private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly - void* scratch_addr_; MosaicHostFunc* host_launch_; }; @@ -384,7 +380,7 @@ GetKernelCache() { // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CompileAndInit( +absl::StatusOr> CompileAndInit( CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; @@ -426,10 +422,8 @@ absl::StatusOr> CompileAndInit( void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); cache->insert_or_assign( - key, - CompiledKernel(std::move(*maybe_engine), kernel_ptr, - nullptr, // TODO(apaszke): Clean this up. - reinterpret_cast(*main))); + key, CompiledKernel(std::move(*maybe_engine), kernel_ptr, + reinterpret_cast(*main))); } return cache->at(key).GetHostLaunch(); } @@ -454,9 +448,8 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, ctx_and_kernel.status().message().size()); return; } - void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers, - &std::get<1>(*ctx_and_kernel)}; - std::get<2>(*ctx_and_kernel)(args); + void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers}; + std::get<1>(*ctx_and_kernel)(args); } XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, From 104ebc2e905a0eda7e563a53b295b55235885c2c Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Thu, 5 Sep 2024 10:31:45 -0400 Subject: [PATCH 370/702] Fix expression parentheses in shape polymorphism docs for division of symbolic dimensions. --- docs/export/shape_poly.md | 2 +- jax/experimental/jax2tf/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 695ca6cd21d9..f025c60577d5 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -619,7 +619,7 @@ compilation. ### Division of symbolic dimensions is partially supported JAX will attempt to simplify division and modulo operations, -e.g., `(a * b + a) // (b + 1) == a` and `6*a + 4 % 3 == 1`. +e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`. In particular, JAX will handle the cases when either (a) there is no remainder, or (b) the divisor is a constant in which case there may be a constant remainder. diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index b190829fe7d0..dbdc4f563368 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -880,7 +880,7 @@ is unsound. ### Division of symbolic dimensions is partially supported JAX will attempt to simplify division and modulo operations, -e.g., `(a * b + a) // (b + 1) == a` and `6*a + 4 % 3 == 1`. +e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`. In particular, JAX will handle the cases when either (a) there is no remainder, or (b) the divisor is a constant in which case there may be a constant remainder. From 2dd13ce3e8bb6245158c49f0b365700bcaf6966e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 5 Sep 2024 07:59:52 -0700 Subject: [PATCH 371/702] Fix some tests to use iota_order when creating a mesh PiperOrigin-RevId: 671373187 --- tests/array_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/array_test.py b/tests/array_test.py index 3cbb27877a20..9260e800258e 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1228,7 +1228,7 @@ def test_mesh_str(self): def test_make_array_from_callback_error(self): mesh_shape = (2, 3) global_shape = tuple(np.square(mesh_shape)) - mesh = jtu.create_mesh(mesh_shape, ('x', 'y')) + mesh = jtu.create_mesh(mesh_shape, ('x', 'y'), iota_order=True) pspec = P('x', 'y') sharding = jax.sharding.NamedSharding(mesh, pspec) n = math.prod(global_shape) @@ -1387,7 +1387,7 @@ def f(x): global_shape = tuple(np.square(mesh_shape)) - mesh = jtu.create_mesh(mesh_shape, ('x', 'y')) + mesh = jtu.create_mesh(mesh_shape, ('x', 'y'), iota_order=True) s = jax.sharding.NamedSharding(mesh, pspec) n = math.prod(global_shape) From 97db78ba24da5eaf0e5886729d86d8de04e4ff34 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 5 Sep 2024 09:57:17 -0700 Subject: [PATCH 372/702] Adds test_compute_offload_with_donation in memories_test PiperOrigin-RevId: 671410527 --- tests/memories_test.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/memories_test.py b/tests/memories_test.py index 9b8b990d674b..affe5de99644 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1388,6 +1388,38 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P())) self.assertEqual(out.sharding.memory_kind, 'device') + def test_compute_offload_with_donation(self): + sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) + p_sharding = jax.sharding.SingleDeviceSharding( + jax.devices()[0], memory_kind="pinned_host" + ) + + @compute_on("device_host") + @jax.jit + def host_fn(x_in, y_in): + return x_in * x_in, y_in + y_in + + def test_fn(x_in, y_in): + x_out, y_out = host_fn(x_in, y_in) + return x_out, y_out + + x = jnp.arange(0, 1024, dtype=jnp.float32) + y = jnp.arange(0, 1024, dtype=jnp.float32) + y = jax.device_put(y, p_sharding) + + x1 = jnp.arange(0, 1024, dtype=jnp.float32) + y1 = jnp.arange(0, 1024, dtype=jnp.float32) + + jit_fn = jax.jit( + test_fn, + in_shardings=(sharding, p_sharding), + out_shardings=(sharding, p_sharding), + donate_argnums=(0, 1), + ) + x_out, y_out = jit_fn(x, y) + self.assertArraysEqual(x_out, x1 * x1) + self.assertArraysEqual(y_out, y1 + y1) + class ActivationOffloadingTest(jtu.JaxTestCase): From 2d2cbbc5fbbeb9869b0c883e8a2a3fc401979a6b Mon Sep 17 00:00:00 2001 From: kaixih Date: Tue, 3 Sep 2024 20:25:02 +0000 Subject: [PATCH 373/702] Relax q_seqlen and kv_seqlen --- jax/_src/nn/functions.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index b49e16d95408..a5b5aaf31799 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -786,10 +786,14 @@ def _get_causal_mask(T, S): return mask[None, None, :, :] def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen): - q_indices = jnp.arange(0, T)[None, :, None] - kv_indices = jnp.arange(0, S)[None, None, :] - q_mask = q_indices < q_seqlen[:, None, None] - kv_mask = kv_indices < kv_seqlen[:, None, None] + q_mask = True + kv_mask = True + if q_seqlen is not None: + q_indices = jnp.arange(0, T)[None, :, None] + q_mask = q_indices < q_seqlen[:, None, None] + if kv_seqlen is not None: + kv_indices = jnp.arange(0, S)[None, None, :] + kv_mask = kv_indices < kv_seqlen[:, None, None] mask = jnp.logical_and(q_mask, kv_mask) return mask[:, None, :, :] @@ -813,7 +817,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): mask = _get_causal_mask(T, S) combined_mask = jnp.logical_and(combined_mask, mask) - if q_seqlen is not None and kv_seqlen is not None: + if q_seqlen is not None or kv_seqlen is not None: mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen) combined_mask = jnp.logical_and(combined_mask, mask) @@ -1001,12 +1005,22 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], kv_seqlen=key_value_seq_lengths, ) case 'cudnn': + use_padding = ( + query_seq_lengths is not None or key_value_seq_lengths is not None + ) + if use_padding: + if query_seq_lengths is None: + T = query_arr.shape[1] + query_seq_lengths = jnp.full((B,), T, dtype=jnp.int32) + if key_value_seq_lengths is None: + key_value_seq_lengths = jnp.full((B,), S, dtype=jnp.int32) + mask_type = MaskType.NO_MASK - if query_seq_lengths is not None and is_causal: + if use_padding and is_causal: mask_type = MaskType.PADDING_CAUSAL elif is_causal: mask_type = MaskType.CAUSAL - elif query_seq_lengths is not None: + elif use_padding: mask_type = MaskType.PADDING out = cudnn_dot_product_attention( query_arr, key_arr, value_arr, bias, mask, query_seq_lengths, From dba674153e4d850f2089c99abd8a5c997de4665b Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 5 Sep 2024 11:05:54 -0700 Subject: [PATCH 374/702] [Mosaic TPU] Fix operands order in try canonicalize add of matmul. PiperOrigin-RevId: 671437100 --- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index ff349160dc50..bc1d30893537 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -493,8 +493,9 @@ class CanonicalizeAddOfMatmul : public OpRewritePattern { } return failure(); }; - return success(succeeded(try_canonicalize(op.getLhs(), op.getRhs())) || - succeeded(try_canonicalize(op.getLhs(), op.getRhs()))); + // We tried try_canonicalize(op.getRhs(), op.getLhs()) and it caused + // worrying numerical differences in some of kernels. + return try_canonicalize(op.getLhs(), op.getRhs()); } }; From ee67c414f0b1f84bac7e1513cdc3b6cb69189590 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Thu, 5 Sep 2024 23:38:35 +0530 Subject: [PATCH 375/702] Better doc for jnp.rot90 --- jax/_src/numpy/lax_numpy.py | 58 ++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8d17f09ac3db..74dd152f2684 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -961,9 +961,65 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return lax.transpose(x, axes) -@util.implements(np.rot90, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('k', 'axes')) def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: + """Rotate an array by 90 degrees counterclockwise in the plane specified by axes. + + JAX implementation of :func:`numpy.rot90`. + + Args: + m: input array. Must have ``m.ndim >= 2``. + k: int, optional, default=1. Specifies the number of times the array is rotated. + For negative values of ``k``, the array is rotated in clockwise direction. + axes: tuple of 2 integers, optional, default= (0, 1). The axes define the plane + in which the array is rotated. Both the axes must be different. + + Returns: + An array containing the copy of the input, ``m`` rotated by 90 degrees. + + See also: + - :func:`jax.numpy.flip`: reverse the order along the given axis + - :func:`jax.numpy.fliplr`: reverse the order along axis 1 (left/right) + - :func:`jax.numpy.flipud`: reverse the order along axis 0 (up/down) + + Examples: + >>> m = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.rot90(m) + Array([[3, 6], + [2, 5], + [1, 4]], dtype=int32) + >>> jnp.rot90(m, k=2) + Array([[6, 5, 4], + [3, 2, 1]], dtype=int32) + + ``jnp.rot90(m, k=1, axes=(1, 0))`` is equivalent to + ``jnp.rot90(m, k=-1, axes(0,1))``. + + >>> jnp.rot90(m, axes=(1, 0)) + Array([[4, 1], + [5, 2], + [6, 3]], dtype=int32) + >>> jnp.rot90(m, k=-1, axes=(0, 1)) + Array([[4, 1], + [5, 2], + [6, 3]], dtype=int32) + + when input array has ``ndim>2``: + + >>> m1 = jnp.array([[[1, 2, 3], + ... [4, 5, 6]], + ... [[7, 8, 9], + ... [10, 11, 12]]]) + >>> jnp.rot90(m1, k=1, axes=(2, 1)) + Array([[[ 4, 1], + [ 5, 2], + [ 6, 3]], + + [[10, 7], + [11, 8], + [12, 9]]], dtype=int32) + """ util.check_arraylike("rot90", m) if np.ndim(m) < 2: raise ValueError("rot90 requires its first argument to have ndim at least " From 65b1b0bd95bfdc57f31e598f80d7f17ea9eb31d2 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 5 Sep 2024 11:13:02 -0700 Subject: [PATCH 376/702] Roll back #23404, because it incorrectly casts numpy scalars to `float` when `dtype = None` For example: ``` >>> dtypes.coerce_to_array(np.complex64(1+1j)) jax/_src/dtypes.py:323: ComplexWarning: Casting complex values to real discards the imaginary part return np.array(x).astype(dtype) array(1.) ``` Reverts 3672b633c30fe82ef94d6cb83889894bdda64295 PiperOrigin-RevId: 671439898 --- jax/_src/dtypes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 05350d8621f8..81f4180a1c12 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -320,7 +320,7 @@ def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray: """ if dtype is None and type(x) in python_scalar_dtypes: dtype = _scalar_type_to_dtype(type(x), x) - return np.array(x).astype(dtype) + return np.asarray(x, dtype) iinfo = ml_dtypes.iinfo finfo = ml_dtypes.finfo From 86f48a85b468690b390aae043082e0ffe7dc75c5 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 5 Sep 2024 12:19:59 -0400 Subject: [PATCH 377/702] Add support for the DeviceLocalLayout API when lowering FFI calls. This PR updates the FFI lowering rule to support a DeviceLoweringLayout object as input when specifying the input and output layouts. For now, this just converts the DLL object to its appropriate list of minor-to-major integers because that's what the custom call op expects. --- jax/_src/extend/ffi.py | 33 +++++++++++++------ tests/extend_test.py | 72 +++++++++++++++++++++++------------------- 2 files changed, 63 insertions(+), 42 deletions(-) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index aa092a768bb3..833ac4f615a8 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Mapping, Sequence import ctypes import functools import os @@ -27,12 +27,14 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.layout import DeviceLocalLayout from jax._src.lib import jaxlib from jax._src.lib import xla_client from jax._src.lib.mlir import ir -from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, Shape +from jax._src.typing import Array, ArrayLike, DuckTypedArray, Shape map, unsafe_map = util.safe_map, map +FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None def register_ffi_target( @@ -104,15 +106,24 @@ def _aval_shape(aval: core.AbstractValue) -> Shape: return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error -def _default_layouts(avals: Iterable[core.AbstractValue]) -> list[list[DimSize]]: - return [list(reversed(range(len(_aval_shape(aval))))) for aval in avals] +def _convert_layout(aval: core.AbstractValue, + layout: FfiLayoutOptions = None) -> Sequence[int]: + """Convert a layout to the minor-to-major order used by the custom call API.""" + if layout is None: + return list(reversed(range(len(_aval_shape(aval))))) + elif isinstance(layout, DeviceLocalLayout): + if layout._tiling is not None: + raise ValueError("The FFI does not support layouts with tiling") + return layout.major_to_minor[::-1] + else: + return layout def ffi_lowering( call_target_name: str, *, - operand_layouts: Sequence[Sequence[DimSize]] | None = None, - result_layouts: Sequence[Sequence[DimSize]] | None = None, + operand_layouts: Sequence[FfiLayoutOptions] | None = None, + result_layouts: Sequence[FfiLayoutOptions] | None = None, backend_config: Mapping[str, ir.Attribute] | None = None, **lowering_args: Any ) -> mlir.LoweringRule: @@ -147,13 +158,15 @@ def _lowering( if "result_types" not in kwargs: kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] if operand_layouts is None: - kwargs["operand_layouts"] = _default_layouts(ctx.avals_in) + kwargs["operand_layouts"] = map(_convert_layout, ctx.avals_in) else: - kwargs["operand_layouts"] = operand_layouts + kwargs["operand_layouts"] = [ + _convert_layout(*args) for args in zip(ctx.avals_in, operand_layouts)] if result_layouts is None: - kwargs["result_layouts"] = _default_layouts(ctx.avals_out) + kwargs["result_layouts"] = map(_convert_layout, ctx.avals_out) else: - kwargs["result_layouts"] = result_layouts + kwargs["result_layouts"] = [ + _convert_layout(*args) for args in zip(ctx.avals_out, result_layouts)] if "result_shapes" not in kwargs and not all( core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out): kwargs["result_shapes"] = [ diff --git a/tests/extend_test.py b/tests/extend_test.py index 867e46fe6262..fff3314a7656 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -31,6 +31,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.interpreters import mlir +from jax._src.layout import DeviceLocalLayout from jax._src.lib.mlir.dialects import hlo jax.config.parse_flags_with_absl() @@ -97,33 +98,50 @@ def no_rule(*args, **kwargs): class FfiTest(jtu.JaxTestCase): + def find_custom_call_in_module(self, module): + for func in module.body.operations: + for block in func.body.blocks: + for op in block.operations: + if op.OPERATION_NAME == "stablehlo.custom_call": + return op + self.fail("No custom_call found in the lowered IR") + def testHeadersExist(self): base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api") for header in ["c_api.h", "api.h", "ffi.h"]: self.assertTrue(os.path.exists(os.path.join(base_dir, header))) - def testLoweringLayouts(self): + @parameterized.parameters([ + (tuple(range(3)), tuple(range(3))), + (None, tuple(reversed(range(3)))), + (DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))), + ]) + def testLoweringLayouts(self, layout_spec, expected_layout): # Regression test to ensure that the lowering rule properly captures # layouts. def lowering_rule(ctx, x): aval, = ctx.avals_in ndim = len(aval.shape) - layout = tuple(range(ndim)) - return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout], - result_layouts=[layout])(ctx, x) + return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec], + result_layouts=[layout_spec])(ctx, x) prim = core.Primitive("test_ffi") prim.def_impl(lambda x: x) prim.def_abstract_eval(lambda x: x) mlir.register_lowering(prim, lowering_rule) - x = jnp.linspace(0, 1, 5) + + x = jnp.ones((3,) * len(expected_layout)) lowered = jax.jit(prim.bind).lower(x) module = lowered.compiler_ir("stablehlo") - for func in module.body.operations: - for block in func.body.blocks: - for op in block.operations: - if op.OPERATION_NAME == "stablehlo.custom_call": - self.assertIn("operand_layouts", op.attributes) - self.assertIn("result_layouts", op.attributes) + op = self.find_custom_call_in_module(module) + self.assertIn("operand_layouts", op.attributes) + self.assertIn("result_layouts", op.attributes) + + text = lowered.as_text() + expected = ", ".join(map(str, expected_layout)) + pattern = rf"operand_layouts = \[dense<\[{expected}\]>" + self.assertRegex(text, pattern) + pattern = rf"result_layouts = \[dense<\[{expected}\]>" + self.assertRegex(text, pattern) @parameterized.parameters([ (True, mlir.ir.BoolAttr.get), @@ -140,19 +158,14 @@ def fun(x): # Here we inspect the lowered IR to test that the parameter has been # serialized with the appropriate type. module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo") - for func in module.body.operations: - for block in func.body.blocks: - for op in block.operations: - if op.OPERATION_NAME == "stablehlo.custom_call": - config = op.attributes["mhlo.backend_config"] - self.assertIsInstance(config, mlir.ir.DictAttr) - self.assertIn("param", config) - with mlir.make_ir_context(), mlir.ir.Location.unknown(): - expected = expected_builder(param) - self.assertEqual(type(config["param"]), type(expected)) - self.assertTrue(expected.type.isinstance(config["param"].type)) - return - self.fail("No custom_call found in the lowered IR") + op = self.find_custom_call_in_module(module) + config = op.attributes["mhlo.backend_config"] + self.assertIsInstance(config, mlir.ir.DictAttr) + self.assertIn("param", config) + with mlir.make_ir_context(), mlir.ir.Location.unknown(): + expected = expected_builder(param) + self.assertEqual(type(config["param"]), type(expected)) + self.assertTrue(expected.type.isinstance(config["param"].type)) def testToken(self): def fun(): @@ -161,14 +174,9 @@ def fun(): # Ensure that token inputs and outputs are translated to the correct type module = jax.jit(fun).lower().compiler_ir("stablehlo") - for func in module.body.operations: - for block in func.body.blocks: - for op in block.operations: - if op.OPERATION_NAME == "stablehlo.custom_call": - self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) - self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) - return - self.fail("No custom_call found in the lowered IR") + op = self.find_custom_call_in_module(module) + self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) + self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) @jtu.sample_product( shape=[(1,), (4,), (5,)], From be4383b3f9ea51b7ff7d2c1a308e2d0ee5179c61 Mon Sep 17 00:00:00 2001 From: Frederik Wilde Date: Thu, 5 Sep 2024 20:47:40 +0200 Subject: [PATCH 378/702] Typo --- docs/type_promotion.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/type_promotion.rst b/docs/type_promotion.rst index 103a8331df2b..d3724745fe08 100644 --- a/docs/type_promotion.rst +++ b/docs/type_promotion.rst @@ -218,7 +218,7 @@ Strict dtype promotion ---------------------- In some contexts it can be useful to disable implicit type promotion behavior, and instead require all promotions to be explicit. This can be done in JAX by setting the -``jax_numpy_dtype_promtion`` flag to ``'strict'``. Locally, it can be done with a\ +``jax_numpy_dtype_promotion`` flag to ``'strict'``. Locally, it can be done with a\ context manager: .. code-block:: python From d08b68996a128969e134634734e2332b7f0bee1d Mon Sep 17 00:00:00 2001 From: Frederik Wilde <42576579+frederikwilde@users.noreply.github.com> Date: Thu, 5 Sep 2024 21:29:56 +0200 Subject: [PATCH 379/702] Update jax-primitives.md --- docs/_tutorials/jax-primitives.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_tutorials/jax-primitives.md b/docs/_tutorials/jax-primitives.md index 41ff86fd60f0..abdc8be6d0a8 100644 --- a/docs/_tutorials/jax-primitives.md +++ b/docs/_tutorials/jax-primitives.md @@ -306,7 +306,7 @@ from jax.interpreters import mlir mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') ``` -You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_xla_translation`. +You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_lowering`. ```{code-cell} assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. From 4c8bed92703c855e9304127f3ec57313b9ebdf60 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 5 Sep 2024 12:47:16 -0700 Subject: [PATCH 380/702] Don't add a `sharding` property to `ShapedArray` if `sharding_in_types` flag is not switched on. PiperOrigin-RevId: 671475186 --- jax/_src/core.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 32677ade1b75..f6da7ac7a1cb 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1745,8 +1745,6 @@ def __init__(self, shape, dtype, weak_type=False, named_shape=None, self.weak_type = weak_type if config.sharding_in_types.value: self.sharding = sharding - else: - self.sharding = None def update(self, shape=None, dtype=None, weak_type=None, named_shape=None, sharding=None): @@ -1758,7 +1756,7 @@ def update(self, shape=None, dtype=None, weak_type=None, named_shape=None, if weak_type is None: weak_type = self.weak_type if sharding is None: - sharding = self.sharding + sharding = getattr(self, 'sharding', None) return ShapedArray(shape, dtype, weak_type, sharding=sharding) ndim = property(lambda self: len(self.shape)) @@ -1775,13 +1773,14 @@ def __eq__(self, other): return (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type - and self.sharding == other.sharding) + and getattr(self, 'sharding', None) == getattr(other, 'sharding', None)) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) - return hash((self.shape, self.dtype, self.weak_type, self.sharding)) + return hash((self.shape, self.dtype, self.weak_type, + getattr(self, 'sharding', None))) def at_least_vspace(self): return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), @@ -1800,10 +1799,10 @@ def str_short(self, short_dtypes=False): dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name dt_str = dt_str.replace('void', 'float0') shapestr = ','.join(map(str, self.shape)) - if self.sharding is None: - return f'{dt_str}[{shapestr}]' - else: + if hasattr(self, 'sharding'): return f'{dt_str}[{shapestr}]({self.sharding})' + else: + return f'{dt_str}[{shapestr}]' def _len(self, ignored_tracer): try: From a144eb234b7bd0286be2914c8c8d36ece3e564b5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 5 Sep 2024 14:15:33 -0700 Subject: [PATCH 381/702] Add compute_on_context_manager to thread local jit state. This is to avoid getting false cache hits PiperOrigin-RevId: 671507042 --- jax/BUILD | 2 +- jax/_src/compute_on.py | 5 +++++ jax/_src/config.py | 7 ++++++- tests/memories_test.py | 16 ++++++++++++++++ 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index c4a421362f37..4c622194941f 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -765,7 +765,7 @@ pytype_strict_library( pytype_strict_library( name = "compute_on", srcs = ["_src/compute_on.py"], - deps = [], + deps = [":config"], ) pytype_strict_library( diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 25b2be78d287..4495d38f9da8 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -15,6 +15,7 @@ from __future__ import annotations import threading from contextlib import contextmanager +from jax._src import config class ComputeOnContext(threading.local): @@ -28,6 +29,8 @@ def __init__(self): @contextmanager def extend_compute_type(c_type: str): compute_on_context.stack.append(c_type) + config.update_thread_local_jit_state( + compute_on_context_manager=tuple(compute_on_context.stack)) try: if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1: raise NotImplementedError( @@ -36,6 +39,8 @@ def extend_compute_type(c_type: str): yield compute_on_context.stack[-1] finally: compute_on_context.stack.pop() + config.update_thread_local_jit_state( + compute_on_context_manager=tuple(compute_on_context.stack)) def current_compute_type() -> str | None: return compute_on_context.stack[-1] if compute_on_context.stack else None diff --git a/jax/_src/config.py b/jax/_src/config.py index a0b91e2ad6ed..646e487c5f1c 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -202,12 +202,16 @@ def trace_context(): tls = jax_jit.thread_local_state() axis_env_state = () mesh_context_manager = () + compute_on_context_manager = () context: Any = tls.extra_jit_context if context and context.axis_env_state is not None: axis_env_state = context.axis_env_state if context and context.mesh_context_manager: mesh_context_manager = context.mesh_context_manager - return (axis_env_state, mesh_context_manager, enable_x64.value, + if context and context.compute_on_context_manager: + compute_on_context_manager = context.compute_on_context_manager + return (axis_env_state, mesh_context_manager, compute_on_context_manager, + enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, dynamic_shapes.value, numpy_dtype_promotion.value, default_device.value, random_seed_offset.value, @@ -853,6 +857,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): dynamic_trace_state: Any | None = None axis_env_state: Hashable = () mesh_context_manager: Hashable = () + compute_on_context_manager: Hashable = () # Values set by _StateContextManager context managers. # CAUTION: these must be initialized to `None`! The state context manager diff --git a/tests/memories_test.py b/tests/memories_test.py index affe5de99644..68aecfdf669f 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1420,6 +1420,22 @@ def test_fn(x_in, y_in): self.assertArraysEqual(x_out, x1 * x1) self.assertArraysEqual(y_out, y1 + y1) + def test_compute_on_cache_miss(self): + @jax.jit + def f(x): + return x * 2 + + inp = jnp.arange(10) + with jtu.count_jit_tracing_cache_miss() as count: + with compute_on('device_host'): + f(inp) + + with compute_on('device'): + f(inp) + + # 2 for `f` and `2` for `mul` (compute type changes for `mul`) + self.assertEqual(count[0], 4) + class ActivationOffloadingTest(jtu.JaxTestCase): From 02334cdaa5087a66cd5ed50830c246a3caa80f5c Mon Sep 17 00:00:00 2001 From: Piseth Ky Date: Fri, 14 Jun 2024 13:21:44 -0700 Subject: [PATCH 382/702] updating bitwise_right_shift_doc as an alias simpler bitwise_right_shift implementation to match previous PR updating bitwise_right_shift_doc as an alias readded jnp.bitwise_left_shift, jnp.bitwise_right_shift Update sharded-computation doc to use make_mesh() Rename `jtu.create_global_mesh` to `jtu.create_mesh` and use `jax.make_mesh` inside `jtu.create_mesh` to get maximum test coverage of the new API. PiperOrigin-RevId: 670744047 better true_divide and divide docs doc wording update [Mosaic TPU] Fix mosaic alignment check in concatenate rule. PiperOrigin-RevId: 670837792 Fix pytype errors and args for jax.Array methods Add docker builds for ubu22 and 24 Better docs for jax.numpy: log and log1p random.key_impl: improve repr of output Remove unused docstring addition: _PRECISION_DOC update example optimizers library docstring * JAXopt is being merged into Optax, so point only to Optax * Update Optax's github repository URL fixing merge duplication updating tests to skip bitwise shift if numpy major version < 2 removed whitespace 659 keep non-bitwise tests for numpy < 2.0.0 more readable edit --- jax/_src/numpy/ufuncs.py | 8 +++----- tests/lax_numpy_operators_test.py | 8 +++++++- tests/lax_numpy_test.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 91356f40b2fa..2cae8de1712b 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -974,13 +974,11 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2) -@implements(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy') + @partial(jit, inline=True) def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: - x1, x2 = promote_args_numeric("bitwise_right_shift", x1, x2) - lax_fn = lax.shift_right_logical if \ - np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic - return lax_fn(x1, x2) + """Alias of :func:`jax.numpy.right_shift`.""" + return right_shift(x1, x2) @partial(jit, inline=True) diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index 4c31684e145f..d9d6fa464d98 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -654,7 +654,13 @@ def testShiftOpAgainstNumpy(self, op, dtypes, shapes): shift_rng = jtu.rand_int(self.rng(), high=max(info.bits, shift_info.bits)) args_maker = lambda: (x_rng(shapes[0], dtype), shift_rng(shapes[1], shift_dtype)) - np_op = getattr(np, op.__name__) + if jtu.numpy_version() < (2, 0, 0) and op.__name__ in ("bitwise_left_shift", "bitwise_right_shift"): + # numpy < 2.0.0 does not have bitwise shift functions. + op_name = op.__name__.removeprefix("bitwise_") + else: + op_name = op.__name__ + + np_op = getattr(np, op_name) with jtu.strict_promotion_if_dtypes_match(dtypes): self._CompileAndCheck(op, args_maker) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 3a5d1def4906..534a395ef777 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6288,7 +6288,7 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', - 'amax', 'amin', 'around', 'divide', 'round_'] + 'amax', 'amin', 'around', 'bitwise_right_shift', 'divide', 'round_'] for name in dir(jnp): if name.startswith('_') or name in unimplemented: From 45dc05eaa6fcac0de5cbf2109e024252435e9e69 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 5 Sep 2024 14:30:25 -0700 Subject: [PATCH 383/702] Delete remote python repository rule calls from TF configs. Remote configurations of python repositories are removed because hermetic Python repository rules install and configure python modules in Bazel cache on the host machine. The cache is shared across host and remote machines. PiperOrigin-RevId: 671512134 --- .bazelrc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.bazelrc b/.bazelrc index 98bc29d8a991..9d5d9664939e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -209,11 +209,11 @@ build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylin build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" +build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" -build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" +build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" -build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.12" +build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" build:rbe_linux_cuda_base --config=rbe_linux @@ -231,11 +231,11 @@ build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_SYSROOT="/dt9" build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" +build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" -build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" +build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" -build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.12" +build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" # These you may need to change for your own GCP project. From 125bb4f158df081b6a921b1452f0d3b56acacc6b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 5 Sep 2024 14:58:50 -0700 Subject: [PATCH 384/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/3c4102f71c0dc443619b2848d3f5080377518166. PiperOrigin-RevId: 671522199 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 24be11d85893..a99cd0d6b6e9 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "79f2636de5521c8fb98fa4ab33724a5043fdf2d2" -XLA_SHA256 = "fe46154bbeac942fe166580ab0e91a329adcc7428aadd90fcf96784f0291fde8" +XLA_COMMIT = "3c4102f71c0dc443619b2848d3f5080377518166" +XLA_SHA256 = "73ddf9ed6dc16cedd5ffd6990e3331221916d4eff54c4445082cad096ff3d40a" def repo(): tf_http_archive( From 9c86fdec02532274d49e22c362f7a7bc5fae758e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 5 Sep 2024 19:49:12 +0000 Subject: [PATCH 385/702] Make optimization_barrier a public lax API. --- CHANGELOG.md | 3 ++ docs/jax.lax.rst | 1 + jax/_src/ad_checkpoint.py | 27 ++----------- jax/_src/lax/control_flow/__init__.py | 2 +- jax/_src/lax/lax.py | 56 +++++++++++++++++++++++++++ jax/lax/__init__.py | 2 + tests/lax_test.py | 4 ++ 7 files changed, 71 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e310b296b05f..93745d9d3d11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. {obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`, {obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`. In future releases we plan to expand these to other ufuncs. + * Added {func}`jax.lax.optimization_barrier`, which allows users to prevent + compiler optimizations such as common-subexpression elimination and to + control scheduling. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 7b19955d3d78..e0fc5ad46b3e 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -119,6 +119,7 @@ Operators ne neg nextafter + optimization_barrier pad platform_dependent polygamma diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 6bf481ef0496..fd30119882e7 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -27,7 +27,6 @@ from jax._src import api from jax._src import config from jax._src import core -from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu from jax._src import effects @@ -755,7 +754,7 @@ def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool, return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr) def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): - args = _optimization_barrier(args) + args = lax_internal.optimization_barrier(args) return core.eval_jaxpr(jaxpr, (), *args) # TODO(mattjj): add core utility for 'create dummy value for this type'? @@ -837,27 +836,6 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True), platform="gpu") -def _optimization_barrier_abstract_eval(*args): - return args - -def _optimization_barrier_lowering_rule(ctx, *args): - barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in) - flat_args = mlir.flatten_ir_values(args) - barrier_op = hlo.OptimizationBarrierOp(flat_args) - return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types) - -def _optimization_barrier(arg): - flat_args, treedef = tree_flatten(arg) - return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args)) - -optimization_barrier_p = core.Primitive('optimization_barrier') -optimization_barrier_p.multiple_results = True -optimization_barrier_p.def_impl( - partial(dispatch.apply_primitive, optimization_barrier_p)) -optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) -mlir.register_lowering(optimization_barrier_p, - _optimization_barrier_lowering_rule) - def checkpoint_name(x, name): return name_p.bind(x, name=name) @@ -936,3 +914,6 @@ def checkpoint_wrapper( raise NotImplementedError(msg) return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) + +# TODO(phawkins): update users to refer to the public name. +_optimization_barrier = lax_internal.optimization_barrier diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 05dcade84999..5e6fa86f706e 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -33,4 +33,4 @@ _initial_style_jaxprs_with_common_consts, _check_tree_and_avals) # TODO(mattjj): fix dependent library which expects optimization_barrier_p here -from jax._src.ad_checkpoint import optimization_barrier_p +from jax._src.lax.lax import optimization_barrier_p diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 2186e767ab37..8d2c24d6e64c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5346,3 +5346,59 @@ def convert_to(other_dtype, bint_dtype) -> bool: core.bint._rules = BIntRules + + +def optimization_barrier(operand, /): + """Prevents the compiler from moving operations across the barrier. + + Optimization barriers have a number of possible uses: + + * An optimization barrier ensures that all inputs are evaluated before any + operators that depend on the barrier's outputs. This can be used to enforce + a particular order of operations. + * An optimization barrier prevents common subexpression elimination. This is + used by JAX to implement rematerialization. + * Optimization barriers prevent compiler fusions. That is, operations before + the barrier may not be fused into the same kernel as operations after the + barrier by the compiler. + + JAX does not define derivative or batching rules for an optimization barrier. + + Optimization barriers have no effect outside a compiled function. + + Args: + operand: a pytree of JAX values. + + Returns: + A pytree of JAX values, with the same structure and contents as ``operand``. + + Examples: + Prevents common-subexpression elimination between the two calls to `sin`: + + >>> def f(x): + ... return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x) + >>> jax.jit(f)(0.) + Array(0., dtype=float32, weak_type=True) + """ + flat_args, treedef = tree_util.tree_flatten(operand) + return tree_util.tree_unflatten( + treedef, optimization_barrier_p.bind(*flat_args)) + + +def _optimization_barrier_abstract_eval(*args): + return args + +def _optimization_barrier_lowering_rule(ctx, *args): + barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in) + flat_args = mlir.flatten_ir_values(args) + barrier_op = hlo.OptimizationBarrierOp(flat_args) + return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types) + + +optimization_barrier_p = core.Primitive('optimization_barrier') +optimization_barrier_p.multiple_results = True +optimization_barrier_p.def_impl( + partial(dispatch.apply_primitive, optimization_barrier_p)) +optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) +mlir.register_lowering(optimization_barrier_p, + _optimization_barrier_lowering_rule) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index bac005b81650..e2bcd5de9408 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -142,6 +142,8 @@ nextafter as nextafter, nextafter_p as nextafter_p, not_p as not_p, + optimization_barrier as optimization_barrier, + optimization_barrier_p as optimization_barrier_p, or_p as or_p, outfeed as outfeed, outfeed_p as outfeed_p, diff --git a/tests/lax_test.py b/tests/lax_test.py index 7ed17adf45bc..ce30131953af 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3093,6 +3093,10 @@ def testAsarray(self, typ): with jax.transfer_guard('disallow'): jax.jit(asarray_closure)() + def testOptimizationBarrier(self): + x = lax.optimization_barrier((2, 3)) + self.assertEqual((2, 3), x) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): From 27e19239cafa88e4a6518316084c5869630bedd5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 5 Sep 2024 16:02:32 +0000 Subject: [PATCH 386/702] Fix triton capi_objects target to depend on MLIR CAPIIRObjects bazel target. "...Objects" targets should only depend on other "...Objects" targets in MLIR land. Don't mix them. --- jaxlib/triton/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index 95482e47e864..99cddd9e6381 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -116,7 +116,7 @@ cc_library( hdrs = ["triton_dialect_capi.h"], deps = [ "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPIIRObjects", "@llvm-project//mlir:IR", "@triton//:TritonDialects", ], From e3b8177af3c4cb6b7891a4164612f65588ec2a69 Mon Sep 17 00:00:00 2001 From: Sebastian Bodenstein Date: Thu, 5 Sep 2024 18:41:53 -0700 Subject: [PATCH 387/702] Internal change. PiperOrigin-RevId: 671583042 --- jaxlib/gpu/BUILD | 14 +++++++++++++- jaxlib/jax.bzl | 3 +++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index f3524ccdf781..8c4144974b4a 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -14,7 +14,12 @@ # Shared CUDA/ROCM GPU kernels. -load("//jaxlib:jax.bzl", "cc_proto_library") +load( + "//jaxlib:jax.bzl", + "cc_proto_library", + "jax_visibility", + "xla_py_proto_library", +) licenses(["notice"]) @@ -72,3 +77,10 @@ cc_proto_library( name = "triton_cc_proto", deps = [":triton_proto"], ) + +xla_py_proto_library( + name = "triton_py_pb2", + api_version = 2, + visibility = jax_visibility("triton_proto_py_users"), + deps = [":triton_proto"], +) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 7df848aad843..cf9047cc4e17 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -297,3 +297,6 @@ def jax_generate_backend_suites(backends = []): ) jax_test_file_visibility = [] + +def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable + pass From 0d8ffd33abada5da6338a8977796578ec4e1cc55 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 6 Sep 2024 11:52:12 +0300 Subject: [PATCH 388/702] [shape_polyO] Improve handling of equality shape constraints This fixes several bugs in presence of equality constraints where the left-hand side is just a dimension variable. First, such constraints were not applied when parsing variables. Now, with a constraint `a == b` when we parse "a" we obtain `b`. Second, when we evaluate symbolic dimensions that contain dimension variables that are constrained to be equal to something else, we may fail to find the dimension variable in the environment because the environment construction has applied the constraints. We fix this by looking up the unknown dimension variable in the equality constraints. Fixes: #23437 Fixes: #23456 --- docs/export/shape_poly.md | 4 +- jax/_src/export/shape_poly.py | 60 ++++++++++++++-------- jax/_src/export/shape_poly_decision.py | 8 ++- jax/experimental/jax2tf/jax2tf.py | 2 +- tests/shape_poly_test.py | 70 +++++++++++++++++++++++--- 5 files changed, 111 insertions(+), 33 deletions(-) diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 695ca6cd21d9..f3da7a0e55fe 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -353,7 +353,7 @@ symbolic constraints: E.g., `floordiv(a, b) == c` works by replacing all occurences of `floordiv(a, b)` with `c`. Equality constraints must not contain addition or - subtraction at the top-leve on the left-hand-side. Examples of + subtraction at the top-level on the left-hand-side. Examples of valid left-hand-sides are `a * b`, or `4 * a`, or `floordiv(a + c, b)`. @@ -530,7 +530,7 @@ Array([[ 9, 8, 7], >>> k, = export.symbolic_shape("k", constraints=["k <= 10"]) >>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): -KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments +UnexpectedDimVar: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments ``` diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 0173df4fd345..77786cbf1a9d 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -80,6 +80,8 @@ def __init__(self, message: str): # https://github.com/python/mypy/issues/5887 super().__init__(error_msg) +class UnexpectedDimVar(Exception): + pass class Comparator(Enum): EQ = 1 @@ -87,12 +89,14 @@ class Comparator(Enum): @dataclasses.dataclass(frozen=True) class _SymbolicConstraint: + # Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2 cmp: Comparator debug_str: str # The form in which the user expressed it, for error messages - diff: _DimExpr # For GEQ: diff >= 0, and for EQ: diff == 0 + e1: DimSize # This has been normalized w.r.t. previous constraints only + e2: DimSize # This has been normalized w.r.t. previous constraints only def __repr__(self): - return f"Constraint({self.debug_str}: {self.diff})" + return f"Constraint({self.debug_str})" class _DimFactor: @@ -209,15 +213,22 @@ def __ge__(self, other: _DimFactor): """Lexicographic comparison""" return self._syntactic_cmp(other) >= 0 - def evaluate(self, env: DimVarEnv): + def evaluate(self, env: DimVarEnv, scope: SymbolicScope): if self.var is not None: try: return env[self.var] except KeyError: + # Perhaps there is a normalization rule for this variable + normalized_var = _DimExpr._from_var(self.var, scope) + if core.is_constant_dim(normalized_var): + return normalized_var + non_trivial_normalization = (v1 := normalized_var._to_var()) is None or v1 != self.var # type: ignore + if non_trivial_normalization: + return normalized_var._evaluate(env) # type: ignore err_msg = ( f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n" "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") - raise KeyError(err_msg) + raise UnexpectedDimVar(err_msg) else: operand_values = [opnd._evaluate(env) for opnd in self.operands] if self.operation == _DimFactor.FLOORDIV: @@ -370,11 +381,11 @@ def divide(self, divisor: _DimTerm) -> _DimTerm: raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.") return _DimTerm(new_factors) - def evaluate(self, env: DimVarEnv): + def evaluate(self, env: DimVarEnv, scope: SymbolicScope): prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1) def pow_opt(v, p: int): return v if p == 1 else prod([v] * p) - return prod([pow_opt(f.evaluate(env), exp) for f, exp in self._factors]) + return prod([pow_opt(f.evaluate(env, scope), exp) for f, exp in self._factors]) def __deepcopy__(self, memo): return _DimTerm(copy.deepcopy(self._factors, memo)) @@ -404,7 +415,7 @@ class _DimExpr: def __init__(self, sorted_terms: SortedTerms, scope: SymbolicScope): # Do not construct _DimExpr directly, unless you are sure that `terms` is - # normalized; Use _DimExpr.normalize. + # normalized; Use _DimExpr._normalize_sorted_terms. self._sorted_terms = tuple(sorted_terms) or ((_DimTerm_one, 0),) self._scope = scope self._hash = None @@ -426,8 +437,8 @@ def _from_term(t: _DimTerm, t_k: int, scope: SymbolicScope) -> DimSize: return _DimExpr._normalize_sorted_terms(((t, t_k),), scope) @staticmethod - def _from_var(v: str, scope: SymbolicScope) -> _DimExpr: - return _DimExpr(((_DimTerm.from_var(v), 1),), scope) + def _from_var(v: str, scope: SymbolicScope) -> DimSize: + return _DimExpr._normalize_sorted_terms(((_DimTerm.from_var(v), 1),), scope) @staticmethod def _from_operation(operation: str, *operands: DimSize, @@ -475,8 +486,9 @@ def _add_coeff(coeffs: dict[_DimTerm, int], t: _DimTerm, coeff: int): def _normalize_term(t: _DimTerm, t_k: int, scope: SymbolicScope) -> Sequence[tuple[_DimTerm, int]]: # If (t, t_k) is among the scope normalization rules, then return - # a list of updates to apply to the expression containing (t, t_k). - # Returns empty sequence if no normalizations are necessary. + # a list of `term * coefficient` to add to the expression containing (t, t_k). + # Returns the empty sequence if no normalizations are necessary. + if not scope._normalization_rules: return [] updates = [] after, t_k_after = scope._normalization_rules.get(t, (None, 0)) if after is not None and t_k % t_k_after == 0: @@ -899,7 +911,7 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]: def _evaluate(self, env: DimVarEnv): # Evaluates as a value of dtype=core.dim_value_dtype() - terms = [_evaluate_multiply(t.evaluate(env), core.dim_constant(t_k)) + terms = [_evaluate_multiply(t.evaluate(env, self.scope), core.dim_constant(t_k)) for t, t_k in self._sorted_terms] return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0] @@ -1046,8 +1058,6 @@ def _parse_and_process_explicit_constraint(self, c_str: str): raise ValueError(f"Unsatisfiable explicit constraint: {c_str}") return - constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, diff=diff) # type: ignore[arg-type] - self._explicit_constraints.append(constr) if cmp == Comparator.EQ: if not isinstance(e1, _DimExpr): raise ValueError("Invalid equality constraint: {e1} == {e2}. " @@ -1063,6 +1073,9 @@ def _parse_and_process_explicit_constraint(self, c_str: str): f"Found multiple equality constraints with the same left-hand-side: {before}") self._normalization_rules[before] = (after, before_k) + constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2) + self._explicit_constraints.append(constr) + def _check_same_scope(self, other: _DimExpr, when: str = "", self_descr: str = " ", @@ -2016,7 +2029,7 @@ def _solve_dim_equations( # Returns a shape environment and the shape constraints if it can solve all # dimension variables. Raises an exception if it cannot. shape_env: DimVarEnv = {} - solution_error_message_pieces: list[str | _DimExpr] = [ + solution_error_message_pieces: list[str | DimSize] = [ " Obtained dimension variables: " ] # Error message describing the solution # Prepare error message piece describing the polymorphic shape specs @@ -2050,8 +2063,8 @@ def process_one_eqn(eqn: _DimEquation) -> bool: for term, term_k in eqn.aval_dim_expr._sorted_terms: # Perhaps we can already evaluate this term (all vars solved) try: - term_value = term.evaluate(shape_env) - except KeyError: + term_value = term.evaluate(shape_env, scope) + except UnexpectedDimVar: # `mon` still uses some variables not yet solved. We handle only the # case when `mon` is a single variable. v = term.to_var() @@ -2118,14 +2131,19 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): if not shape_env: return assert scope is not None for constr in scope._explicit_constraints: - c_value = constr.diff._evaluate(shape_env) + # We can't just construct constr.e1 - constr.e2 because for an equality + # constraint it would be reduced to 0. + c_e1 = constr.e1._evaluate(shape_env) if not core.is_constant_dim(constr.e1) else constr.e1 # type: ignore + c_e2 = constr.e2._evaluate(shape_env) if not core.is_constant_dim(constr.e2) else constr.e2 # type: ignore + c_diff = c_e1 - c_e2 shape_constraints.add_constraint( - constr.cmp, c_value, 0, + constr.cmp, c_diff, 0, error_message_pieces=[ f"Input shapes do not match the symbolic shape constraint {constr.debug_str}. " - f"Expected '{constr.diff}' to be " + f"Expected '{constr.e1} - {constr.e2}' to be " f"{'greater or equal' if constr.cmp == Comparator.GEQ else 'equal'} to 0, " - "but found ", c_value, + "but found ", c_diff, + ". " + poly_specs_err_msg ] + solution_error_message_pieces + [ solution_err_msg_trailer_errors]) diff --git a/jax/_src/export/shape_poly_decision.py b/jax/_src/export/shape_poly_decision.py index e325722b0c26..4bad8b7be06d 100644 --- a/jax/_src/export/shape_poly_decision.py +++ b/jax/_src/export/shape_poly_decision.py @@ -23,6 +23,7 @@ import numpy as np +from jax._src import core from jax._src.export import shape_poly from jax._src.export.shape_poly import ( _DimExpr, _DimTerm, _DimFactor, @@ -84,7 +85,10 @@ def initialize(self) -> _DecisionByElimination: # the result (albeit, for now, without a good feedback loop to understand # how the order matters for inequalities). for constr in self.scope._explicit_constraints: - self.add_implicit_constraints_expr(constr.diff) + if not core.is_constant_dim(constr.e1): + self.add_implicit_constraints_expr(constr.e1) # type: ignore + if not core.is_constant_dim(constr.e2): + self.add_implicit_constraints_expr(constr.e2) # type: ignore # The equality constraints are not needed for inequality decisions, # because the LHS should always be rewritten in terms of the RHS. # In fact, adding them may break the assumption that if we eliminate @@ -92,7 +96,7 @@ def initialize(self) -> _DecisionByElimination: # may appear in the rest and may be rewritten to something larger. # However, we want to add the implicit constraints within. if constr.cmp == Comparator.GEQ: - self.combine_and_add_constraint(constr.cmp, constr.diff, 0, + self.combine_and_add_constraint(constr.cmp, constr.e1 - constr.e2, 0, constr.debug_str) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 545945c91ffd..741887abf24b 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1253,7 +1253,7 @@ def __init__(self, trace: TensorFlowTrace, val: TfVal, # We have a TF value with known shape, and the abstract shape is a shape variable. try: aval_int = int(_eval_shape([aval_dim])) # type: ignore - except (TypeError, KeyError): + except (TypeError, KeyError, shape_poly.UnexpectedDimVar): continue assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index d5b32cdbd7fc..357b2e08d091 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -995,6 +995,16 @@ def test_constraints_ge_override(self): self.assertEqual(_bounds(a), (10, np.inf)) self.assertEqual(_bounds(b), (1, 10)) + def test_constraint_eq_0(self): + a, b, c, d = shape_poly.symbolic_shape( + "a, b, c, d", + constraints=("b == a", "c == a + b", "d == 5")) + # Check that we have already applied the normalizaton rules + self.assertEqual(a._to_var(), "a") + self.assertEqual(b._to_var(), "a") + self.assertEqual(c._to_single_term(), (0, 2, a._to_term())) + self.assertIs(d, 5) + def test_constraints_eq_1(self): # Some constaints override other a, b, c = shape_poly.symbolic_shape("a, b, c", @@ -1073,6 +1083,20 @@ def test_constraints_eq_7(self): self.assertEqual(128 * (t1_ceil // 128), t1_ceil) self.assertEqual(128 * b1 * (t1_ceil // 128), b1 * t1_ceil) + def test_constraints_eq_bug_23456(self): + b, = jax.export.symbolic_shape('b', constraints=['b==5']) + jax.eval_shape(lambda k: jnp.tile(k, 3), jax.ShapeDtypeStruct((b,), jnp.float32)) + + def test_constraints_eq_bug_23437(self): + def f1(x, y): + return x + y + + x = jnp.ones((4,), dtype=jnp.int32) + y = jnp.ones((4,), dtype=jnp.int32) + args_specs = jax.export.symbolic_args_specs((x, y), ("a*2", "b*2"), constraints=("a==b",)) + exp = jax.export.export(jax.jit(f1))(*args_specs) + self.assertEqual(exp.in_avals[0], exp.in_avals[1]) + def test_constraints_eq_threefry(self): # Test equalities that arise out of the threefree lowering # x : i32[a] # a may be even or odd @@ -1106,12 +1130,9 @@ def test_constraints_a_minus_4d_eq(self): assumptions1 = ["m1 >= 0", "m1 <= 3", "a1 == 4*d1 + m1"] scope1 = shape_poly.SymbolicScope(assumptions1) a1, d1, m1 = shape_poly.symbolic_shape("a1, d1, m1", scope=scope1) - # TODO: The incompleteness is due to the way we combine external constraints self.assertEqual(_bounds(a1 - 4*d1), (1, 3)) # a - 4d = m >= 1 self.assertEqual(_bounds(a1 - 2*d1), (3, np.inf)) # a - 2d = m + 2d >= 3 - # TODO: The incompleteness is due to the way we combine external constraints - self.assertEqual(_bounds(a1), - _expect(best=(5, np.inf), current=(-np.inf, np.inf))) # a >= 4d + m >= 5 + self.assertEqual(_bounds(a1), (5, np.inf)) # a >= 4d + m >= 5 def test_constraints_error_msg(self): a, b = shape_poly.symbolic_shape("a, b", @@ -1642,8 +1663,7 @@ def f(x): # x: i32[a, b] _ = export.export(jax.jit(f))( jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), np.int32)) - - def test_constraints_compile_time_check(self): + def test_constraints_ge_compile_time_check(self): def f(x): # x: i32[a] a = x.shape[0] assert _bounds(a) == (2, 4) @@ -1669,9 +1689,45 @@ def f(x): # x: i32[a] with self.assertRaisesRegex( ValueError, - re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")): + re.escape("Expected '4 - a' to be greater or equal to 0, but found -1")): exp.call(np.arange(5, dtype=np.int32)) + def test_constraints_eq_0_compile_time_check(self): + def f(x): # x: i32[a, b] + return x + + x_spec = jax.ShapeDtypeStruct( + export.symbolic_shape("a, b", + constraints=["max(a, b) == b"]), np.int32) + exp = export.export(jax.jit(f))(x_spec) + with self.assertRaisesRegex( + ValueError, + re.escape("Expected 'max(a, b) - b' to be equal to 0, but found 1")): + exp.call(np.ones((3, 2), dtype=np.int32)) + + def test_constraints_eq_1_compile_time_check(self): + def f(x): # x: i32[a, b] + return x + + x_spec = jax.ShapeDtypeStruct( + export.symbolic_shape("a, b", + constraints=["a == b"]), np.int32) + exp = export.export(jax.jit(f))(x_spec) + exp.call(np.ones((3, 3), dtype=np.int32)) + + def test_constraints_eq_2_compile_time_check(self): + def f(x): # x: i32[a, b] + return x + + x_spec = jax.ShapeDtypeStruct( + export.symbolic_shape("a, b", + constraints=["max(a, b) == 4", "a == b"]), np.int32) + exp = export.export(jax.jit(f))(x_spec) + with self.assertRaisesRegex( + ValueError, + re.escape("Expected 'max(a, b) - 4' to be equal to 0, but found -1")): + exp.call(np.ones((3, 3), dtype=np.int32)) + def test_caching_with_scopes(self): f_tracing_count = 0 expected_a_bounds = (1, np.inf) From ef947a0ce673201b60763fe1f8b26c1c7dcd0fbc Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 6 Sep 2024 04:38:56 -0700 Subject: [PATCH 389/702] Added a bit more error checking to Pallas Mosaic GPU pipelining logic PiperOrigin-RevId: 671711873 --- jax/_src/pallas/mosaic_gpu/lowering.py | 57 ++++++++++++++------------ 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 88cd4545ae7b..fa546949780e 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -33,7 +33,6 @@ from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect -from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src.state import primitives as sp @@ -128,7 +127,6 @@ class LoweringRuleContext: class LoweringResult: module: ir.Module grid: tuple[int, ...] - gmem_scratch_bytes: int out_structs: tuple[jax.ShapeDtypeStruct, ...] @@ -199,7 +197,11 @@ def lower_jaxpr_to_module( dimension_semantics = params.get( "dimension_semantics", ["parallel"] * len(grid_mapping.grid) ) - assert len(dimension_semantics) == len(grid_mapping.grid) + if len(dimension_semantics) != len(grid_mapping.grid): + raise ValueError( + "dimension_semantics must have an entrey for each grid dimension:" + f" {len(dimension_semantics)=}, but len(grid={grid_mapping.grid})." + ) sequential_axes = tuple( i for i, s in enumerate(dimension_semantics) if s == "sequential" ) @@ -310,14 +312,20 @@ def store(step: ir.Value, slot: ir.Value) -> None: "Array dimensions along the sequential axis must be divisible by" " the corresponding block dimensions." ) - [num_steps] = { + num_steps, *rest = { b_gmem.shape[sequential_axis] // b_smem.shape[1 + sequential_axis] for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) } + if rest: + raise ValueError( + "Array dimensions along the sequential axis must produce the same" + " number of steps when devided by the corresponding block" + " dimensions." + ) else: num_steps = 1 - for slot in range(num_stages): + for slot in range(min(num_stages, num_steps)): fetch(_as_index(slot), _as_index(slot)) @mgpu.fori(_as_index(num_steps), ()) @@ -339,9 +347,8 @@ def _(step, _): next_step_in_bounds = arith_dialect.cmpi( arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps) ) - with ir.InsertionPoint(scf_dialect.IfOp(next_step_in_bounds).then_block): + with mgpu.when(next_step_in_bounds): fetch(next_step, slot) - scf_dialect.yield_([]) return () @@ -355,28 +362,26 @@ def _(step, _): dtype=np.int8, ) ] - module, out_structs_smem, gmem_scratch_bytes, _ = ( - mosaic_gpu._lower_as_gpu_kernel( - body, - grid=grid, - cluster=(), - block=block, - in_shapes=in_structs_gmem, - out_shape=out_structs_gmem, - smem_scratch_shape=( - *in_structs_smem, - *out_structs_smem, - *extra_smem_scratch, - mgpu.Barrier( - arrival_count=len(in_structs_gmem), - num_barriers=num_stages, - ), + module, out_structs_smem, _ = mosaic_gpu._lower_as_gpu_kernel( + body, + grid=grid, + cluster=(), + block=block, + in_shapes=in_structs_gmem, + out_shape=out_structs_gmem, + smem_scratch_shape=( + *in_structs_smem, + *out_structs_smem, + *extra_smem_scratch, + mgpu.Barrier( + arrival_count=len(in_structs_gmem), + num_barriers=num_stages, ), - module_name=name_and_src_info.name, - ) + ), + module_name=name_and_src_info.name, ) - return LoweringResult(module, grid, gmem_scratch_bytes, out_structs_smem) + return LoweringResult(module, grid, out_structs_smem) mosaic_lowering_rules = {} From 7326db77917556de1a403e6597b3036140c3f219 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 6 Sep 2024 06:59:34 -0700 Subject: [PATCH 390/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/53fd000440d637359a6546a04b06d11c823553ed. PiperOrigin-RevId: 671741772 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a99cd0d6b6e9..20ffedcb9393 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "3c4102f71c0dc443619b2848d3f5080377518166" -XLA_SHA256 = "73ddf9ed6dc16cedd5ffd6990e3331221916d4eff54c4445082cad096ff3d40a" +XLA_COMMIT = "53fd000440d637359a6546a04b06d11c823553ed" +XLA_SHA256 = "fe1098143e2515d472b87d9052bd48e60d9332729502d2d98742e2b7892a2937" def repo(): tf_http_archive( From bcbc0962bb03b6afa9073b2459d36f36b13664c9 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 6 Sep 2024 12:30:28 -0400 Subject: [PATCH 391/702] Add the FFI functions and tutorial to the changelog. Although we soft launched the FFI with v0.4.31, it would be nice to include an update in the changelog to help with visibility. --- CHANGELOG.md | 5 +++++ docs/ffi.ipynb | 2 ++ docs/ffi.md | 2 ++ 3 files changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93745d9d3d11..869b9dfdd196 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.32 +* New Functionality + * Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering` + to support the use of the new {ref}`ffi-tutorial` to interface with custom + C++ and CUDA code from JAX. + * Changes * `jax_enable_memories` flag is set to `True` by default. * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index 12a2781f7b13..7f7bcc07ce85 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -4,6 +4,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "(ffi-tutorial)=\n", + "\n", "# Foreign function interface (FFI)\n", "\n", "_This tutorial requires JAX v0.4.31 or newer._\n", diff --git a/docs/ffi.md b/docs/ffi.md index 802fd4f2264e..d96d9ff8c4fc 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -12,6 +12,8 @@ kernelspec: name: python3 --- +(ffi-tutorial)= + # Foreign function interface (FFI) _This tutorial requires JAX v0.4.31 or newer._ From fc6b22e2e4d4577ae9a6df6c51ac98b045543577 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 6 Sep 2024 10:55:42 -0700 Subject: [PATCH 392/702] [host_callback] Fix type promotion error Fix a type error that arises when we try to run the host callback tests with JAX_HOST_CALLBACK_LEGACY=False (in the process of deprecating jax.experimental.host_callback). PiperOrigin-RevId: 671825020 --- tests/host_callback_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 5988b2774408..944b47dc8b1d 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -2240,7 +2240,7 @@ def f_outside(arg): def test_call_cond(self): def f_outside(args): x, y = args - return x * y + return x * y.astype(np.float32) def loop(x, use_outside=True): def body(i, acc): @@ -2253,8 +2253,8 @@ def body(i, acc): return lax.fori_loop(0, 18, body, x) - res_inside = loop(1.2, use_outside=False) - self.assertAllClose(res_inside, jax.jit(loop)(1.2)) + res_inside = loop(np.float32(1.2), use_outside=False) + self.assertAllClose(res_inside, jax.jit(loop)(np.float32(1.2))) def test_call_jit_scan_call(self): def f_outside(x): From f97bfc85a38cf82cb71211c8448ad25855feace6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 6 Sep 2024 11:57:36 -0700 Subject: [PATCH 393/702] Implement symmetric_product() to produce a symmetric matrix: `C = alpha * X @ X.T + beta * C` PiperOrigin-RevId: 671845818 --- jax/_src/lax/linalg.py | 76 +++++++++++++++++++++ jax/experimental/jax2tf/jax2tf.py | 1 + jaxlib/cuda/BUILD | 4 ++ jaxlib/gpu/gpu_kernels.cc | 2 + jaxlib/gpu/solver.cc | 4 ++ jaxlib/gpu/solver_kernels_ffi.cc | 107 ++++++++++++++++++++++++++++++ jaxlib/gpu/solver_kernels_ffi.h | 2 + jaxlib/gpu/vendor.h | 22 ++++++ tests/linalg_test.py | 38 +++++++++++ 9 files changed, 256 insertions(+) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 714b1279037b..e2d2e2f0ab8d 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -210,6 +210,20 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: return cholesky_update_p.bind(r_matrix, w_vector) +def symmetric_product( + a_matrix: ArrayLike, c_matrix: ArrayLike, + alpha: float = 1., beta: float = 0., + symmetrize_output=False): + """Computes C = alpha * A @ A.T + beta * C (where C is symmetric).""" + result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta) + if symmetrize_output: + upper_half = lax.transpose( + _tril(result, k=-1), + (*range(result.ndim - 2), result.ndim - 1, result.ndim - 2)) + result = _tril(result, k=0) + upper_half + return result + + def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array: """Converts the pivots (row swaps) returned by LU to a permutation. @@ -592,6 +606,7 @@ def _drot( R = R.at[k, :].set(row_k) return R + cholesky_update_p = Primitive('cholesky_update') cholesky_update_p.multiple_results = False cholesky_update_p.def_abstract_eval(_cholesky_update_abstract_eval) @@ -604,6 +619,67 @@ def _drot( cholesky_update_p, mlir.lower_fun(_cholesky_update_jax_fn, multiple_results=False)) +# symmetric_update + +def _symmetric_product_abstract_eval(a, c, *, alpha, beta): + a_dtype = dtypes.canonicalize_dtype(a.dtype) + c_dtype = dtypes.canonicalize_dtype(c.dtype) + if not (a_dtype == c_dtype and a_dtype in (np.float32, np.float64)): + raise NotImplementedError( + "Symmetric update is only implemented for float32 and float64.") + if not (a.ndim >= 2 and c.ndim >= 2 + and a.shape[-2] == c.shape[-1] + and c.shape[-1] == c.shape[-2]): + raise ValueError( + "Symmetric update takes (maybe batched) matrices of matching shapes. " + "Got shapes {}, {} instead".format(a.shape, c.shape)) + return ShapedArray(c.shape, c.dtype) + + +def _symmetric_product_batching_rule(batched_args, batch_dims, *, alpha, beta): + a_tensor, c_tensor = batched_args + a_bd, c_bd = batch_dims + a_tensor = batching.moveaxis(a_tensor, a_bd, 0) + c_tensor = batching.moveaxis(c_tensor, c_bd, 0) + return ( + symmetric_product_p.bind(a_tensor, c_tensor, alpha=alpha, beta=beta), 0) + +symmetric_product_p = Primitive('symmetric_update') +symmetric_product_p.multiple_results = False +symmetric_product_p.def_abstract_eval(_symmetric_product_abstract_eval) +symmetric_product_p.def_impl( + partial(dispatch.apply_primitive, symmetric_product_p)) +batching.primitive_batchers[ + symmetric_product_p] = _symmetric_product_batching_rule + + +def _symmetric_product_gpu_lowering( + platform, ctx, a_tensor, c_tensor, alpha, beta): + a_aval, c_aval = ctx.avals_in[:2] + dtype = a_aval.dtype + alpha_aval = beta_aval = ShapedArray((), dtype) + + alpha_array = mlir.full_like_aval(ctx, alpha, alpha_aval) + beta_array = mlir.full_like_aval(ctx, beta, beta_aval) + + rule = ffi.ffi_lowering(f"{platform}_syrk_ffi", operand_output_aliases={1: 0}) + ctx = ctx.replace(avals_in=[a_aval, c_aval, alpha_aval, beta_aval]) + return rule(ctx, a_tensor, c_tensor, alpha_array, beta_array, transpose=False) + + +def _symmetric_product_jax_fn(a, c, *, alpha, beta): + a_T = lax.transpose(a, (*range(a.ndim - 2), a.ndim - 1, a.ndim - 2)) + return alpha * lax.batch_matmul( + a, a_T, precision=lax.Precision.HIGHEST) + beta * c + + +mlir.register_lowering( + symmetric_product_p, + partial(_symmetric_product_gpu_lowering, 'cu'), platform='cuda') +mlir.register_lowering( + symmetric_product_p, + mlir.lower_fun(_symmetric_product_jax_fn, multiple_results=False)) + # Asymmetric eigendecomposition def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 741887abf24b..24dee390f398 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1548,6 +1548,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "consume", "ragged_dot", "cholesky_update", + "symmetric_update", # Pallas TPU primitives "bitcast", "repeat", diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 72db9868e427..5cf85f3697c7 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -351,6 +351,7 @@ cc_library( hdrs = ["//jaxlib/gpu:linalg_kernels.h"], features = ["-use_header_modules"], deps = [ + ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", ":cuda_linalg_kernels_impl", ":cuda_vendor", @@ -363,6 +364,7 @@ cc_library( "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cublas", ], ) @@ -373,6 +375,8 @@ cuda_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", + "//jaxlib:ffi_helpers", + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", ], diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 3841393654a8..c17d0ac9fd5a 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -46,6 +46,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32, XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_syrk_ffi", "CUDA", + SyrkFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index fee1c1014c75..2a006d033afe 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -473,12 +473,16 @@ nb::dict Registrations() { #ifdef JAX_GPU_CUDA dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr); dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); + #endif // JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); + dict[JAX_GPU_PREFIX "_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); + + return dict; } diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 6e988a6ca5e6..b757d303510f 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -483,5 +483,112 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, #undef SOLVER_DISPATCH_IMPL + +#define SYRK_KERNEL_IMPL(type, fn) \ + template <> \ + struct SyrkKernel { \ + static absl::Status Run(gpublasHandle_t handle, std::int64_t n, \ + std::int64_t k, bool transpose, \ + const type* alpha, const type* beta, \ + const type* a_matrix, type* c_matrix) { \ + gpublasOperation_t op = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; \ + gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; \ + int lda = transpose ? n : k; \ + return JAX_AS_STATUS(fn(handle, uplo, op, n, k, \ + alpha, a_matrix, lda, beta, \ + c_matrix, n)); \ + } \ + } + +template +struct SyrkKernel; + +SYRK_KERNEL_IMPL(float, gpublasSsyrk); +SYRK_KERNEL_IMPL(double, gpublasDsyrk); +SYRK_KERNEL_IMPL(gpublasComplex, gpublasCsyrk); +SYRK_KERNEL_IMPL(gpublasDoubleComplex, gpublasZsyrk); +#undef SYRK_KERNEL_IMPL + +template +ffi::Error SyrkImpl(gpuStream_t stream, + ffi::AnyBuffer a_matrix, + ffi::AnyBuffer c_matrix, + bool transpose, + ffi::AnyBuffer alpha, + ffi::AnyBuffer beta, + ffi::Result c_matrix_out) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a_matrix.dimensions())); + FFI_ASSIGN_OR_RETURN((auto [batch_c, rows_c, cols_c]), + SplitBatch2D(c_matrix.dimensions())); + FFI_ASSIGN_OR_RETURN((auto [batch_out, rows_out, cols_out]), + SplitBatch2D(c_matrix_out->dimensions())); + if (batch != batch_c || batch != batch_out) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "a_matrix, c_matrix and c_matrix_out must have the same " + "batch size."); + } + int n = transpose ? cols : rows; + int k = transpose ? rows : cols; + + FFI_RETURN_IF_ERROR( + CheckShape(c_matrix_out->dimensions().last(2), {n, n}, "out", "Syrk")); + FFI_RETURN_IF_ERROR( + CheckShape(c_matrix.dimensions().last(2), {n, n}, "C", "Syrk")); + + const T* a_data = static_cast(a_matrix.untyped_data()); + T* c_data = static_cast(c_matrix.untyped_data()); + T* c_out_data = static_cast(c_matrix_out->untyped_data()); + + // with alpha or beta provided as device_pointers, cublassyrk will SIGSEGV + T host_alpha; + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + &host_alpha, alpha.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost, + stream))); + + T host_beta; + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + &host_beta, beta.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost, + stream))); + + if (c_data != c_out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + c_out_data, c_data, c_matrix.size_bytes(), gpuMemcpyDeviceToDevice, + stream))); + } + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); + for (int i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(SyrkKernel::Run( + handle.get(), n, k, transpose, &host_alpha, &host_beta, + a_data + i * k * n, c_out_data + i * n * n)); + } + return ffi::Error::Success(); +} + +ffi::Error SyrkDispatch( + gpuStream_t stream, + ffi::AnyBuffer a_matrix, + ffi::AnyBuffer c_matrix, + bool transpose, + ffi::AnyBuffer alpha, + ffi::AnyBuffer beta, + ffi::Result c_matrix_out) { + auto dataType = a_matrix.element_type(); + SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, a_matrix, c_matrix, transpose, + alpha, beta, c_matrix_out); + return ffi::Error::InvalidArgument("Unsupported element type for Syrk"); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Arg() // a_matrix + .Arg() // c_matrix + .Attr("transpose") // transpose + .Arg() // alpha + .Arg() // beta + .Ret()); // c_matrix_out + + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 7dbc7454c2e6..4d9b6d1371fa 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -25,6 +25,8 @@ namespace JAX_GPU_NAMESPACE { XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 077d3bb54185..bc61d58181ab 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -57,6 +57,9 @@ typedef cuDoubleComplex gpublasDoubleComplex; typedef cublasFillMode_t gpusolverFillMode_t; typedef cublasStatus_t gpublasStatus_t; typedef cublasHandle_t gpublasHandle_t; +typedef cublasOperation_t gpublasOperation_t; +typedef cublasFillMode_t gpublasFillMode_t; + typedef CUcontext gpuContext_t; typedef CUstreamCaptureMode gpustreamCaptureMode_t; typedef CUstreamCaptureStatus gpustreamCaptureStatus_t; @@ -101,6 +104,11 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpublasCgetrfBatched cublasCgetrfBatched #define gpublasZgetrfBatched cublasZgetrfBatched +#define gpublasSsyrk cublasSsyrk +#define gpublasDsyrk cublasDsyrk +#define gpublasCsyrk cublasCsyrk +#define gpublasZsyrk cublasZsyrk + #define GPUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS #define gpudnnCreate cudnnCreate @@ -190,6 +198,10 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR #define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS +#define GPUBLAS_OP_N CUBLAS_OP_N +#define GPUBLAS_OP_T CUBLAS_OP_T +#define GPUBLAS_OP_C CUBLAS_OP_C + #define gpusparseCooSetStridedBatch cusparseCooSetStridedBatch #define gpusparseCreate cusparseCreate #define gpusparseCreateCoo cusparseCreateCoo @@ -330,6 +342,7 @@ typedef hipsolverHandle_t gpusolverDnHandle_t; typedef hipblasFillMode_t gpublasFillMode_t; typedef hipsolverFillMode_t gpusolverFillMode_t; typedef hipblasHandle_t gpublasHandle_t; +typedef hipblasOperation_t gpublasOperation_t; typedef hipblasStatus_t gpublasStatus_t; typedef hipCtx_t gpuContext_t; typedef hipStreamCaptureMode gpustreamCaptureMode_t; @@ -372,6 +385,11 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpublasCgetrfBatched hipblasCgetrfBatched #define gpublasZgetrfBatched hipblasZgetrfBatched +#define gpublasSsyrk hipblasSsyrk +#define gpublasDsyrk hipblasDsyrk +#define gpublasCsyrk hipblasCsyrk +#define gpublasZsyrk hipblasZsyrk + #define GPUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define gpusolverDnCreate hipsolverCreate @@ -456,6 +474,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR #define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS +#define GPUBLAS_OP_N HIPBLAS_OP_N +#define GPUBLAS_OP_T HIPBLAS_OP_T +#define GPUBLAS_OP_C HIPBLAS_OP_C + #define gpusparseCooSetStridedBatch hipsparseCooSetStridedBatch #define gpusparseCreate hipsparseCreate #define gpusparseSetStream hipsparseSetStream diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 944880066437..0eb4f800309b 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -2187,6 +2187,44 @@ def testHilbert(self, n): self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker) self._CompileAndCheck(jsp_fun, args_maker) + @jtu.sample_product( + shape=[ + (128, 12), + (128, 64), + (2048, 128), + ], + dtype=[jnp.float32, jnp.float64], + symmetrize_output=[True, False], + ) + @jtu.skip_on_devices("tpu") + def testSymmetricProduct(self, shape, dtype, symmetrize_output): + if dtype is jnp.float64 and not config.enable_x64.value: + self.skipTest("Test disabled for x32 mode") + + rng = jtu.rand_default(self.rng()) + batch_size = 10 + atol = 1e-6 if dtype == jnp.float64 else 1e-3 + + a_matrix = rng((batch_size,) + shape, dtype) + c_shape = a_matrix.shape[:-1] + (a_matrix.shape[-2],) + c_matrix = jnp.zeros(c_shape, dtype) + + old_product = jnp.einsum("...ij,...kj->...ik", a_matrix, a_matrix) + new_product = lax_linalg.symmetric_product( + a_matrix, c_matrix, symmetrize_output=symmetrize_output) + new_product_with_batching = jax.vmap( + lambda a, c: lax_linalg.symmetric_product( + a, c, symmetrize_output=symmetrize_output), + in_axes=(0, 0))(a_matrix, c_matrix) + + if not symmetrize_output: + old_product = jnp.tril(old_product) + new_product = jnp.tril(new_product) + new_product_with_batching = jnp.tril(new_product_with_batching) + self.assertAllClose(new_product, old_product, atol=atol) + self.assertAllClose( + new_product_with_batching, old_product, atol=atol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 1d12a9934c1514137a96fc4208ce7fc31f46d260 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 6 Sep 2024 12:22:10 -0700 Subject: [PATCH 394/702] Port GPU kernel for symmetric eigendecomposition to GPU. Of note, I moved the logic about which algorithm to use, and when to use the batched algorithm into the kernel in order to support shape polymorphism and export. PiperOrigin-RevId: 671853879 --- jaxlib/gpu/gpu_kernels.cc | 2 + jaxlib/gpu/solver.cc | 3 +- jaxlib/gpu/solver_kernels_ffi.cc | 235 ++++++++++++++++++++++++++++++- jaxlib/gpu/solver_kernels_ffi.h | 10 +- 4 files changed, 243 insertions(+), 7 deletions(-) diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index c17d0ac9fd5a..2310a4cf20ed 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -57,6 +57,8 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", OrgqrFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA", + SyevdFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA"); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 2a006d033afe..43c2dfd85604 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -479,10 +479,9 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); - + dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi); dict[JAX_GPU_PREFIX "_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); - return dict; } diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index b757d303510f..3c74b85192ad 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include #include "absl/status/status.h" @@ -30,6 +32,8 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::JAX_GPU_NAMESPACE::SyevdAlgorithm); + namespace jax { namespace JAX_GPU_NAMESPACE { @@ -48,6 +52,21 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, } return static_cast(maybe_workspace.value()); } + +template +struct RealType { + using Type = T; +}; + +template <> +struct RealType { + using Type = float; +}; + +template <> +struct RealType { + using Type = double; +}; } // namespace #define SOLVER_DISPATCH_IMPL(impl, ...) \ @@ -206,7 +225,8 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, SOLVER_DISPATCH_IMPL(GetrfImpl, batch, rows, cols, stream, scratch, a, out, ipiv, info); } - return ffi::Error::InvalidArgument("Unsupported element type for getrf"); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in getrf", absl::FormatStreamed(dataType))); } } // namespace @@ -362,7 +382,8 @@ ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, SOLVER_DISPATCH_IMPL(GeqrfImpl, batch, rows, cols, stream, scratch, a, out, tau); } - return ffi::Error::InvalidArgument("Unsupported element type for geqrf"); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in geqrf", absl::FormatStreamed(dataType))); } } // namespace @@ -468,7 +489,8 @@ ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, CheckShape(out->dimensions(), {batch, rows, cols}, "out", "orgqr")); SOLVER_DISPATCH_IMPL(OrgqrImpl, batch, rows, cols, size, stream, scratch, a, tau, out); - return ffi::Error::InvalidArgument("Unsupported element type for orgqr"); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in orgqr", absl::FormatStreamed(dataType))); } } // namespace @@ -481,8 +503,211 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, .Ret() // out ); -#undef SOLVER_DISPATCH_IMPL +// Symmetric (Hermitian) eigendecomposition: +// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) +// * QR algorithm: syevd/heevd +// For historical reasons, the target is called "syevd" even though it +// dispatches dynamically to both syevd and syevj depending on the problem +// size and the algorithm selected by the user via the `algorithm` attribute. + +namespace { +#define SYEVJ_KERNEL_IMPL(type, name) \ + template <> \ + struct SyevjKernel { \ + static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, \ + gpuSyevjInfo_t params) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ + /*w=*/nullptr, &lwork, params))); \ + return lwork; \ + } \ + static absl::Status Run(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ + int n, type* a, RealType::Type* w, \ + type* workspace, int lwork, int* info, \ + gpuSyevjInfo_t params) { \ + return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \ + lwork, info, params)); \ + } \ + } + +template +struct SyevjKernel; +SYEVJ_KERNEL_IMPL(float, gpusolverDnSsyevj); +SYEVJ_KERNEL_IMPL(double, gpusolverDnDsyevj); +SYEVJ_KERNEL_IMPL(gpuComplex, gpusolverDnCheevj); +SYEVJ_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevj); +#undef SYEVJ_KERNEL_IMPL + +#define SYEVJ_BATCHED_KERNEL_IMPL(type, name) \ + template <> \ + struct SyevjBatchedKernel { \ + static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, \ + gpuSyevjInfo_t params, int batch) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ + /*w=*/nullptr, &lwork, params, batch))); \ + return lwork; \ + } \ + static absl::Status Run(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ + int n, type* a, RealType::Type* w, \ + type* workspace, int lwork, int* info, \ + gpuSyevjInfo_t params, int batch) { \ + return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \ + lwork, info, params, batch)); \ + } \ + } + +template +struct SyevjBatchedKernel; +SYEVJ_BATCHED_KERNEL_IMPL(float, gpusolverDnSsyevjBatched); +SYEVJ_BATCHED_KERNEL_IMPL(double, gpusolverDnDsyevjBatched); +SYEVJ_BATCHED_KERNEL_IMPL(gpuComplex, gpusolverDnCheevjBatched); +SYEVJ_BATCHED_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevjBatched); +#undef SYEVJ_BATCHED_KERNEL_IMPL + +#define SYEVD_KERNEL_IMPL(type, name) \ + template <> \ + struct SyevdKernel { \ + static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ + /*w=*/nullptr, &lwork))); \ + return lwork; \ + } \ + static absl::Status Run(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ + int n, type* a, RealType::Type* w, \ + type* workspace, int lwork, int* info) { \ + return JAX_AS_STATUS( \ + name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \ + } \ + } +template +struct SyevdKernel; +SYEVD_KERNEL_IMPL(float, gpusolverDnSsyevd); +SYEVD_KERNEL_IMPL(double, gpusolverDnDsyevd); +SYEVD_KERNEL_IMPL(gpuComplex, gpusolverDnCheevd); +SYEVD_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevd); +#undef SYEVD_KERNEL_IMPL + +template +ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, + ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm, + bool lower, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result w, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto w_data = static_cast::Type*>(w->untyped_data()); + auto info_data = info->typed_data(); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + } + if (algorithm == SyevdAlgorithm::kJacobi || + (algorithm == SyevdAlgorithm::kDefault && size <= 32)) { + gpuSyevjInfo_t params; + FFI_RETURN_IF_ERROR_STATUS( + JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); + std::unique_ptr params_cleanup( + params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); + + if (batch == 1) { + FFI_ASSIGN_OR_RETURN(int lwork, SyevjKernel::BufferSize( + handle.get(), jobz, uplo, n, params)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "syevj")); + FFI_RETURN_IF_ERROR_STATUS( + SyevjKernel::Run(handle.get(), jobz, uplo, n, out_data, w_data, + workspace, lwork, info_data, params)); + } else { + FFI_ASSIGN_OR_RETURN( + int lwork, SyevjBatchedKernel::BufferSize(handle.get(), jobz, uplo, + n, params, batch)); + FFI_ASSIGN_OR_RETURN( + auto workspace, + AllocateWorkspace(scratch, lwork, "syevj_batched")); + FFI_RETURN_IF_ERROR_STATUS(SyevjBatchedKernel::Run( + handle.get(), jobz, uplo, n, out_data, w_data, workspace, lwork, + info_data, params, batch)); + } + } else { + FFI_ASSIGN_OR_RETURN( + int lwork, SyevdKernel::BufferSize(handle.get(), jobz, uplo, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "syevd")); + int out_step = n * n; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS( + SyevdKernel::Run(handle.get(), jobz, uplo, n, out_data, w_data, + workspace, lwork, info_data)); + out_data += out_step; + w_data += n; + ++info_data; + } + } + return ffi::Error::Success(); +} + +ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + SyevdAlgorithm algorithm, bool lower, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result w, + ffi::Result> info) { + auto dataType = a.element_type(); + if (dataType != out->element_type() || + ffi::ToReal(dataType) != w->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to syevd must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to syevd must be square"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "syevd")); + FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "syevd")); + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "syevd")); + SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, algorithm, + lower, a, out, w, info); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in syevd", absl::FormatStreamed(dataType))); +} +} // namespace + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("algorithm") + .Attr("lower") + .Arg() // a + .Ret() // out + .Ret() // w + .Ret>() // info +); #define SYRK_KERNEL_IMPL(type, fn) \ template <> \ @@ -589,6 +814,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch, .Arg() // beta .Ret()); // c_matrix_out +#undef SOLVER_DISPATCH_IMPL +#undef SOLVER_BLAS_DISPATCH_IMPL } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 4d9b6d1371fa..3bebe40bee26 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -16,18 +16,26 @@ limitations under the License. #ifndef JAXLIB_GPU_SOLVER_KERNELS_FFI_H_ #define JAXLIB_GPU_SOLVER_KERNELS_FFI_H_ +#include + #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" namespace jax { namespace JAX_GPU_NAMESPACE { +enum class SyevdAlgorithm : uint8_t { + kDefault = 0, + kDivideAndConquer, + kJacobi, +}; + XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); - } // namespace JAX_GPU_NAMESPACE } // namespace jax From b6213aaa858601955fcae905a8203efad438afec Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 6 Sep 2024 12:34:13 -0700 Subject: [PATCH 395/702] Make pltpu key derivation more robust. PiperOrigin-RevId: 671857080 --- jax/_src/pallas/mosaic/random.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index c642d99578cd..68a4fe508917 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -14,8 +14,8 @@ from collections.abc import Callable +import functools import jax -import numpy as np from jax import numpy as jnp from jax import random as jax_api_random from jax._src import blocked_sampler @@ -37,15 +37,13 @@ def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray: """Helper function for converting non-Pallas PRNG keys into Pallas keys.""" - batch_dims = key.shape - key_data = jax_api_random.key_data(key) - pallas_key_size = np.prod(tpu_key_impl.key_shape) - if np.prod(key_data.shape[len(batch_dims):]) < pallas_key_size: - raise ValueError(f"Key data must be at least {pallas_key_size} bytes.") - pallas_key_data = jnp.reshape( - key_data, batch_dims + (-1,))[..., :pallas_key_size] - pallas_key_data = jnp.reshape(pallas_key_data, - batch_dims + tpu_key_impl.key_shape) + generate_key = functools.partial( + jax.random.bits, shape=tpu_key_impl.key_shape, dtype=jnp.uint32 + ) + if len(key.shape) == 0: + pallas_key_data = generate_key(key) + else: + pallas_key_data = (jax.vmap(generate_key))(key) return jax_api_random.wrap_key_data(pallas_key_data, impl="pallas_tpu") def _seed_func(seed: jnp.int32): From 7266e338c85bea349483328288d43ac1e90e4f6a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 6 Sep 2024 13:20:49 -0700 Subject: [PATCH 396/702] Update FFI target name for `syrk` operation to be consistent with other kernels. PiperOrigin-RevId: 671870569 --- jax/_src/lax/linalg.py | 3 ++- jaxlib/gpu/gpu_kernels.cc | 2 +- jaxlib/gpu/solver.cc | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index e2d2e2f0ab8d..0cc0e774af53 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -662,7 +662,8 @@ def _symmetric_product_gpu_lowering( alpha_array = mlir.full_like_aval(ctx, alpha, alpha_aval) beta_array = mlir.full_like_aval(ctx, beta, beta_aval) - rule = ffi.ffi_lowering(f"{platform}_syrk_ffi", operand_output_aliases={1: 0}) + rule = ffi.ffi_lowering(f"{platform}solver_syrk_ffi", + operand_output_aliases={1: 0}) ctx = ctx.replace(avals_in=[a_aval, c_aval, alpha_aval, beta_aval]) return rule(ctx, a_tensor, c_tensor, alpha_array, beta_array, transpose=False) diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 2310a4cf20ed..62977c5f57a1 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -46,7 +46,7 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32, XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); -XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_syrk_ffi", "CUDA", +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", SyrkFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 43c2dfd85604..c65ad088af21 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -480,7 +480,7 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi); - dict[JAX_GPU_PREFIX "_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); + dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); return dict; } From 2ce0fc25e03d0b1109f4a4124577b9e6e3880fef Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 6 Sep 2024 13:57:49 -0700 Subject: [PATCH 397/702] Fix tolerances for failing linalg tests. PiperOrigin-RevId: 671881600 --- tests/linalg_test.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 0eb4f800309b..9fcc940c6c2c 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -2188,19 +2188,12 @@ def testHilbert(self, n): self._CompileAndCheck(jsp_fun, args_maker) @jtu.sample_product( - shape=[ - (128, 12), - (128, 64), - (2048, 128), - ], - dtype=[jnp.float32, jnp.float64], + shape=[(5, 1), (10, 4), (128, 12)], + dtype=float_types, symmetrize_output=[True, False], ) @jtu.skip_on_devices("tpu") def testSymmetricProduct(self, shape, dtype, symmetrize_output): - if dtype is jnp.float64 and not config.enable_x64.value: - self.skipTest("Test disabled for x32 mode") - rng = jtu.rand_default(self.rng()) batch_size = 10 atol = 1e-6 if dtype == jnp.float64 else 1e-3 @@ -2209,7 +2202,8 @@ def testSymmetricProduct(self, shape, dtype, symmetrize_output): c_shape = a_matrix.shape[:-1] + (a_matrix.shape[-2],) c_matrix = jnp.zeros(c_shape, dtype) - old_product = jnp.einsum("...ij,...kj->...ik", a_matrix, a_matrix) + old_product = jnp.einsum("...ij,...kj->...ik", a_matrix, a_matrix, + precision=lax.Precision.HIGHEST) new_product = lax_linalg.symmetric_product( a_matrix, c_matrix, symmetrize_output=symmetrize_output) new_product_with_batching = jax.vmap( From 51a666fb8c022bfa4a4995cb29ae570e67b459bd Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 6 Sep 2024 14:26:52 -0700 Subject: [PATCH 398/702] [Pallas] Update Pallas docs with new figures and TPUCompilerParams --- .../pallas/distributed/reduce_sum_1.svg | 2 +- .../pallas/distributed/reduce_sum_2.svg | 2 +- docs/pallas/tpu/details.rst | 6 ++---- docs/pallas/tpu/distributed.ipynb | 18 +++++++++--------- docs/pallas/tpu/distributed.md | 16 ++++++++-------- docs/pallas/tpu/matmul.ipynb | 16 ++++++++-------- docs/pallas/tpu/matmul.md | 16 ++++++++-------- docs/pallas/tpu/pipelining.ipynb | 3 ++- docs/pallas/tpu/pipelining.md | 3 ++- 9 files changed, 41 insertions(+), 41 deletions(-) diff --git a/docs/_static/pallas/distributed/reduce_sum_1.svg b/docs/_static/pallas/distributed/reduce_sum_1.svg index 6c397a87be88..9a527aff6a2e 100644 --- a/docs/_static/pallas/distributed/reduce_sum_1.svg +++ b/docs/_static/pallas/distributed/reduce_sum_1.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/reduce_sum_2.svg b/docs/_static/pallas/distributed/reduce_sum_2.svg index ef2a76330a61..61685cf41863 100644 --- a/docs/_static/pallas/distributed/reduce_sum_2.svg +++ b/docs/_static/pallas/distributed/reduce_sum_2.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index ae9505c4eb8b..93d7e55473f2 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -148,10 +148,8 @@ grid axes over cores. This is an opt-in procedure. To allow that, .. pallas_call( ..., - compiler_params=dict( - mosaic=dict( - dimension_semantics=["parallel", "parallel", "arbitrary"] - ) + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=["parallel", "parallel", "arbitrary"] ), ) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index 5209f2ff8e52..8552e10d8552 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -196,11 +196,7 @@ "- If a device calls `.wait_recv()` but no other device sends to it, the kernel may hang.\n", "- If a device is sent a more bytes than it expected to receive, it may also crash due to non-zero semaphore states. If sent less, it may hang indefinitely.\n", "- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states.\n", - "- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states.\n", - "\n", - "### Megacore\n", - "\n", - "Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = lax.axis_index(name)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core." + "- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states." ] }, { @@ -576,7 +572,7 @@ "kernel = pl.pallas_call(\n", " example_kernel,\n", " ...,\n", - " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", ")\n", "```" ] @@ -815,7 +811,7 @@ " all_reduce_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", ")\n", "\n", "pallas_result = jax.jit(\n", @@ -1169,7 +1165,7 @@ " reduce_scatter_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", " )(input_arr)[0]\n", "\n", "\n", @@ -1626,7 +1622,7 @@ " reduce_scatter_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=dict(mosaic=dict(collective_id=0)),\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", " )(input_arr)[0]\n", "\n", "\n", @@ -1705,6 +1701,10 @@ "source": [ "## Final Notes\n", "\n", + "### Megacore\n", + "\n", + "Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n", + "\n", "### Interaction with XLA\n", "\n", "In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations.\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index b7c058b117ca..dbdb00e8018f 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -183,10 +183,6 @@ Some common causes of the above include: - If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states. - If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states. -### Megacore - -Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = lax.axis_index(name)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core. - +++ {"id": "vpGSN1Sui0Bu"} ### Example: Right Permute (`lax.ppermute`) @@ -498,7 +494,7 @@ When using barrier semaphores, the `collective_id` compiler parameter must be pa kernel = pl.pallas_call( example_kernel, ..., - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), ) ``` @@ -709,7 +705,7 @@ kernel = pl.pallas_call( all_reduce_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), ) pallas_result = jax.jit( @@ -1042,7 +1038,7 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), )(input_arr)[0] @@ -1460,7 +1456,7 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), )(input_arr)[0] @@ -1518,6 +1514,10 @@ print( ## Final Notes +### Megacore + +Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core. + ### Interaction with XLA In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations. diff --git a/docs/pallas/tpu/matmul.ipynb b/docs/pallas/tpu/matmul.ipynb index 0bd16095cb7e..51ce2ed6868f 100644 --- a/docs/pallas/tpu/matmul.ipynb +++ b/docs/pallas/tpu/matmul.ipynb @@ -210,8 +210,8 @@ " pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n", " out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n", " grid=(m // bm, n // bn, k // bk),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, @@ -466,8 +466,8 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, @@ -741,8 +741,8 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, @@ -929,8 +929,8 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, diff --git a/docs/pallas/tpu/matmul.md b/docs/pallas/tpu/matmul.md index a00880ebaf37..e542dedc7d10 100644 --- a/docs/pallas/tpu/matmul.md +++ b/docs/pallas/tpu/matmul.md @@ -167,8 +167,8 @@ def matmul( pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))], out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), grid=(m // bm, n // bn, k // bk), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -321,8 +321,8 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -489,8 +489,8 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -613,8 +613,8 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 275a72f3837b..2a3aa9d114de 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -33,6 +33,7 @@ "\n", "import jax\n", "from jax.experimental import pallas as pl\n", + "from jax.experimental.pallas import tpu as pltpu\n", "import jax.numpy as jnp\n", "import numpy as np" ] @@ -696,7 +697,7 @@ " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", " grid=(2,),\n", - " compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",)))\n", + " compiler_params=pltpu.TPUCompilerParams(dimension_semantics=(\"parallel\",))\n", " )(x, y)\n", "\n", "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index d753b404db1a..67c1900a0d3a 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -29,6 +29,7 @@ pipelines in Pallas that overlap memory I/O with compute. import jax from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np ``` @@ -465,7 +466,7 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: in_specs=[block_spec, block_spec], out_specs=block_spec, grid=(2,), - compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))) + compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",)) )(x, y) x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) From 671acef5aba1a0caaea62c75491884a2a67d28e5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 6 Sep 2024 14:33:54 -0700 Subject: [PATCH 399/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b0368b065c36a4dc27cf8de8a31f6510fa6f5086. PiperOrigin-RevId: 671893460 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 20ffedcb9393..a0173da6221e 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "53fd000440d637359a6546a04b06d11c823553ed" -XLA_SHA256 = "fe1098143e2515d472b87d9052bd48e60d9332729502d2d98742e2b7892a2937" +XLA_COMMIT = "b0368b065c36a4dc27cf8de8a31f6510fa6f5086" +XLA_SHA256 = "82b889ff11bac258df7e9230fbb45921f0adb0cde85561f8ab699fb0f4cfb5f1" def repo(): tf_http_archive( From 02b7a767683dd92c7cd3503fbb0d60b2a7440bb9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 6 Sep 2024 16:44:19 -0700 Subject: [PATCH 400/702] Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes. PiperOrigin-RevId: 671930431 --- jax/BUILD | 11 + jax/_src/config.py | 9 +- jax/_src/core.py | 20 +- jax/_src/interpreters/mlir.py | 27 +++ jax/_src/interpreters/partial_eval.py | 8 +- jax/_src/xla_metadata.py | 55 +++++ jax/experimental/xla_metadata.py | 17 ++ tests/BUILD | 6 + tests/xla_metadata_test.py | 290 ++++++++++++++++++++++++++ 9 files changed, 434 insertions(+), 9 deletions(-) create mode 100644 jax/_src/xla_metadata.py create mode 100644 jax/experimental/xla_metadata.py create mode 100644 tests/xla_metadata_test.py diff --git a/jax/BUILD b/jax/BUILD index 4c622194941f..74072dc44644 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -299,6 +299,7 @@ py_library_providing_imports_info( ":version", ":xla", ":xla_bridge", + ":xla_metadata", "//jax/_src/lib", ] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps, ) @@ -451,6 +452,7 @@ pytype_strict_library( ":tree_util", ":typing", ":util", + ":xla_metadata", "//jax/_src/lib", ] + py_deps("numpy"), ) @@ -703,6 +705,7 @@ pytype_strict_library( ":state_types", ":tree_util", ":util", + ":xla_metadata", ] + py_deps("numpy"), ) @@ -768,6 +771,14 @@ pytype_strict_library( deps = [":config"], ) +pytype_strict_library( + name = "xla_metadata", + srcs = ["_src/xla_metadata.py"], + deps = [ + ":config", + ], +) + pytype_strict_library( name = "layout", srcs = ["_src/layout.py"], diff --git a/jax/_src/config.py b/jax/_src/config.py index 646e487c5f1c..b2d1aa52ef2a 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -202,16 +202,20 @@ def trace_context(): tls = jax_jit.thread_local_state() axis_env_state = () mesh_context_manager = () + xla_metadata_context_manager = () compute_on_context_manager = () + context: Any = tls.extra_jit_context if context and context.axis_env_state is not None: axis_env_state = context.axis_env_state if context and context.mesh_context_manager: mesh_context_manager = context.mesh_context_manager + if context and context.xla_metadata_context_manager: + xla_metadata_context_manager = context.xla_metadata_context_manager if context and context.compute_on_context_manager: compute_on_context_manager = context.compute_on_context_manager - return (axis_env_state, mesh_context_manager, compute_on_context_manager, - enable_x64.value, + return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, + compute_on_context_manager, enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, dynamic_shapes.value, numpy_dtype_promotion.value, default_device.value, random_seed_offset.value, @@ -858,6 +862,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): axis_env_state: Hashable = () mesh_context_manager: Hashable = () compute_on_context_manager: Hashable = () + xla_metadata_context_manager: Hashable = () # Values set by _StateContextManager context managers. # CAUTION: these must be initialized to `None`! The state context manager diff --git a/jax/_src/core.py b/jax/_src/core.py index f6da7ac7a1cb..ef3ace2e0e31 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -55,6 +55,7 @@ from jax._src import traceback_util from jax._src.typing import Array, DimSize, Shape from jax._src import typing +from jax._src import xla_metadata as xla_metadata_lib traceback_util.register_exclusion(__file__) @@ -261,12 +262,15 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args): class JaxprEqnContext: - def __init__(self, compute_type: str | None, threefry_partitionable: bool): + def __init__(self, compute_type: str | None, threefry_partitionable: bool, + xla_metadata=None): self.compute_type = compute_type self.threefry_partitionable = threefry_partitionable + self.xla_metadata = xla_metadata self._managers = [ (compute_on.extend_compute_type, self.compute_type), (config.threefry_partitionable.__call__, self.threefry_partitionable), + (xla_metadata_lib.set_xla_metadata, self.xla_metadata), ] @property @@ -278,8 +282,11 @@ def manager(self): yield def __repr__(self): - return (f"JaxprEqnContext(compute_type={self.compute_type}," - f"threefry_partitionable={self.threefry_partitionable})") + return ( + f"JaxprEqnContext(compute_type={self.compute_type}," + f"threefry_partitionable={self.threefry_partitionable})," + f"xla_metadata={self.xla_metadata}" + ) class JaxprEqn: @@ -333,8 +340,11 @@ def replace( def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, ctx=None): source_info = source_info or source_info_util.new_source_info() - ctx = ctx or JaxprEqnContext(compute_on.current_compute_type(), - config.threefry_partitionable.value) + ctx = ctx or JaxprEqnContext( + compute_on.current_compute_type(), + config.threefry_partitionable.value, + xla_metadata_lib.current_xla_metadata(), + ) if config.enable_checks.value: assert all(isinstance(x, (Var, Literal)) for x in invars) assert all(isinstance(v, Var) for v in outvars) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 0e7e0146e984..c4c77c72b88b 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1923,6 +1923,10 @@ def lower_per_platform(ctx: LoweringRuleContext, lambda o: wrap_compute_type_in_place(ctx, o.owner), filter(_is_not_block_argument, flatten_ir_values(output)), ) + map( + lambda o: wrap_xla_metadata_in_place(ctx, o.owner), + flatten_ir_values(output), + ) return output assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules) @@ -1964,6 +1968,10 @@ def lower_per_platform(ctx: LoweringRuleContext, lambda o: wrap_compute_type_in_place(ctx, o.owner), filter(_is_not_block_argument, out_nodes), ) + map( + lambda o: wrap_xla_metadata_in_place(ctx, o.owner), + out_nodes, + ) if inner_ctx.tokens_out is not None: assert len(ordered_effects) == len(inner_ctx.tokens_out) out_nodes = [inner_ctx.tokens_out.get(eff) @@ -2125,6 +2133,25 @@ def wrap_compute_type_in_place(ctx, op): op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) +def wrap_xla_metadata_in_place(ctx, op): + ctx_attributes = {} + existing_attributes = {} + if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.xla_metadata: + for k, v in ctx.jaxpr_eqn_ctx.xla_metadata.items(): + ctx_attributes[k] = ir.StringAttr.get(str(v).lower()) + if isinstance(op, ir.Operation): + # combine with existing mhlo.frontend_attributes + op_attributes_dict = {attr.name: attr.attr for attr in op.attributes} + for k, attributes in op_attributes_dict.items(): + if k == "mhlo.frontend_attributes": + v_dict = {attr.name: attr.attr for attr in attributes} + for fa_key, fa_val in v_dict.items(): + existing_attributes[fa_key] = fa_val + op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + ctx_attributes | existing_attributes + ) + + def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *, broadcast_dimensions) -> ir.Value: # broadcast_dimension[i] is the axis of the result where the axis i of diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 94f8918aa2ae..5bb3e204ced0 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -35,6 +35,7 @@ from jax._src import profiler from jax._src import source_info_util from jax._src import compute_on +from jax._src import xla_metadata as xla_metadata_lib from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs, fun_sourceinfo) from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval, @@ -898,8 +899,11 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], assert ("donated_invars" in params and len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers] - ctx = ctx or JaxprEqnContext(compute_on.current_compute_type(), - config.threefry_partitionable.value) + ctx = ctx or JaxprEqnContext( + compute_on.current_compute_type(), + config.threefry_partitionable.value, + xla_metadata_lib.current_xla_metadata(), + ) return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers), out_avals, primitive, params, effects, source_info, ctx) diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py new file mode 100644 index 000000000000..94b482e2dea4 --- /dev/null +++ b/jax/_src/xla_metadata.py @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +import threading +from contextlib import contextmanager + +from jax._src import config + + +class _XlaMetadata(threading.local): + val: dict[Any, Any] + + def __init__(self): + self.val = {} + +thread_local_metadata = _XlaMetadata() + +def current_xla_metadata(): + return thread_local_metadata.val + +@contextmanager +def set_xla_metadata(*args, **kwargs): + new_metadata = thread_local_metadata.val.copy() + if args: + new_metadata.update(args[0] if args[0] else {}) + else: + new_metadata.update(**kwargs) + prev_metadata, thread_local_metadata.val = ( + thread_local_metadata.val, + new_metadata, + ) + config.update_thread_local_jit_state( + xla_metadata_context_manager=tuple( + (v, k) for k, v in sorted(new_metadata.items()))) + try: + yield + finally: + thread_local_metadata.val = prev_metadata + config.update_thread_local_jit_state( + xla_metadata_context_manager=tuple( + (v, k) for k, v in sorted(prev_metadata.items()) + ) + ) diff --git a/jax/experimental/xla_metadata.py b/jax/experimental/xla_metadata.py new file mode 100644 index 000000000000..fb15e4743d2b --- /dev/null +++ b/jax/experimental/xla_metadata.py @@ -0,0 +1,17 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the ific language governing permissions and +# limitations under the License. + +from jax._src.xla_metadata import ( + set_xla_metadata as set_xla_metadata, +) diff --git a/tests/BUILD b/tests/BUILD index b624b6bef3ac..f8cf35c5d1e0 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1505,6 +1505,12 @@ jax_test( tags = ["multiaccelerator"], ) +jax_test( + name = "xla_metadata_test", + srcs = ["xla_metadata_test.py"], + deps = ["//jax:experimental"], +) + py_test( name = "pretty_printer_test", srcs = ["pretty_printer_test.py"], diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py new file mode 100644 index 000000000000..38bd7e05533e --- /dev/null +++ b/tests/xla_metadata_test.py @@ -0,0 +1,290 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests whether the frontend attributes added by the context manager are + +correctly propagated to the jaxpr and mlir. +""" + +from absl.testing import absltest +import jax +from jax._src import config +from jax._src import dispatch +from jax._src import test_util as jtu +from jax._src.lax import lax +from jax.experimental.xla_metadata import set_xla_metadata +import jax.numpy as jnp + +config.parse_flags_with_absl() + + +class XlaMetadataTest(jtu.JaxTestCase): + + def test_f_jitted(self): + @jax.jit + def f(a, b): + with set_xla_metadata(a="b"): + return a + b + + f_jaxpr = jax.make_jaxpr(f)(1, 2) + eqns = f_jaxpr.eqns + for eq in eqns[1:]: + self.assertDictEqual(eq.ctx.attributes, {"a": "b"}) + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn('mhlo.frontend_attributes = {a = "b"}', f_lowered_text) + + def test_f_jitted_bool_attributes(self): + @jax.jit + def f(a, b): + with set_xla_metadata(a=True): + return a + b + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn('mhlo.frontend_attributes = {a = "true"}', f_lowered_text) + + def test_f_jitted_int_attributes(self): + @jax.jit + def f(a, b): + with set_xla_metadata(a=10): + return a + b + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn('mhlo.frontend_attributes = {a = "10"}', f_lowered_text) + + def test_f_nonjitted(self): + def f_add(a, b): + return dispatch.apply_primitive(lax.add_p, a, b) + + arg1 = jnp.arange(2) + with set_xla_metadata(a="b"): + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', + jax.jit(f_add).lower(arg1, arg1).as_text(), + ) + + def test_f_attributes_overwrite(self): + @jax.jit + def g(a, b): + return a * b + + with set_xla_metadata(a="b"): + + @jax.jit + def f(a, b): + with set_xla_metadata(a="c"): + return a + b + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn('mhlo.frontend_attributes = {a = "c"}', f_lowered_text) + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', g.lower(1.0, 2.0).as_text() + ) + self.assertNotIn("mhlo.frontend_attributes", g.lower(1.0, 2.0).as_text()) + + def test_f_attributes_merge(self): + with set_xla_metadata(key1="val1"): + + @jax.jit + def f(a, b): + with set_xla_metadata(key2="val2"): + return a + b + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn( + 'mhlo.frontend_attributes = {key1 = "val1", key2 = "val2"}', + f_lowered_text, + ) + + def test_attr_caching_jit(self): + @jax.jit + def f_add_jit(a, b): + return a + b + + with set_xla_metadata(b="c"): + f_add_lowered1 = f_add_jit.lower(2.0, 3.0).as_text() + # Expect no attributes in the mlir. + f_add_lowered2 = f_add_jit.lower(1.0, 2.0).as_text() + with set_xla_metadata(c="d"): + f_add_lowered3 = f_add_jit.lower(4.0, 5.0).as_text() + self.assertIn('mhlo.frontend_attributes = {b = "c"}', f_add_lowered1) + self.assertNotIn("mhlo.frontend_attributes = {}", f_add_lowered2) + self.assertNotIn('mhlo.frontend_attributes = {b = "c"}', f_add_lowered2) + self.assertNotIn('mhlo.frontend_attributes = {c = "d"}', f_add_lowered2) + self.assertIn('mhlo.frontend_attributes = {c = "d"}', f_add_lowered3) + + def test_attr_caching_nonjit(self): + def f_add(a, b): + return dispatch.apply_primitive(lax.add_p, a, b) + + arg1 = jnp.arange(2) + arg2 = jnp.arange(2) + 1 + arg3 = jnp.arange(2) + 2 + with set_xla_metadata(b="c"): + self.assertIn( + 'mhlo.frontend_attributes = {b = "c"}', + jax.jit(f_add).lower(arg1, arg1).as_text(), + ) + # Expect no attributes in the jaxpr. + self.assertNotIn( + "mhlo.frontend_attributes", + jax.jit(f_add).lower(arg2, arg2).as_text(), + ) + + with set_xla_metadata(c="d"): + self.assertIn( + 'mhlo.frontend_attributes = {c = "d"}', + jax.jit(f_add).lower(arg3, arg3).as_text(), + ) + + def test_axpy(self): + @jax.jit + def axpy(a, x, y): + with set_xla_metadata(a="b"): + return a * x + y + + for line in axpy.lower(1.0, 2.0, 3.0).as_text().split("\n"): + if "stablehlo.multiply" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + if "stablehlo.add" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + + def test_while(self): + @jax.jit + def f(a): + with set_xla_metadata(a="b"): + return jax.lax.while_loop(lambda x: x < 10, lambda x: x + 1, a) + + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', f.lower(1.0).as_text() + ) + + def test_while_condition_body(self): + @jax.jit + def f_condition(x): + with set_xla_metadata(a="b"): + return x < 10 + + @jax.jit + def f_body(x): + with set_xla_metadata(a="c"): + return x + 1 + + @jax.jit + def while_fn(a): + return jax.lax.while_loop(f_condition, f_body, a) + + for line in while_fn.lower(1.0).as_text().split("\n"): + if "stablehlo.compare" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + if "stablehlo.add" in line: + self.assertIn('mhlo.frontend_attributes = {a = "c"}', line) + + def test_nested_jit(self): + @jax.jit + def f(x, y): + with set_xla_metadata(a="b"): + z = x * y + + @jax.jit + def g(z): + with set_xla_metadata(c="d"): + return z**2 + 1 + + return g(z) + + self.assertIn( + 'mhlo.frontend_attributes = {a = "b", c = "d"}', + f.lower(1.0, 2.0).as_text(), + ) + + def test_grad(self): + @jax.jit + def f(x, y): + with set_xla_metadata(a="b"): + return jax.grad(lambda x: x**3 + y**2 + jnp.sin(x))(x) + + f_jaxpr = jax.make_jaxpr(f)(1.0, 2.0) + eqns = f_jaxpr.eqns + for eq in eqns[1:]: + self.assertDictEqual(eq.ctx.attributes, {"a": "b"}) + + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', f.lower(1.0, 2.).as_text() + ) + + def test_grad_outside_ctx(self): + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return x**3 + x**2 + jnp.sin(x) + + grad_fn = jax.jit(jax.grad(f)) + for line in grad_fn.lower(1.0).as_text().split("\n"): + if "stablehlo.cosine" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + if "call @integer_pow" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + + def test_vmap(self): + dct = {"a": 0.0, "b": jnp.arange(5.0)} + + @jax.jit + def f(dct, x): + with set_xla_metadata(a="b"): + return dct["a"] + dct["b"] + x + + with set_xla_metadata(a="d"): + f_vmap = jax.vmap(f, in_axes=({"a": None, "b": 0}, None)) + f_jaxpr = jax.make_jaxpr(f_vmap)(dct, 1.0) + eqns = f_jaxpr.eqns + for eq in eqns[1:]: + self.assertDictEqual(eq.ctx.attributes, {"a": "d"}) + @jax.jit + def f2(x, y): + with set_xla_metadata(a="b"): + return (x + y, y * 2.0) + + f_vmap_jaxpr = jax.make_jaxpr(jax.vmap(f2, in_axes=(0, None))) + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', + f_vmap_jaxpr.lower(jnp.arange(5.0), 1.0).as_text(), + ) + + def test_multiple_instructions(self): + @jax.jit + def f(x, a): + y = jnp.matmul(x, x) + with set_xla_metadata(a="b"): + return y + a + + for line in f.lower(jnp.arange(5.0), 1.0).as_text().split("\n"): + # matmul doesn't have attributes + if "stablehlo.dot_general" in line: + self.assertNotIn('mhlo.frontend_attributes = {a = "b"}', line) + if "stablehlo.add" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + + def test_softmax(self): + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.nn.softmax(x) + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', f.lower(jnp.arange(5.0)).as_text() + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 8ab503158bf9c76998d3cf0bd8e160fe07fd69b1 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 7 Sep 2024 18:08:48 +0000 Subject: [PATCH 401/702] tweak readme title to be more about what jax can do for you, dear user we should rewrite this whole readme... --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 878543304b25..52dedbe80746 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ logo -# Pushing back the limits on numerical computing +# Scalable, transformable, high-performance machine learning ![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg) ![PyPI version](https://img.shields.io/pypi/v/jax) From 3e1c2b3ee9aeac48192cbe760fa9ad85c33b5cda Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Sat, 7 Sep 2024 11:25:52 -0700 Subject: [PATCH 402/702] Removed dead code from `add_jaxvals` PiperOrigin-RevId: 672103395 --- jax/_src/ad_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 90ae6c1413ec..57e881c34f82 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -31,7 +31,6 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: - dtype = core.get_aval(x).dtype return add_jaxvals_p.bind(x, y) add_jaxvals_p = Primitive('add_any') From cd782643a1b0877a842b80c13b9cad962cc18883 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 7 Sep 2024 14:15:47 -0700 Subject: [PATCH 403/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/8c6dafbe7e941fcd38647f8b8c4dc73c7916c6a4. PiperOrigin-RevId: 672126938 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a0173da6221e..e5eaa3cb67f2 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "b0368b065c36a4dc27cf8de8a31f6510fa6f5086" -XLA_SHA256 = "82b889ff11bac258df7e9230fbb45921f0adb0cde85561f8ab699fb0f4cfb5f1" +XLA_COMMIT = "8c6dafbe7e941fcd38647f8b8c4dc73c7916c6a4" +XLA_SHA256 = "3bf092c27eabc80d8a353912af0db4f0754233a60773bf57fd37b8d0446c9a54" def repo(): tf_http_archive( From 265bb7bf4c79017096512291d9b34729dc71bd27 Mon Sep 17 00:00:00 2001 From: Keith Rush Date: Sat, 7 Sep 2024 20:29:30 -0700 Subject: [PATCH 404/702] Adds failing test for https://github.com/google/jax/issues/23476. PiperOrigin-RevId: 672183133 --- tests/shard_map_test.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 857c69ce3d91..2df477454646 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -746,6 +746,37 @@ def f(x): self.assertIn('out_names', e.params) self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) + def test_nested_vmap_with_capture_spmd_axis_name(self): + self.skipTest('https://github.com/google/jax/issues/23476') + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + + def to_map_with_capture(x, y): + + # We capture x from `to_map_with_capture`'s parameters. + def with_capture(y_slice): + # Inside of all the maps, we have 'mapped everything away'--we are just + # adding two scalars, but one by fully mapping across each of the two + # dimensions, the other by mapping across one and capturing the + # resulting scalar. + self.assertEqual(x.shape, ()) + self.assertEqual(y_slice.shape, ()) + return x + y_slice + + # This vmap i will refer to as 'inner vmap'. + vmap_with_capture = jax.vmap(with_capture) + shmap_vmap_capture = shard_map( + vmap_with_capture, mesh=mesh, in_specs=P('y'), out_specs=P('y') + ) + return shmap_vmap_capture(y) + + # And this one is the outer vmap. + mapped = jax.vmap(to_map_with_capture, spmd_axis_name='x') + x = jnp.arange(2).reshape(2) + y = jnp.arange(2 * 2).reshape(2, 2) + # Inner vmap inside of shard-map will be over an axis of size 1. Outer vmap + # is over an axis of size 2. This is a problem at the moment. + jax.make_jaxpr(mapped)(x, y).jaxpr + @unittest.skipIf(xla_extension_version < 281, 'Requires xla_extension_version >= 281') def test_shard_map_abstract_mesh(self): From 5af1efb28507fd89f0e128106682df3ed4ceb77a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 8 Sep 2024 12:03:17 -0700 Subject: [PATCH 405/702] Skip symmetric product test on older jaxlibs. The new symmetric product operator will appear to jaxlib 0.4.32. PiperOrigin-RevId: 672311569 --- tests/linalg_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 9fcc940c6c2c..901bfca997dc 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,6 +16,7 @@ from functools import partial import itertools +import unittest import numpy as np import scipy @@ -2193,6 +2194,9 @@ def testHilbert(self, n): symmetrize_output=[True, False], ) @jtu.skip_on_devices("tpu") + @unittest.skipIf( + jax._src.lib.version < (0, 4, 32), "requires jaxlib >= 0.4.32" + ) def testSymmetricProduct(self, shape, dtype, symmetrize_output): rng = jtu.rand_default(self.rng()) batch_size = 10 From b6abd738d9ab57511390a4fee932ef6a142876b9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sun, 8 Sep 2024 12:09:01 -0700 Subject: [PATCH 406/702] Relax some test tolerances in for_loop_test.py. This PR attempts to fix some CI failures on Mac ARM. PiperOrigin-RevId: 672312564 --- tests/for_loop_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py index 0fea62c12cf7..438ba55203a9 100644 --- a/tests/for_loop_test.py +++ b/tests/for_loop_test.py @@ -345,7 +345,8 @@ def g(a, b): expected_tangents = g_lin(a, b) _, actual_tangents = jax.jvp(g, (a, b), (a, b)) np.testing.assert_allclose(actual_tangents[0], expected_tangents[0]) - np.testing.assert_allclose(actual_tangents[1], expected_tangents[1]) + np.testing.assert_allclose(actual_tangents[1], expected_tangents[1], + rtol=1e-6) @jtu.sample_product( [dict(for_body_name=for_body_name, f=for_body, ref=ref, From 201d3ff8f1bbcce6bebf90d2c1a5e5d6dec4fd5e Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 8 Sep 2024 13:37:41 -0700 Subject: [PATCH 407/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/492151fc193162218dba2d89a5d7f7415737b092. PiperOrigin-RevId: 672324700 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e5eaa3cb67f2..088c1ba10466 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "8c6dafbe7e941fcd38647f8b8c4dc73c7916c6a4" -XLA_SHA256 = "3bf092c27eabc80d8a353912af0db4f0754233a60773bf57fd37b8d0446c9a54" +XLA_COMMIT = "492151fc193162218dba2d89a5d7f7415737b092" +XLA_SHA256 = "b531573a0e4d97068615f4a93423a1320e89384b02f288b85eabe496cf1d246e" def repo(): tf_http_archive( From d37c8501ead1f7cb8a2e10f1e85659ac029e41ca Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 9 Sep 2024 10:25:23 +0530 Subject: [PATCH 408/702] Better dosc for jax.numpy: minimum and maximum --- jax/_src/numpy/ufuncs.py | 113 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 2cae8de1712b..eb4ab343ed5b 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -753,14 +753,123 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) -@implements(np.minimum, module='numpy') + @partial(jit, inline=True) def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise minimum of the input arrays. + + JAX implementation of :obj:`numpy.minimum`. + + Args: + x: input array or scalar. + y: input array or scalar. Both ``x`` and ``y`` should either have same shape + or be broadcast compatible. + + Returns: + An array containing the element-wise minimum of ``x`` and ``y``. + + Note: + For each pair of elements, ``jnp.minimum`` returns: + - smaller of the two if both elements are finite numbers. + - ``nan`` if one element is ``nan``. + + See also: + - :func:`jax.numpy.maximum`: Returns element-wise maximum of the input arrays. + - :func:`jax.numpy.fmin`: Returns element-wise minimum of the input arrays, + ignoring NaNs. + - :func:`jax.numpy.amin`: Returns the minimum of array elements along a given + axis. + - :func:`jax.numpy.nanmin`: Returns the minimum of the array elements along + a given axis, ignoring NaNs. + + Examples: + Inputs with ``x.shape == y.shape``: + + >>> x = jnp.array([2, 3, 5, 1]) + >>> y = jnp.array([-3, 6, -4, 7]) + >>> jnp.minimum(x, y) + Array([-3, 3, -4, 1], dtype=int32) + + Inputs having broadcast compatibility: + + >>> x1 = jnp.array([[1, 5, 2], + ... [-3, 4, 7]]) + >>> y1 = jnp.array([-2, 3, 6]) + >>> jnp.minimum(x1, y1) + Array([[-2, 3, 2], + [-3, 3, 6]], dtype=int32) + + Inputs with ``nan``: + + >>> nan = jnp.nan + >>> x2 = jnp.array([[2.5, nan, -2], + ... [nan, 5, 6], + ... [-4, 3, 7]]) + >>> y2 = jnp.array([1, nan, 5]) + >>> jnp.minimum(x2, y2) + Array([[ 1., nan, -2.], + [nan, nan, 5.], + [-4., nan, 5.]], dtype=float32) + """ return lax.min(*promote_args("minimum", x, y)) -@implements(np.maximum, module='numpy') + @partial(jit, inline=True) def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise maximum of the input arrays. + + JAX implementation of :obj:`numpy.maximum`. + + Args: + x: input array or scalar. + y: input array or scalar. Both ``x`` and ``y`` should either have same shape + or be broadcast compatible. + + Returns: + An array containing the element-wise maximum of ``x`` and ``y``. + + Note: + For each pair of elements, ``jnp.maximum`` returns: + - larger of the two if both elements are finite numbers. + - ``nan`` if one element is ``nan``. + + See also: + - :func:`jax.numpy.minimum`: Returns element-wise minimum of the input + arrays. + - :func:`jax.numpy.fmax`: Returns element-wise maximum of the input arrays, + ignoring NaNs. + - :func:`jax.numpy.amax`: Retruns the maximum of array elements along a given + axis. + - :func:`jax.numpy.nanmax`: Returns the maximum of the array elements along + a given axis, ignoring NaNs. + + Examples: + Inputs with ``x.shape == y.shape``: + + >>> x = jnp.array([1, -5, 3, 2]) + >>> y = jnp.array([-2, 4, 7, -6]) + >>> jnp.maximum(x, y) + Array([1, 4, 7, 2], dtype=int32) + + Inputs with broadcast compatibility: + + >>> x1 = jnp.array([[-2, 5, 7, 4], + ... [1, -6, 3, 8]]) + >>> y1 = jnp.array([-5, 3, 6, 9]) + >>> jnp.maximum(x1, y1) + Array([[-2, 5, 7, 9], + [ 1, 3, 6, 9]], dtype=int32) + + Inputs having ``nan``: + + >>> nan = jnp.nan + >>> x2 = jnp.array([nan, -3, 9]) + >>> y2 = jnp.array([[4, -2, nan], + ... [-3, -5, 10]]) + >>> jnp.maximum(x2, y2) + Array([[nan, -2., nan], + [nan, -3., 10.]], dtype=float32) + """ return lax.max(*promote_args("maximum", x, y)) @implements(np.float_power, module='numpy') From 91df9d1a1733126dadbe3136100192669ab5c789 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 9 Sep 2024 10:41:08 +0100 Subject: [PATCH 409/702] Fixed validation in `jax.debug.format` This commit ensures that no formatting is done during validation, because the arguments could be abstract values. Closes #23475. --- jax/_src/debugging.py | 6 +++++- tests/debugging_primitives_test.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 62bfdf031c7c..3e7082ab10ec 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -265,6 +265,10 @@ def _flat_callback(*flat_args): class _DebugPrintFormatChecker(string.Formatter): + def format_field(self, value, format_spec): + del value, format_spec + return "" # No formatting is done. + def check_unused_args(self, used_args, args, kwargs): unused_args = [arg for i, arg in enumerate(args) if i not in used_args] unused_kwargs = [k for k in kwargs if k not in used_args] @@ -314,7 +318,7 @@ def debug_print(fmt: str, *args, **kwargs): **kwargs: Additional keyword arguments to be formatted, as if passed to ``fmt.format``. """ - # Check that we provide the correct arguments to be formatted + # Check that we provide the correct arguments to be formatted. formatter.format(fmt, *args, **kwargs) debug_callback(functools.partial(_format_print_callback, fmt), *args, diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index f5d7c47115b6..273c12f1b13c 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -106,6 +106,16 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "x: 2\n") + def test_can_stage_out_debug_print_with_formatting(self): + @jax.jit + def f(x): + debug_print('x: {x:.2f}', x=x) + + with jtu.capture_stdout() as output: + f(2) + jax.effects_barrier() + self.assertEqual(output(), "x: 2.00\n") + @jtu.device_supports_buffer_donation() def test_can_stage_out_debug_print_with_donate_argnums(self): def f(x, y): From aa16abe511f3210df10bb5e415ac85c8e7ec3f3a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Sep 2024 05:14:19 -0700 Subject: [PATCH 410/702] [pallas] Fix test failures on Windows. Avoid importing Triton modules on Windows, since we don't build it. Also avoid using an unescaped `\` in a regular expression. PiperOrigin-RevId: 672507555 --- tests/pallas/gpu_attention_test.py | 5 ++++- tests/pallas/gpu_ops_test.py | 14 ++++++++++---- tests/pallas/ops_test.py | 8 ++++++++ tests/pallas/pallas_test.py | 6 +++--- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index e7bc88cab811..9428b79c0a55 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -21,7 +21,10 @@ from jax import random from jax._src import config from jax._src import test_util as jtu -from jax.experimental.pallas.ops.gpu import decode_attention +if sys.platform != "win32": + from jax.experimental.pallas.ops.gpu import decode_attention +else: + decode_attention = None import jax.numpy as jnp import numpy as np diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index a18051b002fe..7c5fa2db630c 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -28,10 +28,16 @@ from jax._src.lax.control_flow.for_loop import for_loop from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl -from jax.experimental.pallas.ops.gpu import attention -from jax.experimental.pallas.ops.gpu import layer_norm -from jax.experimental.pallas.ops.gpu import rms_norm -from jax.experimental.pallas.ops.gpu import softmax +if sys.platform != "win32": + from jax.experimental.pallas.ops.gpu import attention + from jax.experimental.pallas.ops.gpu import layer_norm + from jax.experimental.pallas.ops.gpu import rms_norm + from jax.experimental.pallas.ops.gpu import softmax +else: + attention = None + layer_norm = None + rms_norm = None + softmax = None import jax.numpy as jnp import numpy as np diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 564a59ec2552..84844151eb63 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -988,6 +988,10 @@ def kernel(x_ref, o_ref): x = jnp.array([4.2, 2.4]).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x) + @unittest.skipIf( + sys.platform == "win32", + "plgpu.TritonCompilerParams unavailable on Windows", + ) def test_debug_print(self): # TODO: this test flakes on gpu if jtu.test_device_matches(["gpu"]): @@ -1008,6 +1012,10 @@ def kernel(x_ref, o_ref): self.assertIn("It works!", output()) + @unittest.skipIf( + sys.platform == "win32", + "plgpu.TritonCompilerParams unavailable on Windows", + ) def test_debug_print_with_values(self): # TODO: this test flakes on gpu if jtu.test_device_matches(["gpu"]): diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 2762eb28755e..1d3316760fe8 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -768,7 +768,7 @@ def my_index_map(): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - f"Index map function my_index_map at .*{os.sep}pallas_test.py:.* for " + "Index map function my_index_map at .*pallas_test.py:.* for " "x_ref must return 1 values to match .*" "Currently returning 2 values."): f(a) @@ -783,7 +783,7 @@ def my_index_map(i): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - f"Index map function my_index_map at .*{os.sep}pallas_test.py:.* for " + "Index map function my_index_map at .*pallas_test.py:.* for " "x_ref must return integer scalars. Output\\[0\\] has " "type .*float"): f(a) @@ -798,7 +798,7 @@ def my_index_map(i): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - f"Index map function my_index_map at .*{os.sep}pallas_test.py:.* for " + "Index map function my_index_map at .*pallas_test.py:.* for " "x_ref must return integer scalars. Output\\[0\\] has " "type .*int32\\[4\\]"): f(a) From fe63b991dda8a5a39398bad8e20d0d3fba9b48f0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Sep 2024 05:14:27 -0700 Subject: [PATCH 411/702] Disable cudnn_fusion_test from CI. This test isn't passing in our internal CI. PiperOrigin-RevId: 672507574 --- tests/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index f8cf35c5d1e0..2fbed0601d96 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1540,7 +1540,10 @@ jax_test( "gpu_a100", "gpu_h100", ], - tags = ["multiaccelerator"], + tags = [ + "multiaccelerator", + "notap", # TODO(phawkins): this test fails in our internal CI. + ], ) exports_files( From 0320a792ba2319c89014c60d301f4bd7fa8dda9d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Sep 2024 05:34:45 -0700 Subject: [PATCH 412/702] Improve docs for jnp.split & related APIs --- jax/_src/numpy/lax_numpy.py | 209 +++++++++++++++++++++++++++++++++--- tests/lax_numpy_test.py | 3 +- 2 files changed, 195 insertions(+), 17 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 74dd152f2684..a4d1aab41d13 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2593,30 +2593,207 @@ def _split(op: str, ary: ArrayLike, return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) for start, end in zip(split_indices[:-1], split_indices[1:])] -@util.implements(np.split, lax_description=_ARRAY_VIEW_DOC) + def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: + """Split an array into sub-arrays. + + JAX implementation of :func:`numpy.split`. + + Args: + ary: N-dimensional array-like object to split + indices_or_sections: either a single integer or a sequence of indices. + + - if ``indices_or_sections`` is an integer *N*, then *N* must evenly divide + ``ary.shape[axis]`` and ``ary`` will be divided into *N* equally-sized + chunks along ``axis``. + - if ``indices_or_sections`` is a sequence of integers, then these integers + specify the boundary between unevenly-sized chunks along ``axis``; see + examples below. + + axis: the axis along which to split; defaults to 0. + + Returns: + A list of arrays. If ``indices_or_sections`` is an integer *N*, then the list is + of length *N*. If ``indices_or_sections`` is a sequence *seq*, then the list is + is of length *len(seq) + 1*. + + Examples: + Splitting a 1-dimensional array: + + >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + + Split into three equal sections: + + >>> chunks = jnp.split(x, 3) + >>> print(*chunks) + [1 2 3] [4 5 6] [7 8 9] + + Split into sections by index: + + >>> chunks = jnp.split(x, [2, 7]) # [x[0:2], x[2:7], x[7:]] + >>> print(*chunks) + [1 2] [3 4 5 6 7] [8 9] + + Splitting a two-dimensional array along axis 1: + + >>> x = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8]]) + >>> x1, x2 = jnp.split(x, 2, axis=1) + >>> print(x1) + [[1 2] + [5 6]] + >>> print(x2) + [[3 4] + [7 8]] + + See also: + - :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections`` + to be an integer that does not evenly divide the size of the array. + - :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0 + - :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1 + - :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2 + """ return _split("split", ary, indices_or_sections, axis=axis) -def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, int | ArrayLike], list[Array]]: - @util.implements(getattr(np, op), update_doc=False) - def f(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: - # for 1-D array, hsplit becomes vsplit - nonlocal axis - util.check_arraylike(op, ary) - a = asarray(ary) - if axis == 1 and len(a.shape) == 1: - axis = 0 - return _split(op, ary, indices_or_sections, axis=axis) - return f -vsplit = _split_on_axis("vsplit", axis=0) -hsplit = _split_on_axis("hsplit", axis=1) -dsplit = _split_on_axis("dsplit", axis=2) +def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: + """Split an array into sub-arrays vertically. + + JAX implementation of :func:`numpy.vsplit`. + + Refer to the documentation of :func:`jax.numpy.split` for details; ``vsplit`` is + equivalent to ``split`` with ``axis=0``. + + Examples: + 1D array: + + >>> x = jnp.array([1, 2, 3, 4, 5, 6]) + >>> x1, x2 = jnp.vsplit(x, 2) + >>> print(x1, x2) + [1 2 3] [4 5 6] + + 2D array: + + >>> x = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8]]) + >>> x1, x2 = jnp.vsplit(x, 2) + >>> print(x1, x2) + [[1 2 3 4]] [[5 6 7 8]] + + See also: + - :func:`jax.numpy.split`: split an array along any axis. + - :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1 + - :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2 + - :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections`` + to be an integer that does not evenly divide the size of the array. + """ + return _split("vsplit", ary, indices_or_sections, axis=0) + + +def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: + """Split an array into sub-arrays horizontally. + + JAX implementation of :func:`numpy.hsplit`. + + Refer to the documentation of :func:`jax.numpy.split` for details. ``hsplit`` is + equivalent to ``split`` with ``axis=1``, or ``axis=0`` for one-dimensional arrays. + + Examples: + 1D array: + + >>> x = jnp.array([1, 2, 3, 4, 5, 6]) + >>> x1, x2 = jnp.hsplit(x, 2) + >>> print(x1, x2) + [1 2 3] [4 5 6] + + 2D array: + + >>> x = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8]]) + >>> x1, x2 = jnp.hsplit(x, 2) + >>> print(x1) + [[1 2] + [5 6]] + >>> print(x2) + [[3 4] + [7 8]] + + See also: + - :func:`jax.numpy.split`: split an array along any axis. + - :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0 + - :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2 + - :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections`` + to be an integer that does not evenly divide the size of the array. + """ + util.check_arraylike("hsplit", ary) + a = asarray(ary) + return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1) + + +def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: + """Split an array into sub-arrays depth-wise. + + JAX implementation of :func:`numpy.dsplit`. + + Refer to the documentation of :func:`jax.numpy.split` for details. ``dsplit`` is + equivalent to ``split`` with ``axis=2``. + + Examples: + + >>> x = jnp.arange(12).reshape(3, 1, 4) + >>> print(x) + [[[ 0 1 2 3]] + + [[ 4 5 6 7]] + + [[ 8 9 10 11]]] + >>> x1, x2 = jnp.dsplit(x, 2) + >>> print(x1) + [[[0 1]] + + [[4 5]] + + [[8 9]]] + >>> print(x2) + [[[ 2 3]] + + [[ 6 7]] + + [[10 11]]] + + See also: + - :func:`jax.numpy.split`: split an array along any axis. + - :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0 + - :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1 + - :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections`` + to be an integer that does not evenly divide the size of the array. + """ + return _split("dsplit", ary, indices_or_sections, axis=2) + -@util.implements(np.array_split) def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: + """Split an array into sub-arrays. + + JAX implementation of :func:`numpy.array_split`. + + Refer to the documentation of :func:`jax.numpy.split` for details; ``array_split`` + is equivalent to ``split``, but allows integer ``indices_or_sections`` which does + not evenly divide the split axis. + + Examples: + >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + >>> chunks = jnp.array_split(x, 4) + >>> print(*chunks) + [1 2 3] [4 5] [6 7] [8 9] + + See also: + - :func:`jax.numpy.split`: split an array along any axis. + - :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0 + - :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1 + - :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2 + """ return _split("array_split", ary, indices_or_sections, axis=axis) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 534a395ef777..1323196feda0 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6289,6 +6289,7 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', 'amax', 'amin', 'around', 'bitwise_right_shift', 'divide', 'round_'] + skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split'] for name in dir(jnp): if name.startswith('_') or name in unimplemented: @@ -6313,7 +6314,7 @@ def test_lax_numpy_docstrings(self): raise Exception(f"jnp.{name} does not have a wrapped docstring.") elif name in aliases: assert "Alias of" in obj.__doc__ - else: + elif name not in skip_args_check: # Other functions should have nontrivial docs including "Args" and "Returns". doc = obj.__doc__ self.assertNotEmpty(doc) From 05cdcb8ce5dee568de06513f111a64ff4f0dac3d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 9 Sep 2024 06:56:03 -0700 Subject: [PATCH 413/702] Slightly re-arranged Pallas Mosaic GPU pipelining logic This change prepares a few pipelining optimizations which will be done in a follow up. PiperOrigin-RevId: 672530087 --- jax/_src/pallas/mosaic_gpu/lowering.py | 76 ++++++++++++++------------ 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index fa546949780e..cdcd1aa97b18 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -251,7 +251,7 @@ def gmem_slice( start_indices: Sequence[ir.Value], step: ir.Value, shape: Sequence[int], - ) -> ir.Value: + ) -> Sequence[mgpu.DynamicSlice]: return tuple( mgpu.ds( arith_dialect.addi( @@ -264,37 +264,35 @@ def gmem_slice( for axis, (start_index, dim) in enumerate(zip(start_indices, shape)) ) - @mgpu.single_thread() - def fetch(step: ir.Value, slot: ir.Value) -> None: - for start_indices, b_gmem, b_smem in zip( - in_start_indices, in_buffers_gmem, in_buffers_smem - ): - # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. - b_smem_shape = ir.MemRefType(b_smem.type).shape[1:] - launch_ctx.async_copy( - src_ref=b_gmem, - dst_ref=mgpu.memref_slice(b_smem, slot), - gmem_slice=gmem_slice(start_indices, step, b_smem_shape), - barrier=barriers[slot], - swizzle=None, - arrive=True, - uniform=False, - ) + def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: + # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. + launch_ctx.async_copy( + src_ref=in_buffers_gmem[idx], + dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot), + gmem_slice=gmem_slice( + in_start_indices[idx], + step, + ir.MemRefType(in_buffers_smem[idx].type).shape[1:], + ), + barrier=barriers[slot], + swizzle=None, + arrive=True, + uniform=False, + ) - @mgpu.single_thread() - def store(step: ir.Value, slot: ir.Value) -> None: - for start_indices, b_gmem, b_smem in zip( - out_start_indices, out_buffers_gmem, out_buffers_smem - ): - # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. - b_smem_shape = ir.MemRefType(b_smem.type).shape[1:] - launch_ctx.async_copy( - src_ref=mgpu.memref_slice(b_smem, slot), - dst_ref=b_gmem, - gmem_slice=gmem_slice(start_indices, step, b_smem_shape), - swizzle=None, - uniform=False, - ) + def store(idx: int, step: ir.Value, slot: ir.Value) -> None: + # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. + launch_ctx.async_copy( + src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot), + dst_ref=out_buffers_gmem[idx], + gmem_slice=gmem_slice( + out_start_indices[idx], + step, + ir.MemRefType(out_buffers_smem[idx].type).shape[1:], + ), + swizzle=None, + uniform=False, + ) # Compute the number of steps along each sequential axis. if sequential_axes: @@ -325,8 +323,10 @@ def store(step: ir.Value, slot: ir.Value) -> None: else: num_steps = 1 - for slot in range(min(num_stages, num_steps)): - fetch(_as_index(slot), _as_index(slot)) + with mgpu.single_thread(): + for slot in range(min(num_stages, num_steps)): + for idx in range(grid_mapping.num_inputs): + fetch(idx, _as_index(slot), _as_index(slot)) @mgpu.fori(_as_index(num_steps), ()) def _(step, _): @@ -341,14 +341,18 @@ def _(step, _): [mgpu.memref_slice(b_smem, slot) for b_smem in buffers_smem], ) mgpu.commit_shared() - store(step, slot) + + with mgpu.single_thread(): + for idx in range(grid_mapping.num_outputs): + store(idx, step, slot) next_step = arith_dialect.addi(step, _as_index(num_stages)) next_step_in_bounds = arith_dialect.cmpi( arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps) ) - with mgpu.when(next_step_in_bounds): - fetch(next_step, slot) + with mgpu.when(next_step_in_bounds), mgpu.single_thread(): + for idx in range(grid_mapping.num_inputs): + fetch(idx, next_step, slot) return () From b7b58e9983d4d0cc0b0a5f7c7a8317b763efcb66 Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Mon, 9 Sep 2024 17:16:13 +0200 Subject: [PATCH 414/702] More explicit docstring on the limitations of `spsolve` --- jax/experimental/sparse/linalg.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index b0ac1fa5d380..184eb9741dd2 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -602,7 +602,9 @@ def spsolve(data, indices, indptr, b, tol=1e-6, reorder=1): """A sparse direct solver using QR factorization. Accepts a sparse matrix in CSR format `data, indices, indptr` arrays. - Currently only the CUDA GPU backend is implemented. + Currently only the CUDA GPU backend is implemented, the CPU backend will fall + back to `scipy.sparse.linalg.spsolve`. Neither the CPU nor the GPU + implementation support batching with `vmap`. Args: data : An array containing the non-zero entries of the CSR matrix. From ab29fee7637d699a2ccc96baeca8425b0a754ea4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Sep 2024 08:23:38 -0700 Subject: [PATCH 415/702] Add array_api intersphinx & document jnp.permute_dims --- docs/conf.py | 1 + jax/_src/numpy/lax_numpy.py | 25 ++++++++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 6d64f7a90285..1a7bf32842f0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -85,6 +85,7 @@ def _do_not_evaluate_in_jax( ] intersphinx_mapping = { + 'array_api': ('https://data-apis.org/array-api/2023.12/', None), 'python': ('https://docs.python.org/3/', None), 'numpy': ('https://numpy.org/doc/stable/', None), 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 74dd152f2684..99574f856ea2 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -902,8 +902,31 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: return lax.transpose(a, axes_) -@util.implements(getattr(np, "permute_dims", None)) def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: + """Permute the axes/dimensions of an array. + + JAX implementation of :func:`array_api.permute_dims`. + + Args: + a: input array + axes: tuple of integers in range ``[0, a.ndim)`` specifying the + axes permutation. + + Returns: + a copy of ``a`` with axes permuted. + + See also: + - :func:`jax.numpy.transpose` + - :func:`jax.numpy.matrix_transpose` + + Examples: + >>> a = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.permute_dims(a, (1, 0)) + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) + """ util.check_arraylike("permute_dims", a) return lax.transpose(a, axes) From f86cc91cce4f6d74825f9579a126a3c770008715 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Sep 2024 08:58:31 -0700 Subject: [PATCH 416/702] array API: use latest array-api-tests commit hash --- .github/workflows/jax-array-api.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 3709f0557a46..cdba39b3642a 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -25,7 +25,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'db95e67b29235249e5776ca2b6bb4e77117e0690' # Latest commit as of 2024-08-08 + ref: 'b4c0823469c02d6ce6e512ad4c2bd8ba42b1b4b2' # Latest commit as of 2024-09-09 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} From 95f38d95d7bf7fb39a7891feb27396ad063892ac Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Sep 2024 09:01:51 -0700 Subject: [PATCH 417/702] Update TPU test configuration tags. PiperOrigin-RevId: 672562923 --- tests/BUILD | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 2fbed0601d96..d1fb4dcc7cde 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -235,8 +235,8 @@ jax_test( }, enable_configs = [ "gpu_2gpu_shardy", - "tpu_df_2x2_shardy", - "tpu_pf_2x2_shardy", + "tpu_v3_2x2_shardy", + "tpu_v4_2x2_shardy", ], shard_count = { "cpu": 5, @@ -1426,7 +1426,7 @@ jax_test( name = "export_test", srcs = ["export_test.py"], enable_configs = [ - "tpu_df_2x2", + "tpu_v3_2x2", ], tags = [], deps = [ From 4bdfe09241493778a4e466486156a3ffd73a65c2 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 9 Sep 2024 10:20:47 -0700 Subject: [PATCH 418/702] [Pallas] Fully skip GPU attention tests on win32. PiperOrigin-RevId: 672588009 --- tests/pallas/gpu_attention_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index 9428b79c0a55..ed059c235329 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -51,7 +51,7 @@ def setUp(self): if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPU with capability >= sm80") - if sys.platform == "win32" and not self.INTERPRET: + if sys.platform == "win32": self.skipTest("Only works on non-Windows platforms") super().setUp() From b97559247870c9fcf38a2269ac7ea8771e9f34b2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Sep 2024 14:36:23 -0400 Subject: [PATCH 419/702] Change nightly install commands to include all packages. pip doesn't update transitive dependencies, and we probably want the latest versions of everything when installing a nightly. --- docs/installation.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 4a831750e42b..7a12f7c541a2 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -269,22 +269,26 @@ for more details. Nightly releases reflect the state of the main JAX repository at the time they are built, and may not pass the full test suite. +Unlike the instructions for installing a JAX release, here we name all of JAX's +packages explicitly on the command line, so `pip` will upgrade them if a newer +version is available. + - CPU only: ```bash -pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html ``` - Google Cloud TPU: ```bash -pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U --pre jax[tpu] jaxlib libtpu-nightly -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` - NVIDIA GPU (CUDA 12): ```bash -pip install -U --pre jax[cuda12] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax[cuda12] jaxlib jax-cuda12-plugin jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html ``` - NVIDIA GPU (CUDA 12) legacy: From 5cc5ed2c5cd6efcc474aaf97963d8994037f9b44 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Sep 2024 11:44:51 -0700 Subject: [PATCH 420/702] Disable a shard_map test case that fails on TPU v5e. PiperOrigin-RevId: 672618556 --- tests/shard_map_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 2df477454646..e9c23b3e5f0d 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1544,6 +1544,9 @@ def f(x): self.assertEqual(e2.primitive.name, 'pbroadcast') def test_check_rep_false_grads(self): + if jtu.is_device_tpu(5, 'e'): + self.skipTest('TODO(b/307508823): Test currently fails on TPU v5e') + # This test is redundant with the systematic tests below, but it serves as a # direct regression test for a bug. mesh = jtu.create_mesh((4,), ('heads',)) From 623cbb8ce777052481a3ba35e0a48c422d2788f5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 9 Sep 2024 13:10:04 -0700 Subject: [PATCH 421/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/32004c272725505f068d2f0d997f68bef383a618. PiperOrigin-RevId: 672647417 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 088c1ba10466..b48634781869 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "492151fc193162218dba2d89a5d7f7415737b092" -XLA_SHA256 = "b531573a0e4d97068615f4a93423a1320e89384b02f288b85eabe496cf1d246e" +XLA_COMMIT = "32004c272725505f068d2f0d997f68bef383a618" +XLA_SHA256 = "d46a54c5539485de704c72a884812e191e2a3f7954bd97f6d7bf16d99d12cfc5" def repo(): tf_http_archive( From d6c36255e832785aa4e3f8ebbaa0f8b46b266eb6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 9 Sep 2024 13:23:42 -0700 Subject: [PATCH 422/702] Create optimal order for v5e:8 devices which is [0, 1, 2, 3, 7, 6, 5, 4] PiperOrigin-RevId: 672652104 --- jax/_src/mesh_utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index 7cac0338a923..da3b54058fbb 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -67,6 +67,7 @@ _TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) _TRAY_2x2_RING_ORDER = (0, 1, 3, 2) _TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4) +_V5E_TRAY_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 4) def _tpu_v2_v3_create_device_mesh( mesh_shape: Sequence[int], @@ -96,7 +97,7 @@ def _tpu_v2_v3_create_device_mesh( return np.asarray(devices).reshape(mesh_shape) -def _vlc_create_device_mesh( +def _v5e_create_device_mesh( mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs ) -> np.ndarray | None: """Creates rotated pincer device assignment for selected topologies. @@ -118,13 +119,19 @@ def _vlc_create_device_mesh( devices, key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) - if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2 + if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: device_mesh = np.asarray(sequential_devices) device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)] device_mesh = device_mesh.reshape(mesh_shape) return device_mesh - if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4 + if len(devices) == 8: + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_V5E_TRAY_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # v5e4x4 # Only uses ring order if the whole mesh is a replica group. if max(mesh_shape) == len(devices): device_mesh = np.asarray(sequential_devices) @@ -144,7 +151,7 @@ def _vlc_create_device_mesh( ] = { _TPU_V2: _tpu_v2_v3_create_device_mesh, _TPU_V3: _tpu_v2_v3_create_device_mesh, - _TPU_V5_LITE: _vlc_create_device_mesh, + _TPU_V5_LITE: _v5e_create_device_mesh, } From 72c095261fb2d678005bdaa963d4146d11634af0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Sep 2024 14:03:19 -0700 Subject: [PATCH 423/702] Improve the docstring for `jax.Array.copy_to_host_async`. PiperOrigin-RevId: 672666190 --- jax/_src/basearray.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index 4848d83d5315..c0b4f9f51c8b 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -124,7 +124,19 @@ def device(self) -> Device | Sharding: @abc.abstractmethod def copy_to_host_async(self): - """Copies jax.Array to host asynchronously.""" + """Copies an ``Array`` to the host asynchronously. + + For arrays that live an an accelerator, such as a GPU or a TPU, JAX may + cache the value of the array on the host. Normally this happens + behind the scenes when the value of an on-device array is requested by the + user, but waiting to initiate a device-to-host copy until the value is + requested requires that JAX block the caller while waiting for the copy to + complete. + + ``copy_to_host_async`` requests that JAX populate its on-host cache of an + array, but does not wait for the copy to complete. This may speed up a + future on-host access to the array's contents. + """ Array.__module__ = "jax" From 7d2f0a75c1c10aeae63ed1a0ced17b7aeb979989 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 9 Sep 2024 14:48:50 -0700 Subject: [PATCH 424/702] [Pallas GPU] Fix the behavior of `jnp.sign(jnp.nan)` and move the TPU test case for `jnp.sign` into the general test This PR is similar to https://github.com/google/jax/pull/23192, which moves TPU test case for `lax.erf_inv` into the general test Fixes https://github.com/google/jax/issues/23504 PiperOrigin-RevId: 672682048 --- jax/_src/pallas/mosaic/lowering.py | 18 ++------ jax/_src/pallas/triton/lowering.py | 12 ++---- jax/_src/pallas/utils.py | 14 +++++++ tests/pallas/ops_test.py | 66 +++++++++++++++--------------- 4 files changed, 52 insertions(+), 58 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index c8f910e185f3..bd897deb3d1f 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1838,22 +1838,10 @@ def _neg_lowering_rule(ctx: LoweringRuleContext, x): skip_mlir_conversions.add(lax.neg_p) -def _sign_lowering_helper(x): - if jnp.issubdtype(x.dtype, jnp.unsignedinteger): - return (x != 0).astype(x.dtype) - - if jnp.issubdtype(x.dtype, jnp.integer): - return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype) - - if jnp.issubdtype(x.dtype, jnp.floating): - out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype) - return jnp.where(jnp.isnan(x), jnp.nan, out) - - raise NotImplementedError - - def _sign_lowering_rule(ctx: LoweringRuleContext, x): - return lower_fun(_sign_lowering_helper, multiple_results=False)(ctx, x) + return lower_fun( + pallas_utils.sign_lowering_helper, multiple_results=False, + )(ctx, x) lowering_rules[lax.sign_p] = _sign_lowering_rule diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 446d87a5f347..dec48847520a 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1359,15 +1359,9 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): return _floordiv(x, y, signed=signed) -@register_lowering(lax.sign_p) -def _sign_lowering_rule(ctx: LoweringRuleContext, x): - [x_aval] = ctx.avals_in - signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger) - zero = _full(x.type, 0) - return _sub( - _cast(_greater_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype), - _cast(_less_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype), - ) +register_lowering(lax.sign_p)( + lower_fun(pallas_utils.sign_lowering_helper, multiple_results=False) +) @register_lowering(lax.iota_p) diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 6fc816e27b53..cfca0769d13d 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -210,3 +210,17 @@ def erf_inv_32_lowering_helper(x): p = c + p * w return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x) + + +def sign_lowering_helper(x): + if jnp.issubdtype(x.dtype, jnp.unsignedinteger): + return (x != 0).astype(x.dtype) + + if jnp.issubdtype(x.dtype, jnp.integer): + return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype) + + if jnp.issubdtype(x.dtype, jnp.floating): + out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype) + return jnp.where(jnp.isnan(x), jnp.nan, out) + + raise NotImplementedError(f"sign_lowering_helper not implemented for {x.dtype}") diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 84844151eb63..52a1d62a63a3 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -669,6 +669,38 @@ def run(interpret=False): actual = run(False) self.assertAllClose(actual, expected) + SIGN_PARAMS = [ + (jnp.int32, (-3, 0, 5)), + (jnp.uint32, (0, 5)), + (jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), + (jnp.float64, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), + ] + + @parameterized.named_parameters( + (f"{dtype.__name__}_{value}", dtype, value) + for dtype, values in SIGN_PARAMS + for value in values + ) + def test_sign(self, dtype, value): + if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64: + self.skipTest("float64 is not supported on TPU") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sign(x_ref[...]) + + with contextlib.ExitStack() as stack: + if jnp.dtype(dtype).itemsize == 8: + stack.enter_context(config.enable_x64(True)) + + x = jnp.full((8, 128,), value, dtype=dtype) + out = kernel(x) + expected = jnp.sign(x) + np.testing.assert_array_equal(out, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True @@ -1614,39 +1646,5 @@ class PallasPrimitivesInterpretTest(PallasPrimitivesTest): INTERPRET = True -class TpuOpsTest(PallasBaseTest): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Test requires TPU device.") - - super().setUp() - - SIGN_PARAMS = [ - (jnp.int32, (-3, 0, 5)), - (jnp.uint32, (0, 5)), - (jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), - ] - - @parameterized.named_parameters( - (f"{dtype.__name__}_{value}", dtype, value) - for dtype, values in SIGN_PARAMS - for value in values - ) - def test_sign(self, dtype, value): - @jax.jit - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct((8, 128), dtype), - ) - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sign(x_ref[...]) - - x = jnp.full((8, 128,), value, dtype=dtype) - out = kernel(x) - expected = jnp.sign(x) - np.testing.assert_array_equal(out, expected) - - if __name__ == "__main__": absltest.main() From e0faa596b3d6797b68fa695c65c19407619ef8ee Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 10 Sep 2024 02:46:14 +0100 Subject: [PATCH 425/702] [Pallas] Fix array indexing error when dimension size is not a multiple of stride --- jax/_src/state/discharge.py | 2 +- tests/pallas/indexing_test.py | 44 +++++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 4795af054280..8e1b3732dd3d 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -169,7 +169,7 @@ def _maybe_convert_to_slice( return None start = i.start - end = i.start + i.size * i.stride + end = i.start + (i.size - 1) * i.stride + 1 stride = i.stride # cannot convert to static `slice` if `start` or `end` is dynamic diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 59e28db6d9e2..d49b83fe160b 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -647,27 +647,37 @@ class IndexerOpsInterpretTest(IndexerOpsTest): # TODO(ayx): Fix all test cases here _ADVANCED_INDEXER_TEST_CASES = [ - ((8, 2), lambda arr, a, b, c, d: arr[2]), - ((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]), - ((16, 3), lambda arr, a, b, c, d: arr[a, a]), - ((16, 16), lambda arr, a, b, c, d: arr[::4, ::4]), + # integer + ((3, 2), lambda arr, a, b, c, d: arr[2]), + # slice + ((12, 12), lambda arr, a, b, c, d: arr[::4, ::4]), ((16, 16), lambda arr, a, b, c, d: arr[1:14:2, 2:13:4]), - ((16, 3), lambda arr, a, b, c, d: arr[a, :]), - # ((16, 3), lambda arr, a, b, c, d: arr[:, a]), - ((16, 3), lambda arr, a, b, c, d: arr[a, ::4]), - # ((16, 3), lambda arr, a, b, c, d: arr[::4, a]), + ((8, 2), lambda arr, a, b, c, d: arr[1::3, :]), + # array + ((4, 3), lambda arr, a, b, c, d: arr[a]), + ((4, 3, 2), lambda arr, a, b, c, d: arr[c, c]), + # integer + 1-D array + ((4, 3), lambda arr, a, b, c, d: arr[2, a]), + ((4, 3), lambda arr, a, b, c, d: arr[a, 2]), + # slice + 1-D array + ((4, 3), lambda arr, a, b, c, d: arr[a, :]), + # ((4, 3), lambda arr, a, b, c, d: arr[:, a]), + ((6, 8, 3), lambda arr, a, b, c, d: arr[c, ::3]), + # ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]), # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), - # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, ::2, ::2]), - # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, ::2, ::2]), - # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]), - # ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]), + ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]), + ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]), # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]), - # ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, c]), - ((8, 6, 4), lambda arr, a, b, c, d: arr[a]), - ((6, 8, 4), lambda arr, a, b, c, d: arr[c, c]), - ((6, 8, 4), lambda arr, a, b, c, d: arr[c, ::3]), - # ((8, 6, 4), lambda arr, a, b, c, d: arr[::3, c]), + ((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]), + ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, a]), + # slice + array w/ broadcasting + ((8, 8, 3, 6), lambda arr, a, b, c, d: \ + arr[b[:, None], ::4, a[None], a[:, None]]), + # integer + slice + 1-D array + ((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, ::2, a]), + ((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, a, ::2]), + # boolean # ((6, 2), lambda arr, a, b, c, d: arr[d]), # ((8, 6), lambda arr, a, b, c, d: arr[::4, d]), ] From 062a69a97e84e257f1ff455aeae1d7c25a302567 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Tue, 10 Sep 2024 03:06:16 -0700 Subject: [PATCH 426/702] Make JAX extract the mesh from an `AUTO` in/out sharding. Automatic partitioners using JAX+Shardy want to partition models which are fully marked as `AUTO` - so no in/out sharding with a `NamedSharding`. In such a case they weren't seeing the mesh on the MLIR module. This makes sure we extract it from the `AUTO` sharding. PiperOrigin-RevId: 672881018 --- jax/_src/interpreters/pxla.py | 2 +- tests/pjit_test.py | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7db928f12704..882f71d58671 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2217,7 +2217,7 @@ def lower_sharding_computation( if config.use_shardy_partitioner.value or prim_requires_devices: for sharding in it.chain(in_shardings, out_shardings, [js for js, _ in unique_intermediate_shardings]): - if isinstance(sharding, sharding_impls.NamedSharding): + if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)): if (mesh_shape_tuple is not None and mesh_shape_tuple != sharding.mesh.shape_tuple): raise ValueError( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 25b2680e8d15..dbb867ab9a39 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5288,9 +5288,9 @@ def f(x): def test_compile_with_inferred_out_sharding(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jax.device_put(np.arange(8 * 4).reshape(8, 4), - jax.sharding.NamedSharding(mesh, P('x', 'y'))) + NamedSharding(mesh, P('x', 'y'))) y = jax.device_put(np.arange(4 * 16).reshape(4, 16), - jax.sharding.NamedSharding(mesh, P('y'))) + NamedSharding(mesh, P('y'))) @jax.jit def f(x, y): @@ -5298,7 +5298,18 @@ def f(x, y): out = f(x, y) self.assertArraysEqual(out, x @ y) - self.assertEqual(out.sharding, jax.sharding.NamedSharding(mesh, P('x'))) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_fully_automatic_sharding(self): + mesh = jtu.create_mesh((8,), ('x',)) + x = jax.ShapeDtypeStruct((128, 128), jnp.float32) + + @jax.jit + def f(x, y): + return x @ y + + lowered_str = jax.jit(f, in_shardings=[AUTO(mesh), AUTO(mesh)]).lower(x, x).as_text() + self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str) if __name__ == '__main__': From 1b2ba9d1c27e37237610df405d4bfcd99ab4b6ec Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 10 Sep 2024 08:25:40 -0700 Subject: [PATCH 427/702] Disable two lax_scipy_test testcases that fail on TPU v6e. PiperOrigin-RevId: 672973757 --- jax/_src/test_util.py | 2 ++ tests/lax_scipy_test.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 102afc42d079..f2d422e512fc 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -506,6 +506,8 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool: # Special case v5e until the name is updated in device_kind if expected_version == "v5e": return "v5 lite" in device_kind + elif expected_version == "v6e": + return "v6 lite" in device_kind return expected_version in device_kind def is_cuda_compute_capability_at_least(capability: str) -> bool: diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index b3c373b63424..df19750b1877 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -333,6 +333,8 @@ def scipy_fun(z): dtype=float_dtypes, ) def testLpmn(self, l_max, shape, dtype): + if jtu.is_device_tpu(6, "e"): + self.skipTest("TODO(b/364258243): fails on TPU v6e") rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] @@ -442,6 +444,8 @@ def testSphHarmOrderOneDegreeOne(self): @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" + if jtu.is_device_tpu(6, "e"): + self.skipTest("TODO(b/364258243): fails on TPU v6e") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) From c5bc2412a74837634e2931eeb73b31cba4575cee Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Tue, 10 Sep 2024 20:56:49 +0530 Subject: [PATCH 428/702] Improve doc for jnp.trim_zeros --- jax/_src/numpy/lax_numpy.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 9586ede8e127..820a05f56548 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6375,8 +6375,28 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: return res -@util.implements(np.trim_zeros) def trim_zeros(filt, trim='fb'): + """Trim leading and/or trailing zeros of the input array. + + JAX implementation of :func:`numpy.trim_zeros`. + + Args: + filt: input array. Must have ``filt.ndim == 1``. + trim: string, optional, default = ``fb``. Specifies from which end the input + is trimmed. + + - ``f`` - trims only the leading zeros. + - ``b`` - trims only the trailing zeros. + - ``fb`` - trims both leading and trailing zeros. + + Returns: + An array containig the trimmed input with same dtype as ``filt``. + + Examples: + >>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0]) + >>> jnp.trim_zeros(x) + Array([2, 0, 1, 4, 3], dtype=int32) + """ filt = core.concrete_or_error(asarray, filt, "Error arose in the `filt` argument of trim_zeros()") nz = (filt == 0) From ee04646f33f9c240b0a57601c7adec3b8622c32b Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Tue, 10 Sep 2024 22:08:44 +0530 Subject: [PATCH 429/702] Improve docs for jax.numpy: float_power and nextafter --- jax/_src/numpy/ufuncs.py | 67 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index eb4ab343ed5b..c3b38e57ab76 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -872,14 +872,77 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.max(*promote_args("maximum", x, y)) -@implements(np.float_power, module='numpy') + @partial(jit, inline=True) def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: + """Calculate element-wise base ``x`` exponential of ``y``. + + JAX implementation of :obj:`numpy.float_power`. + + Args: + x: scalar or array. Specifies the bases. + y: scalar or array. Specifies the exponents. ``x`` and ``y`` should either + have same shape or be broadcast compatible. + + Returns: + An array containing the base ``x`` exponentials of ``y``, promoting to the + inexact dtype. + + See also: + - :func:`jax.numpy.exp`: Calculates element-wise exponential of the input. + - :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of + the input. + + Examples: + Inputs with same shape: + + >>> x = jnp.array([3, 1, -5]) + >>> y = jnp.array([2, 4, -1]) + >>> jnp.float_power(x, y) + Array([ 9. , 1. , -0.2], dtype=float32) + + Inputs with broacast compatibility: + + >>> x1 = jnp.array([[2, -4, 1], + ... [-1, 2, 3]]) + >>> y1 = jnp.array([-2, 1, 4]) + >>> jnp.float_power(x1, y1) + Array([[ 0.25, -4. , 1. ], + [ 1. , 2. , 81. ]], dtype=float32) + + ``jnp.float_power`` produces ``nan`` for negative values raised to a non-integer + values. + + >>> jnp.float_power(-3, 1.7) + Array(nan, dtype=float32, weak_type=True) + """ return lax.pow(*promote_args_inexact("float_power", x, y)) -@implements(np.nextafter, module='numpy') + @partial(jit, inline=True) def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise next floating point value after ``x`` towards ``y``. + + JAX implementation of :obj:`numpy.nextafter`. + + Args: + x: scalar or array. Specifies the value after which the next number is found. + y: scalar or array. Specifies the direction towards which the next number is + found. ``x`` and ``y`` should either have same shape or be broadcast + compatible. + + Returns: + An array containing the next representable number of ``x`` in the direction + of ``y``. + + Examples: + >>> jnp.nextafter(2, 1) # doctest: +SKIP + Array(1.9999999, dtype=float32, weak_type=True) + >>> x = jnp.array([3, -2, 1]) + >>> y = jnp.array([2, -1, 2]) + >>> jnp.nextafter(x, y) # doctest: +SKIP + Array([ 2.9999998, -1.9999999, 1.0000001], dtype=float32) + """ return lax.nextafter(*promote_args_inexact("nextafter", x, y)) # Logical ops From 9fa0164ad286594318886f34a4d3723bb726062e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 10 Sep 2024 09:42:29 -0700 Subject: [PATCH 430/702] Estimate the amount of required scratch SMEM automatically in Pallas Mosaic GPU lowering No estimation is done if `smem_scratch_bytes` was explicitly specified via `compiler_params=`. PiperOrigin-RevId: 672998660 --- jax/_src/pallas/mosaic_gpu/lowering.py | 49 ++++++++++++++++++++++---- tests/pallas/mosaic_gpu_test.py | 1 - 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index cdcd1aa97b18..9666edcc2321 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -51,6 +51,45 @@ partial = functools.partial +_smem_estimators = {} + + +def _regiter_smem_estimator(primitive: jax_core.Primitive): + def deco(fn): + _smem_estimators[primitive] = fn + return fn + + return deco + + +def _estimate_smem_scratch_bytes(jaxpr: jax_core.Jaxpr) -> int: + """Estimates the amount of SMEM scratch bytes required by the kernel.""" + max_used = 0 + for eqn in jaxpr.eqns: + # TODO(slebedev): Add support for other primitives, notably control flow. + rule = _smem_estimators.get(eqn.primitive) + if rule is None: + # Assume that unsupported primitives are neutral wrt SMEM usage. + continue + max_used = max( + max_used, rule(*(invar.aval for invar in eqn.invars), **eqn.params) + ) + return max_used + + +@_regiter_smem_estimator(primitives.run_scoped_p) +def _run_scoped_smem_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int: + del consts # Unused. + in_avals = (v.aval.inner_aval for v in jaxpr.invars) + return sum(math.prod(aval.shape) * aval.dtype.itemsize for aval in in_avals) + + +@_regiter_smem_estimator(lax.reduce_sum_p) +def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int: + if axes != (0,): + raise NotImplementedError("No support for axes other than 0 yet") + return 4 * x_aval.dtype.itemsize + @dataclasses.dataclass class ModuleContext: @@ -358,13 +397,11 @@ def _(step, _): launch_ctx.await_async_copy(0) - # TODO(b/354568888): Add a jaxpr traversal to calculate the precise - # amount of memory required. + smem_scratch_bytes = compiler_params.get("smem_scratch_bytes"), + if smem_scratch_bytes is None: + smem_scratch_bytes = _estimate_smem_scratch_bytes(jaxpr) extra_smem_scratch = [ - jax.ShapeDtypeStruct( - shape=[compiler_params.get("smem_scratch_bytes", 100000)], - dtype=np.int8, - ) + jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8) ] module, out_structs_smem, _ = mosaic_gpu._lower_as_gpu_kernel( body, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 80cfd04c44d6..51e4d06f1179 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -118,7 +118,6 @@ def test_layer_norm(self, input_factor): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), - compiler_params={"smem_scratch_bytes": 4 * 4}, ) def layer_norm(x_ref, o_ref): x_mean = jnp.mean(x_ref[...]) From 95cb347d04ceeb5640840bad6ae1858ee2065dcb Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 10 Sep 2024 17:15:29 +0000 Subject: [PATCH 431/702] trying on another readme header --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 52dedbe80746..35307bee33f6 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ logo -# Scalable, transformable, high-performance machine learning +# Transformable numerical computing at scale ![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg) ![PyPI version](https://img.shields.io/pypi/v/jax) From bacda603fc3cba7f9a5a1201ef37c38ea4621c62 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Wed, 4 Sep 2024 16:53:44 -0700 Subject: [PATCH 432/702] kron_and_outer_docstring_added description_fixed_and_kron_desc_added description_text_and_return_fixed description_text_and_return_fixed return_fixed --- jax/_src/numpy/lax_numpy.py | 58 +++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8d17f09ac3db..859411b2080c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -7405,9 +7405,33 @@ def inner( preferred_element_type=preferred_element_type) -@util.implements(np.outer, skip_params=['out']) @partial(jit, inline=True) def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: + """Compute the outer product of two arrays. + + JAX implementation of :func:`numpy.outer`. + + Args: + a: first input array, if not 1D it will be flattened. + b: second input array, if not 1D it will be flattened. + out: unsupported by JAX. + + Returns: + The outer product of the inputs ``a`` and ``b``. Returned array + will be of shape ``(a.size, b.size)``. + + See also: + - :func:`jax.numpy.inner`: compute the inner product of two arrays. + - :func:`jax.numpy.einsum`: Einstein summation. + + Examples: + >>> a = jnp.array([1, 2, 3]) + >>> b = jnp.array([4, 5, 6]) + >>> jnp.outer(a, b) + Array([[ 4, 5, 6], + [ 8, 10, 12], + [12, 15, 18]], dtype=int32) + """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.outer is not supported.") util.check_arraylike("outer", a, b) @@ -7443,9 +7467,39 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, return moveaxis(c, 0, axisc) -@util.implements(np.kron) @jit def kron(a: ArrayLike, b: ArrayLike) -> Array: + """Compute the Kronecker product of two input arrays. + + JAX implementation of :func:`numpy.kron`. + + The Kronecker product is an operation on two matrices of arbitrary size that + produces a block matrix. Each element of the first matrix ``a`` is multiplied by + the entire second matrix ``b``. If ``a`` has shape (m, n) and ``b`` + has shape (p, q), the resulting matrix will have shape (m * p, n * q). + + Args: + a: first input array with any shape. + b: second input array with any shape. + + Returns: + A new array representing the Kronecker product of the inputs ``a`` and ``b``. + The shape of the output is the element-wise product of the input shapes. + + See also: + - :func:`jax.numpy.outer`: compute the outer product of two arrays. + + Examples: + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> b = jnp.array([[5, 6], + ... [7, 8]]) + >>> jnp.kron(a, b) + Array([[ 5, 6, 10, 12], + [ 7, 8, 14, 16], + [15, 18, 20, 24], + [21, 24, 28, 32]], dtype=int32) + """ util.check_arraylike("kron", a, b) a, b = util.promote_dtypes(a, b) if ndim(a) < ndim(b): From 90892f533a22d2c433bdb761e334f0548f9ae7c6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 10 Sep 2024 12:43:09 -0700 Subject: [PATCH 433/702] Check for `jax.Sharding`'s number of devices instead of `py_array.num_shards` which looks at IFRT sharding's num_devices to check against `global_devices` and deciding whether to fall back to python shard_arg. This is because IFRT sharding's `num_shards` method is busted. It doesn't return the global shards (in some cases) which leads to JAX program unnecessarily falling back to python. PiperOrigin-RevId: 673067095 --- jax/_src/test_util.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index f2d422e512fc..870268e99384 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -39,6 +39,7 @@ import jax from jax import lax from jax._src import api +from jax._src import array from jax._src import config from jax._src import core from jax._src import dispatch @@ -383,6 +384,25 @@ def mlir_lower_and_count(*args, **kwargs): mlir.lower_jaxpr_to_module = mlir_lower +@contextmanager +def count_jax_array_shard_arg_calls(): + # No need to clear any caches since we generally jit and pmap fresh callables + # in tests. + + array_shard_arg = array._array_shard_arg + count = [0] + + def array_shard_arg_and_count(*args, **kwargs): + count[0] += 1 + return array_shard_arg(*args, **kwargs) + + pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg_and_count + try: + yield count + finally: + pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg + + @contextmanager def count_jit_compilation_cache_miss(): # No need to clear any caches since we generally jit and pmap fresh callables @@ -1965,7 +1985,7 @@ def arcsin(self, x): # On branch cut, mpmath.mp.asin returns different value compared # to mpmath.fp.asin and numpy.arcsin (see # mpmath/mpmath#786). The following if-block ensures - # compatibiliy with numpy.arcsin. + # compatibility with numpy.arcsin. if x.real > 1 and x.imag == 0: return ctx.asin(x).conjugate() @@ -1997,7 +2017,7 @@ def arccos(self, x): return ctx.make_mpc((real._mpf_, (-sign_imag * inf)._mpf_)) # On branch cut, mpmath.mp.acos returns different value # compared to mpmath.fp.acos and numpy.arccos. The - # following if-block ensures compatibiliy with + # following if-block ensures compatibility with # numpy.arccos. if x.imag == 0 and x.real > 1: return -ctx.acos(x) @@ -2026,7 +2046,7 @@ def arcsinh(self, x): # On branch cut, mpmath.mp.asinh returns different value # compared to mpmath.fp.asinh and numpy.arcsinh (see # mpmath/mpmath#786). The following if-block ensures - # compatibiliy with numpy.arcsinh. + # compatibility with numpy.arcsinh. if x.real == 0 and x.imag < -1: return (-ctx.asinh(x)).conjugate() return ctx.asinh(x) From 8681bf6dc2d07ed465da1b5f4624a5ecb20f94d5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 10 Sep 2024 13:45:00 -0700 Subject: [PATCH 434/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/720b2c53346660e95abbed7cf3309a8b85e979f9. PiperOrigin-RevId: 673090332 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b48634781869..8f4accca508c 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "32004c272725505f068d2f0d997f68bef383a618" -XLA_SHA256 = "d46a54c5539485de704c72a884812e191e2a3f7954bd97f6d7bf16d99d12cfc5" +XLA_COMMIT = "720b2c53346660e95abbed7cf3309a8b85e979f9" +XLA_SHA256 = "a93bb0414c33025f6cb775c374d5695c84055f2bd88d6ea826d51d99612baaef" def repo(): tf_http_archive( From 46bcb1e0577ec5fde871e0c5448d864b480052c3 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 10 Sep 2024 13:59:55 -0700 Subject: [PATCH 435/702] [Pallas] Simplify lowering and fix the test for `lax.erf_inv_p` This PR is a follow-up of https://github.com/google/jax/pull/23192, which implements the lowering rule for `lax.erf_inv_p`. However, I've realised that the lowering rule can be simplified, and the test for it was moved to the wrong place. This PR resolves the above 2 issues. After merging this PR, I will continue with https://github.com/google/jax/pull/22310, which adds 64-bit lowering support for `lax.erf_inv_p`. PiperOrigin-RevId: 673095319 --- jax/_src/pallas/triton/lowering.py | 26 +++++------------------ tests/pallas/ops_test.py | 33 ++++++++++++++++-------------- 2 files changed, 23 insertions(+), 36 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index dec48847520a..5e495f4bef3e 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1002,27 +1002,6 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): ), ], ), - lax.erf_inv_p: _make_dispatch_table( - "erf_inv", - cuda=[ - _Fallback( - [jnp.float32], - lower_fun( - pallas_utils.erf_inv_32_lowering_helper, - multiple_results=False, - ), - ), - ], - rocm=[ - _Fallback( - [jnp.float32], - lower_fun( - pallas_utils.erf_inv_32_lowering_helper, - multiple_results=False, - ), - ), - ], - ), }) @@ -1364,6 +1343,11 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): ) +register_lowering(lax.erf_inv_p)( + lower_fun(pallas_utils.erf_inv_32_lowering_helper, multiple_results=False) +) + + @register_lowering(lax.iota_p) def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension): iota = _make_range(0, shape[dimension]) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 52a1d62a63a3..627ee0e8a227 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -701,6 +701,24 @@ def kernel(x_ref, o_ref): expected = jnp.sign(x) np.testing.assert_array_equal(out, expected) + @parameterized.product( + dtype=[jnp.float32], + value=[-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4], + ) + def test_erf_inv(self, dtype, value): + @functools.partial( + self.pallas_call, + # TODO(ayx): add float64 support for `erf_inv` + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.erf_inv(x_ref[...]) + + x = jnp.full((8, 128), value, dtype=dtype) + out = kernel(x) + expected = lax.erf_inv(x) + np.testing.assert_array_equal(out, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True @@ -1569,21 +1587,6 @@ def reduce(x_ref, y_ref): y_ref = jnp.cumsum(x, axis=axis) np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) - @parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4]) - def test_erf_inv(self, x): - @functools.partial( - self.pallas_call, - # TODO(ayx): add float64 support for `erf_inv` - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - ) - def kernel(x_ref, o_ref): - o_ref[...] = lax.erf_inv(x_ref[...]) - - x = jnp.full((8, 128), x) - out = kernel(x) - expected = lax.erf_inv(x) - np.testing.assert_array_equal(out, expected) - class OpsExtraInterpretTest(OpsExtraTest): INTERPRET = True From 030b6c655df1339e248433025325fc82a8103515 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 10 Sep 2024 14:02:06 -0700 Subject: [PATCH 436/702] Update the docs for conv_general_dilated to clarify 'W' 'H'. --- jax/_src/lax/convolution.py | 2 +- jax/_src/lax/other.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 2b2ad5bbb515..0e41fe5bb18f 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -114,7 +114,7 @@ def conv_general_dilated( - the input and output feature dimensions in rhs with the characters 'I' and 'O' respectively, and - spatial dimension correspondences between lhs, rhs, and the output using - any distinct characters. + any distinct characters. The examples below use 'W' and 'H'. For example, to indicate dimension numbers consistent with the ``conv`` function with two spatial dimensions, one could use ``('NCHW', 'OIHW', diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 69c7fdc0228b..67f274e829ff 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -187,7 +187,7 @@ def conv_general_dilated_local( - the input and output feature dimensions in rhs with the characters 'I' and 'O' respectively, and - spatial dimension correspondences between `lhs`, `rhs`, and the output using - any distinct characters. + any distinct characters. The examples below use 'W' and 'H'. For example, to indicate dimension numbers consistent with the `conv` function with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As From 4957ab9a5e2e49ba34f07eebafbc1fdbfacf3cc2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 10 Sep 2024 14:18:19 -0700 Subject: [PATCH 437/702] Clean up JAX backend for all backends to avoid dangling PyClient references. PiperOrigin-RevId: 673102539 --- jax/_src/api.py | 7 ++++--- jax/_src/distributed.py | 2 -- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index d19f751f1251..935995ec5cba 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -49,6 +49,7 @@ from jax._src import effects from jax._src import array from jax._src import basearray +from jax._src import distributed from jax._src import dtypes from jax._src import sharding_impls from jax._src import sharding_specs @@ -2957,7 +2958,6 @@ def try_to_block(x): return x - def clear_backends(): """ Clear all backend clients so that new backend clients can be created later. @@ -2975,9 +2975,10 @@ def clear_backends(): @atexit.register def clean_up(): - db = xb._default_backend - if db is not None and db.platform == "cpu": # pytype: disable=attribute-error + if xb._default_backend is not None: clear_backends() + # Shut down distributed system if it exists. Otherwise, this is a no-op. + distributed.shutdown() def live_arrays(platform=None): """Return all live arrays in the backend for `platform`. diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 308387d21b20..3ea9304b67aa 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -14,7 +14,6 @@ from __future__ import annotations -import atexit from collections.abc import Sequence import logging import os @@ -234,7 +233,6 @@ def initialize(coordinator_address: str | None = None, initialization_timeout, coordinator_bind_address) -@atexit.register def shutdown(): """Shuts down the distributed system. From 5e4250e64bb415be94ddc8a80dba6083a6a4123a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 10 Sep 2024 14:25:40 -0700 Subject: [PATCH 438/702] Prepare for jax 0.4.32 release PiperOrigin-RevId: 673105544 --- jax/version.py | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/version.py b/jax/version.py index cc690e02cb46..f2c34d275b01 100644 --- a/jax/version.py +++ b/jax/version.py @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.31" +_minimum_jaxlib_version = "0.4.32" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index 08ce8dbcb4ed..0830a1ae52bd 100644 --- a/setup.py +++ b/setup.py @@ -19,10 +19,10 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.31' +_current_jaxlib_version = '0.4.32' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.31' -_libtpu_version = '0.1.dev20240729' +_libtpu_version = '0.1.dev20240910' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( From e7b261c38606505358f0f55b8ae8023af577af34 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 10 Sep 2024 14:36:18 -0700 Subject: [PATCH 439/702] Removed a sneaky comma in Pallas Mosaic GPU lowering PiperOrigin-RevId: 673109846 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 9666edcc2321..48fc7d8b4212 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -397,7 +397,7 @@ def _(step, _): launch_ctx.await_async_copy(0) - smem_scratch_bytes = compiler_params.get("smem_scratch_bytes"), + smem_scratch_bytes = compiler_params.get("smem_scratch_bytes") if smem_scratch_bytes is None: smem_scratch_bytes = _estimate_smem_scratch_bytes(jaxpr) extra_smem_scratch = [ From 0b04dd022a281b81338d57107ce706389a7cb804 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 10 Sep 2024 16:00:20 -0700 Subject: [PATCH 440/702] Reverts 5e4250e64bb415be94ddc8a80dba6083a6a4123a PiperOrigin-RevId: 673141373 --- jax/version.py | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/version.py b/jax/version.py index f2c34d275b01..cc690e02cb46 100644 --- a/jax/version.py +++ b/jax/version.py @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.32" +_minimum_jaxlib_version = "0.4.31" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index 0830a1ae52bd..08ce8dbcb4ed 100644 --- a/setup.py +++ b/setup.py @@ -19,10 +19,10 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.32' +_current_jaxlib_version = '0.4.31' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.31' -_libtpu_version = '0.1.dev20240910' +_libtpu_version = '0.1.dev20240729' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( From c659dc9a011bf8ff604a7e23f916920ff717288b Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 10 Sep 2024 16:22:33 -0700 Subject: [PATCH 441/702] [Pallas] Disable win32 gpu_ops_test. PiperOrigin-RevId: 673149107 --- tests/pallas/gpu_ops_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index 7c5fa2db630c..e2f0e2152dc5 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -131,7 +131,7 @@ def setUp(self): if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPU with capability >= sm80") - if sys.platform == "win32" and not self.INTERPRET: + if sys.platform == "win32": self.skipTest("Only works on non-Windows platforms") super().setUp() From e3c4b20fa04893ad986c3184387fbd3817f1515d Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 10 Sep 2024 17:20:42 -0700 Subject: [PATCH 442/702] [Pallas] Implement tiled and swizzled Memref loads for Mosaic GPU via "GPUBlockSpec" PiperOrigin-RevId: 673165201 --- jax/_src/pallas/core.py | 205 +++++++++++++++---------- jax/_src/pallas/mosaic_gpu/BUILD | 5 +- jax/_src/pallas/mosaic_gpu/core.py | 86 +++++++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 34 +++- tests/pallas/mosaic_gpu_test.py | 20 +++ 5 files changed, 264 insertions(+), 86 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 01b027d386c3..a6f02f7da62e 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -318,6 +318,105 @@ def __init__( self.memory_space = memory_space self.indexing_mode = indexing_mode + def to_block_mapping( + self, + origin: OriginStr, + array_aval: jax_core.ShapedArray, + *, + # Inputs for the index_map + index_map_avals: Sequence[jax_core.AbstractValue], + index_map_tree: tree_util.PyTreeDef, + grid: GridMappingGrid, + mapped_dims: tuple[int, ...], + ) -> BlockMapping: + if self.index_map is None: + index_map_func = lambda *args: (0,) * len(array_aval.shape) + else: + index_map_func = self.index_map + if self.block_shape is None: + block_shape = array_aval.shape + else: + block_shape = self.block_shape + if len(array_aval.shape) != len(block_shape): + raise ValueError( + f"Block shape for {origin} (= {block_shape}) " + "must have the same number of dimensions as the " + f"array shape {array_aval.shape}." + ) + + unmapped_block_shape = tuple(s for s in block_shape if s is not None) + block_array_aval = array_aval.update(shape=unmapped_block_shape) + if isinstance(array_aval, jax_core.DShapedArray): + # Get the "max" shape for the ragged array. + block_array_aval = jax_core.ShapedArray( + block_array_aval.shape, + block_array_aval.dtype, + block_array_aval.weak_type, + ) + block_aval = AbstractMemoryRef(block_array_aval, self.memory_space) + + if not jax_core.is_constant_shape(block_aval.shape): + raise ValueError( + "shape polymorphism for Pallas does not support " + "dynamically-shaped blocks. " + f"Block spec for {origin} has block_shape: {block_aval.shape}" + ) + + flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( + lu.wrap_init(index_map_func), index_map_tree + ) + debug = pe.debug_info( + index_map_func, + index_map_tree, + index_map_out_tree_thunk, + False, + "pallas_call index_map", + ) + index_map_src_info = NameAndSrcInfo.from_pallas_call( + None, debug.func_src_info + ) + with tracing_grid_env(grid, mapped_dims): + jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( + flat_index_map_fun, index_map_avals, debug_info=debug + ) + mapped_block_shape = tuple(mapped if s is None else s for s in block_shape) + if len(out_avals) != len(block_shape): + raise ValueError( + f"Index map function {index_map_src_info} for " + f"{origin} must return " + f"{len(block_shape)} values to match {block_shape=}. " + f"Currently returning {len(out_avals)} values." + ) + for i, ov in enumerate(out_avals): + if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]: + raise ValueError( + f"Index map function {index_map_src_info} for " + f"{origin} must return integer scalars. Output[{i}] has type " + f"{ov}." + ) + + if consts: + raise ValueError( + f"Index map function {index_map_src_info} for " + f"{origin} must not capture constants: {consts}" + ) + + array_aval_shape = _max_shape_from_aval(array_aval) + + mapping = BlockMapping( + block_shape=mapped_block_shape, + block_aval=block_aval, + index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), + index_map_src_info=index_map_src_info, + indexing_mode=self.indexing_mode, + array_shape_dtype=jax.ShapeDtypeStruct( + array_aval_shape, array_aval.dtype + ), + origin=origin, + ) + mapping.check_invariants() + return mapping + class NoBlockSpec: def __repr__(self): @@ -329,6 +428,15 @@ def __repr__(self): # BlockSpecTree = Sequence[BlockSpec | NoBlockSpec, ...] | NoBlockSpec BlockSpecTree = Any + +class MemrefTransform(Protocol): + """Represents a transformation applied to a Memref on load or store.""" + + def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef: + """Returns the transformed aval given an input aval.""" + raise NotImplementedError("Abstract evaluation not implemented.") + + @dataclasses.dataclass(frozen=True) class BlockMapping: """An internal canonicalized version of BlockSpec. @@ -342,6 +450,9 @@ class BlockMapping: indexing_mode: IndexingMode array_shape_dtype: jax.ShapeDtypeStruct # The whole array origin: OriginStr + transforms: Sequence[MemrefTransform] = dataclasses.field( + default_factory=tuple + ) def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -368,6 +479,14 @@ def replace(self, **kwargs): new_self.check_invariants() return new_self + @property + def ref_aval(self) -> AbstractMemoryRef: + """Returns the abstract value of the Ref after transformations.""" + block_aval = self.block_aval + for transform in self.transforms: + block_aval = transform(block_aval) + return block_aval + def compute_start_indices_interpret(self, loop_idx, *args): discharged_jaxpr, discharged_consts = state_discharge.discharge_state( self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts @@ -603,82 +722,14 @@ def _convert_block_spec_to_block_mapping( ) -> BlockMapping: if block_spec is no_block_spec: block_spec = BlockSpec(None, None) - if block_spec.index_map is None: - index_map_func = lambda *args: (0,) * len(array_aval.shape) - else: - index_map_func = block_spec.index_map - if block_spec.block_shape is None: - block_shape = array_aval.shape - else: - block_shape = block_spec.block_shape - if len(array_aval.shape) != len(block_shape): - raise ValueError( - f"Block shape for {origin} (= {block_shape}) " - "must have the same number of dimensions as the " - f"array shape {array_aval.shape}.") - - unmapped_block_shape = tuple(s for s in block_shape if s is not None) - block_array_aval = array_aval.update(shape=unmapped_block_shape) - if isinstance(array_aval, jax_core.DShapedArray): - # Get the "max" shape for the ragged array. - block_array_aval = jax_core.ShapedArray( - block_array_aval.shape, - block_array_aval.dtype, - block_array_aval.weak_type, - ) - block_aval = AbstractMemoryRef(block_array_aval, block_spec.memory_space) - - if not jax_core.is_constant_shape(block_aval.shape): - raise ValueError( - "shape polymorphism for Pallas does not support " - "dynamically-shaped blocks. " - f"Block spec for {origin} has block_shape: {block_aval.shape}") - - flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(index_map_func), index_map_tree) - debug = pe.debug_info(index_map_func, index_map_tree, index_map_out_tree_thunk, - False, "pallas_call index_map") - index_map_src_info = NameAndSrcInfo.from_pallas_call(None, - debug.func_src_info) - with tracing_grid_env(grid, mapped_dims): - jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun, - index_map_avals, - debug_info=debug) - mapped_block_shape = tuple( - mapped if s is None else s for s in block_shape) - if len(out_avals) != len(block_shape): - raise ValueError( - f"Index map function {index_map_src_info} for " - f"{origin} must return " - f"{len(block_shape)} values to match {block_shape=}. " - f"Currently returning {len(out_avals)} values.") - for i, ov in enumerate(out_avals): - if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]: - raise ValueError( - f"Index map function {index_map_src_info} for " - f"{origin} must return integer scalars. Output[{i}] has type " - f"{ov}.") - - if consts: - raise ValueError( - f"Index map function {index_map_src_info} for " - f"{origin} must not capture constants: {consts}") - - array_aval_shape = _max_shape_from_aval(array_aval) - - mapping = BlockMapping( - block_shape=mapped_block_shape, - block_aval=block_aval, - index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), - index_map_src_info=index_map_src_info, - indexing_mode=block_spec.indexing_mode, - array_shape_dtype=jax.ShapeDtypeStruct( - array_aval_shape, array_aval.dtype - ), - origin=origin, + return block_spec.to_block_mapping( + origin, + array_aval, + index_map_avals=index_map_avals, + index_map_tree=index_map_tree, + grid=grid, + mapped_dims=mapped_dims, ) - mapping.check_invariants() - return mapping index_map_grid_aval = jax_core.ShapedArray((), jnp.int32) @@ -846,11 +897,11 @@ def get_grid_mapping( num_scratch_operands=num_flat_scratch_operands, ) grid_mapping.check_invariants() - in_ref_avals = [bm.block_aval for bm in in_block_mappings] + in_ref_avals = [bm.ref_aval for bm in in_block_mappings] jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals) jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals) - out_ref_avals = [bm.block_aval for bm in out_block_mappings] + out_ref_avals = [bm.ref_aval for bm in out_block_mappings] jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) if not isinstance(jaxpr_out_avals, (tuple, list)): jaxpr_out_avals = (jaxpr_out_avals,) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 9d2dfd8dfa0f..c3e8fc8b83de 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -71,6 +71,9 @@ pytype_strict_library( srcs = ["core.py"], deps = [ "//jax", + "//jax:core", + "//jax:mosaic_gpu", + "//jax:tree_util", "//jax/_src/pallas", - ], + ] + py_deps("numpy"), ) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 6619a9acfd02..1a0c489af47d 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -19,9 +19,13 @@ import enum from typing import ClassVar, Literal from jax import core as jax_core +from jax._src import core +from jax._src import tree_util from jax._src.pallas import core as pallas_core +from jax.experimental.mosaic import gpu as mosaic_gpu import jax.numpy as jnp + AbstractMemoryRef = pallas_core.AbstractMemoryRef @@ -55,6 +59,88 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): return MemoryRef(shape, dtype, self) +class TilingTransform(pallas_core.MemrefTransform): + """Represents a tiling transformation for Memrefs. + + A tiling of (X, Y) on an array of shape (M, N) will result in a transformed + shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a + tiling of (64, 32) will be tiled as (4, 8, 64, 32). + """ + + def __init__(self, tiling: tuple[int, ...]): + self.tiling = tiling + + def __call__( + self, block_aval: pallas_core.AbstractMemoryRef + ) -> pallas_core.AbstractMemoryRef: + block_shape = block_aval.inner_aval.shape # pytype: disable=attribute-error + old_tiled_dims = block_shape[-len(self.tiling) :] + num_tiles = tuple( + block_dim // tiling_dim + for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling) + ) + rem = ( + block_dim % tiling_dim + for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling) + ) + if any(rem): + raise ValueError( + f"Block shape {block_shape} is not divisible by tiling {self.tiling}" + ) + new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling + return block_aval.update( + inner_aval=block_aval.inner_aval.update(shape=new_block_shape) + ) + + def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform: + return mosaic_gpu.TileTransform(self.tiling) + + +@dataclasses.dataclass(frozen=True) +class GPUBlockMapping(pallas_core.BlockMapping): + swizzle: int | None = None + + +@dataclasses.dataclass +class GPUBlockSpec(pallas_core.BlockSpec): + # TODO(justinfu): Replace tiling a list of transforms. + tiling: tuple[int, ...] | None = None + swizzle: int | None = None + + def to_block_mapping( + self, + origin: pallas_core.OriginStr, + array_aval: core.ShapedArray, + *, + index_map_avals: Sequence[core.AbstractValue], + index_map_tree: tree_util.PyTreeDef, + grid: pallas_core.GridMappingGrid, + mapped_dims: tuple[int, ...], + ) -> GPUBlockMapping: + bm = super().to_block_mapping( + origin, + array_aval, + index_map_avals=index_map_avals, + index_map_tree=index_map_tree, + grid=grid, + mapped_dims=mapped_dims, + ) + transforms: tuple[pallas_core.MemrefTransform, ...] = () + if self.tiling is not None: + transforms += (TilingTransform(self.tiling),) + return GPUBlockMapping( + block_shape=bm.block_shape, + block_aval=bm.block_aval, + origin=bm.origin, + index_map_jaxpr=bm.index_map_jaxpr, + index_map_src_info=bm.index_map_src_info, + indexing_mode=bm.indexing_mode, + array_shape_dtype=bm.array_shape_dtype, + transforms=transforms, + swizzle=self.swizzle, + ) + + # TODO(b/354568887): Cosolidate this with TPU's MemoryRef. @dataclasses.dataclass(frozen=True) class MemoryRef: diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 48fc7d8b4212..87a147c91c81 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -35,6 +35,7 @@ from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives +from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.state import primitives as sp from jax.experimental.mosaic import gpu as mosaic_gpu from jax.experimental.mosaic.gpu import dsl as mgpu @@ -231,6 +232,7 @@ def lower_jaxpr_to_module( grid += (1,) * (3 - len(grid)) block = (128,) + (1,) * (len(grid) - 1) + num_inputs = grid_mapping.num_inputs params = compiler_params.get("mosaic_gpu", {}) num_stages = params.get("num_stages", 1) dimension_semantics = params.get( @@ -248,14 +250,28 @@ def lower_jaxpr_to_module( assert all(block[axis] == 1 for axis in sequential_axes) in_structs_gmem = [*grid_mapping.in_shapes] + in_block_shapes = [ + bm.block_shape + for bm in grid_mapping.block_mappings[:num_inputs] + ] in_structs_smem = [ - jax.ShapeDtypeStruct([num_stages, *bm.block_shape], s.dtype) - for bm, s in zip( - block_mappings[: grid_mapping.num_inputs], - grid_mapping.in_shapes, - ) + jax.ShapeDtypeStruct( + [num_stages, + *bm.ref_aval.inner_aval.shape], # pytype: disable=attribute-error + bm.ref_aval.inner_aval.dtype) # pytype: disable=attribute-error + for bm in block_mappings[:num_inputs] ] + in_gmem_transforms = [ + bm.transforms for bm in grid_mapping.block_mappings[:num_inputs] + ] + _get_swizzle = ( + lambda bm: bm.swizzle + if isinstance(bm, gpu_core.GPUBlockMapping) + else None + ) + in_swizzles = map(_get_swizzle, grid_mapping.block_mappings[:num_inputs]) out_structs_gmem = [*grid_mapping.out_shapes] + # TODO(justinfu): Implement output Memref transforms out_structs_smem = [ jax.ShapeDtypeStruct([num_stages, *bm.block_shape], s.dtype) for bm, s in zip( @@ -264,7 +280,7 @@ def lower_jaxpr_to_module( ) ] - def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): + def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value): *buffers_gmem, (*buffers_smem, runtime_smem, barriers) = buffers assert len(buffers_gmem) == len(buffers_smem) in_buffers_gmem, out_buffers_gmem = util.split_list( @@ -305,16 +321,18 @@ def gmem_slice( def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. + gmem_transforms = (x.to_gpu_transform() for x in in_gmem_transforms[idx]) launch_ctx.async_copy( src_ref=in_buffers_gmem[idx], dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot), gmem_slice=gmem_slice( in_start_indices[idx], step, - ir.MemRefType(in_buffers_smem[idx].type).shape[1:], + in_block_shapes[idx], ), barrier=barriers[slot], - swizzle=None, + gmem_transform=tuple(gmem_transforms), + swizzle=in_swizzles[idx], arrive=True, uniform=False, ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 51e4d06f1179..e90033be151d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -225,6 +225,26 @@ def kernel(o_ref): jnp.full([256], 2, dtype=jnp.int32), ) + def test_swizzled_blockspec_shapes(self): + @functools.partial( + pl.pallas_call, + in_specs=[ + plgpu.GPUBlockSpec( + (128, 64), lambda *i: i, tiling=(64, 64), swizzle=128 + ), + ], + out_specs=pl.BlockSpec((2, 1, 64, 64), lambda i, j: (i, j, 64, 64)), + out_shape=jax.ShapeDtypeStruct((4, 2, 64, 64), jnp.float16), + grid=(2, 2), + ) + def kernel(x_ref, o_ref): + assert x_ref.shape == (2, 1, 64, 64), x_ref.shape + o_ref[...] = x_ref[...] + + x = jnp.zeros((256, 128), dtype=jnp.float16) + result = kernel(x) + self.assertEqual(result.shape, (4, 2, 64, 64)) + if __name__ == "__main__": absltest.main() From 7c660c4ea0eea5bb9526f2ba8734e3d6233d50c8 Mon Sep 17 00:00:00 2001 From: Keshav Date: Tue, 10 Sep 2024 10:02:05 -0700 Subject: [PATCH 443/702] Squashed commit of the following: commit 1abe9559d1ba7a6ec4e2081c52ebdf0eef6b5e56 Merge: 1e1cc3e07 1b2ba9d1c Author: Keshav Date: Tue Sep 10 09:42:04 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit 1e1cc3e0733cca77e2f1ee928f96edcf63f673cf Author: Keshav Date: Tue Sep 10 09:37:22 2024 -0700 added comment commit 631c41fcbdbbac864fadd72c984b07801872f218 Merge: b93b52f27 ce3ea109a Author: Keshav Date: Wed Aug 21 08:54:00 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit b93b52f27aacf7f58eba914a91810b5d0ac06316 Author: Keshav Date: Tue Aug 20 19:00:08 2024 -0700 remove stray breakpoint commit 9ee0842ea98557bcdca0ecfd9031a8ea5274e9a4 Merge: 799e359a5 be53ee10b Author: Keshav Date: Wed Aug 7 18:09:19 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit 799e359a522acd1a83dd7868a3a9278e189664f6 Author: Keshav Date: Wed Aug 7 17:31:27 2024 -0700 added tests and minor changes fix commit c973004493f633526b14a6b5acb3afe50d58c977 Merge: 5900969cc b3924da2a Author: Keshav Date: Thu Aug 1 11:28:59 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit 5900969cc9178bf3629baa49c6a300446bf6d4a9 Author: Keshav Date: Thu Aug 1 11:20:52 2024 -0700 minor edits commit a7cc85a1cb8ddd07b783cc538f25c56f5fb78543 Merge: 89b876270 091eba195 Author: Keshav Date: Mon Jul 29 14:17:13 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit 89b876270bf5f16dc10c2f8700d69715752ca184 Author: Keshav Date: Mon Jul 29 14:11:39 2024 -0700 native IR traversal instead of string manipulation commit 3b161a414d9579c50e1902047dbd45bac840a767 Author: Keshav Date: Sun Jul 28 20:12:30 2024 -0700 longer match string and string search optimization commit 224ee59d2115ec43000105b97bd6e73c40777ab9 Merge: c7664aa61 6a7822a73 Author: Keshav Date: Sun Jul 28 17:08:29 2024 -0700 Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer commit c7664aa61fa9cec55fba9d5ee1d3ffb146a4c2b1 Author: Keshav Date: Sun Jul 28 17:07:04 2024 -0700 remove custom partitioning ptr from pre-compiled hlo during cache key computation linter fixes more linter fixes more linter fixes alternate imports --- jax/_src/cache_key.py | 21 ++++++++++++++- jax/_src/compilation_cache.py | 4 ++- jax/_src/config.py | 10 +++++++ tests/cache_key_test.py | 50 +++++++++++++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 2 deletions(-) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 6fdf0c600b7d..9bce9d0e4308 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -83,7 +83,8 @@ def get(module: ir.Module, 'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf' """ entries = [ - ("computation", lambda hash_obj: _hash_computation(hash_obj, module)), + ("computation", + lambda hash_obj: _hash_computation(hash_obj, module)), ("jax_lib version", lambda hash_obj: hash_obj.update( bytes(jaxlib_version_str.encode("utf-8")))), @@ -129,8 +130,26 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn): ) +def _remove_custom_partitioning_ptr(m: ir.Module): + """ + Removes custom_partitioning callback pointer from precompiled IR. + Python function pointers are not deterministic across executions. + """ + def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult: + if (op.name == "stablehlo.custom_call" and + op.attributes["call_target_name"].value == "CustomSPMDPartitioning"): + op.attributes["backend_config"] = ir.StringAttr.get("REMOVED") + return ir.WalkResult.ADVANCE + + m.operation.walk(_update_bc_attribute) + return m + + def _serialize_ir(m: ir.Module) -> bytes: output = io.BytesIO() + if config.remove_custom_partitioning_ptr_from_cache_key.value: + m = _remove_custom_partitioning_ptr(type_cast(ir.Module, + m.operation.clone())) m.operation.write_bytecode(file=output) return output.getvalue() diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 8117f871a969..b946dc0a2897 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -265,7 +265,9 @@ def put_executable_and_time( cache.put(cache_key, executable_and_time) -def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options, +def get_cache_key(module: ir.Module, + devices: np.ndarray, + compile_options, backend) -> str: return cache_key.get(module, devices, compile_options, backend, "zstandard" if zstandard is not None else "zlib") diff --git a/jax/_src/config.py b/jax/_src/config.py index b2d1aa52ef2a..51d7dab585af 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1347,6 +1347,16 @@ def _update_jax_memories_thread_local(val): 'size to grow indefinitely.'), ) +remove_custom_partitioning_ptr_from_cache_key = bool_state( + name='jax_remove_custom_partitioning_ptr_from_cache_key', + default=False, + help=('If set to True, remove the custom partitioning pointer ' + 'present in the precompiled stableHLO before hashing ' + 'during cache key computation. This is a potentially ' + 'unsafe flag to set and only users who are sure of ' + 'what they are trying to achieve should set it.'), +) + default_dtype_bits = enum_state( name='jax_default_dtype_bits', enum_values=['32', '64'], diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 508dbacc2a98..00925c5f7dfc 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -14,8 +14,10 @@ import hashlib import os +import re import sys import unittest +from typing import cast as type_cast import numpy as np @@ -29,6 +31,11 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib.mlir import ir +from jax._src.mesh import Mesh +from jax._src.partition_spec import PartitionSpec as P +from jax._src.sharding_impls import NamedSharding +from jax._src.custom_partitioning import custom_partitioning config.parse_flags_with_absl() @@ -155,6 +162,49 @@ def test_different_computations(self): cache_key.get(computation2, devices, compile_options, backend), ) + def test_custom_partitioning_ptr_removal(self): + def _partition(mesh, arg_shapes, result_shape): + arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) + result_shardings = NamedSharding(mesh, arg_shapes[0].sharding.spec) + return mesh, jax.numpy.add, result_shardings, arg_shardings + + def _infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return NamedSharding(mesh, arg_shapes[0].sharding.spec) + + @custom_partitioning + def _cp_add(x, y): + return jax.numpy.add(x, y) + + _cp_add.def_partition( + infer_sharding_from_operands=_infer_sharding_from_operands, + partition=_partition) + + devices = np.asarray(jax.devices()) + with Mesh(devices, ('x',)) as m: + computation = jax.jit( + _cp_add, + in_shardings=(NamedSharding(m, P('x')), + NamedSharding(m, P('x'))), + out_shardings=NamedSharding(m, P('x')) + ).lower( + jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32), + jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32), + ).compiler_ir() + pattern = ( + r'stablehlo\.custom_call @CustomSPMDPartitioning\(' + r'(.*?)\) \{' + r'(.*?backend_config\s*=\s*"([^"]*)".*?)' + r'\}' + ) + with config.remove_custom_partitioning_ptr_from_cache_key(True): + with computation.context: + updated_module = cache_key._remove_custom_partitioning_ptr( + type_cast(ir.Module, computation.operation.clone())) + bcs = [match[2] for + match in re.findall(pattern, str(updated_module), re.DOTALL)] + for bc in bcs: + self.assertEqual(bc, "REMOVED") + def test_different_device_assignment(self): computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() devices = np.array([[jax.local_devices()[0]]]) From 808003b4e29e878349192e0f63fa1a2454ace56b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 10 Sep 2024 23:53:24 -0700 Subject: [PATCH 444/702] Update users of jax.tree.map() to be more careful about how they handle Nones. Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself. Fix code that was relying on this bug. Most commonly, the fix is to write `jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`. PiperOrigin-RevId: 673258116 --- jax/experimental/jax2tf/call_tf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 04dd9f17933d..037f8bbc2a02 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -266,6 +266,8 @@ def replace_non_float_or_none(arg_tf): ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax) # We must make the float0s that JAX expects def fix_float0(arg_jax, ct_arg_jax): + if arg_jax is None: + return None arg_dtype = dtypes.result_type(arg_jax) # May be scalar ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype) if ct_arg_dtype != ct_arg_jax.dtype: @@ -273,7 +275,8 @@ def fix_float0(arg_jax, ct_arg_jax): ct_arg_dtype)) return ct_arg_jax - ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax) + ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax, + is_leaf=lambda x: x is None) return ct_args_jax_fixed make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) From 49dd6ed8d891ee6b7bbfcf7cc425382a7235556b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Sep 2024 02:00:02 -0700 Subject: [PATCH 445/702] Disable a pallas export compatibility test that fails on TPU v6e. PiperOrigin-RevId: 673295487 --- tests/pallas/export_back_compat_pallas_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index 0804cf04af9b..ff1d828f27fc 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -62,6 +62,8 @@ def add_one(x_ref, o_ref): @jax.default_matmul_precision("bfloat16") def test_mosaic_matmul(self): + if jtu.is_device_tpu(6, "e"): + self.skipTest("TODO(apaszke): Test fails on TPU v6e") dtype = jnp.float32 def func(): # Build the inputs here, to reduce the size of the golden inputs. From e5107c125d46793cc5e2b52972d4c92e66a1ecad Mon Sep 17 00:00:00 2001 From: Luke Yang <15712023+puct9@users.noreply.github.com> Date: Wed, 11 Sep 2024 23:34:03 +1000 Subject: [PATCH 446/702] Fix typo in `jax.typing` module doc --- jax/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/typing.py b/jax/typing.py index c75e0567e002..89efa1f2ca66 100644 --- a/jax/typing.py +++ b/jax/typing.py @@ -24,7 +24,7 @@ - :obj:`jax.typing.ArrayLike`: annotation for any value that is safe to implicitly cast to a JAX array; this includes :class:`jax.Array`, :class:`numpy.ndarray`, as well as Python builtin numeric values (e.g. :class:`int`, :class:`float`, etc.) and numpy scalar values - (e.g. :class:`numpy.int32`, :class:`numpy.flota64`, etc.) + (e.g. :class:`numpy.int32`, :class:`numpy.float64`, etc.) - :obj:`jax.typing.DTypeLike`: annotation for any value that can be cast to a JAX-compatible dtype; this includes strings (e.g. `'float32'`, `'int32'`), scalar types (e.g. `float`, `np.float32`), dtypes (e.g. `np.dtype('float32')`), or objects with a dtype attribute From ea68f4569c5474f20e52b96ab88c287ab843130a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 11 Sep 2024 08:47:03 -0700 Subject: [PATCH 447/702] Internal change PiperOrigin-RevId: 673409076 --- tests/mosaic/BUILD | 10 ++++++++-- tests/pallas/BUILD | 3 +-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 255b03d3a002..abab212d8618 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -34,17 +34,23 @@ DISABLED_BACKENDS = [ ] DISABLED_CONFIGS = [ - "gpu", "gpu_a100", + "gpu_a100_x32", "gpu_p100", "gpu_p100_x32", - "gpu_x32", "gpu_pjrt_c_api", + "gpu_x32", + "gpu", ] jax_test( name = "gpu_test", srcs = ["gpu_test.py"], + config_tags_overrides = { + "gpu_h100_2gpu": { + "ondemand": False, # Include in presubmit. + }, + }, disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, enable_configs = ["gpu_h100_2gpu"], diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 5559a0552f9f..9b8167527b92 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -182,9 +182,8 @@ jax_test( "mosaic_gpu_test.py", ], config_tags_overrides = { - # TODO(slebedev): Switch to False once Mosaic GPU is unconditionally enabled. "gpu_h100_x32": { - "ondemand": True, # Include in presubmit. + "ondemand": False, # Include in presubmit. }, }, disable_backends = [ From 2bd1fdead81581db08ee84a0d1f82c407ccd6b11 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Sep 2024 08:49:23 -0700 Subject: [PATCH 448/702] Relax test tolerance in pinv test to fix a CI failure on Windows CPU. https://github.com/google/jax/actions/runs/10812364182/job/29993831201 PiperOrigin-RevId: 673409820 --- tests/linalg_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 901bfca997dc..4dcdeb19e1ef 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1152,7 +1152,7 @@ def np_fn(a): a = (a + T(a.conj())) / 2 return np.linalg.pinv(a, hermitian=hermitian) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-4) - self._CompileAndCheck(jnp_fn, args_maker) + self._CompileAndCheck(jnp_fn, args_maker, atol=1e-5) # TODO(phawkins): 6e-2 seems like a very loose tolerance. jtu.check_grads(jnp_fn, args_maker(), 1, rtol=6e-2, atol=1e-3) From ed849ff9e0576dcee2514741b5ffa951a94e20a8 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Sep 2024 08:54:08 -0700 Subject: [PATCH 449/702] Make sure to call the superclass' __init__() on a newly created instance in PositionalSharding._remake(). If we don't do this, the C++ base class is left in an uninitialized state, leading to failures elsewhere in the test suite. PiperOrigin-RevId: 673411282 --- jax/_src/sharding_impls.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 0b1dc082765e..add297b6a351 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -691,13 +691,9 @@ def check_compatible_aval(self, aval_shape: Shape) -> None: def _remake( cls, devices: tuple[xc.Device, ...], ids: np.ndarray, *, memory_kind: str | None = None) -> PositionalSharding: - self = cls.__new__(cls) - self._devices = devices - self._ids = ids - self._internal_device_list = xc.DeviceList(self._devices) - self._memory_kind = xc.check_and_canonicalize_memory_kind( - memory_kind, self._internal_device_list) - return self + sharding = cls(devices, memory_kind=memory_kind) + sharding._ids = ids + return sharding # Hashable From c33a9545242723d72de38d9cb12029d6bd1c2ce7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 10 Sep 2024 04:46:42 -0700 Subject: [PATCH 450/702] Improve docs for jnp.stack, jnp.concat, & related functions --- jax/_src/numpy/lax_numpy.py | 349 +++++++++++++++++++++++++++++++++++- 1 file changed, 340 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d6c2c2715add..806022c8a34b 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3629,9 +3629,53 @@ def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], ### Array-creation functions -@util.implements(np.stack, skip_params=['out']) def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array: + """Join arrays along a new axis. + + JAX implementation of :func:`numpy.stack`. + + Args: + arrays: a sequence of arrays to stack; each must have the same shape. If a + single array is given it will be treated equivalently to + `arrays = unstack(arrays)`, but the implementation will avoid explicit + unstacking. + axis: specify the axis along which to stack. + out: unused by JAX + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the stacked result. + + See also: + - :func:`jax.numpy.unstack`: inverse of ``stack``. + - :func:`jax.numpy.concatenate`: concatenation along existing axes. + - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0. + - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1. + - :func:`jax.numpy.dstack`: stack depth-wise, i.e. along axis 2. + - :func:`jax.numpy.column_stack`: stack columns. + + Examples: + >>> x = jnp.array([1, 2, 3]) + >>> y = jnp.array([4, 5, 6]) + >>> jnp.stack([x, y]) + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.stack([x, y], axis=1) + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) + + :func:`~jax.numpy.unstack` performs the inverse operation: + + >>> arr = jnp.stack([x, y], axis=1) + >>> x, y = jnp.unstack(arr, axis=1) + >>> x + Array([1, 2, 3], dtype=int32) + >>> y + Array([4, 5, 6], dtype=int32) + """ if not len(arrays): raise ValueError("Need at least one array to stack.") if out is not None: @@ -3650,9 +3694,38 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], new_arrays.append(expand_dims(a, axis)) return concatenate(new_arrays, axis=axis, dtype=dtype) -@util.implements(getattr(np, 'unstack', None)) + @partial(jit, static_argnames="axis") def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: + """Unstack an array along an axis. + + JAX implementation of :func:`array_api.unstack`. + + Args: + x: array to unstack. Must have ``x.ndim >= 1``. + axis: integer axis along which to unstack. Must satisfy + ``-x.ndim <= axis < x.ndim``. + + Returns: + tuple of unstacked arrays. + + See also: + - :func:`jax.numpy.stack`: inverse of ``unstack`` + - :func:`jax.numpy.split`: split array into batches along an axis. + + Examples: + >>> arr = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> arrs = jnp.unstack(arr) + >>> print(*arrs) + [1 2 3] [4 5 6] + + :func:`~jax.numpy.stack` provides the inverse of this: + + >>> jnp.stack(arrs) + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + """ util.check_arraylike("unstack", x) x = asarray(x) if x.ndim == 0: @@ -3694,9 +3767,46 @@ def _concatenate_array(arr: ArrayLike, axis: int | None, dimensions = [*range(1, axis + 1), 0, *range(axis + 1, arr.ndim)] return lax.reshape(arr, shape, dimensions) -@util.implements(np.concatenate) + def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int | None = 0, dtype: DTypeLike | None = None) -> Array: + """Join arrays along an existing axis. + + JAX implementation of :func:`numpy.concatenate`. + + Args: + arrays: a sequence of arrays to concatenate; each must have the same shape + except along the specified axis. If a single array is given it will be + treated equivalently to `arrays = unstack(arrays)`, but the implementation + will avoid explicit unstacking. + axis: specify the axis along which to concatenate. + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the concatenated result. + + See also: + - :func:`jax.lax.concatenate`: XLA concatenation API. + - :func:`jax.numpy.concat`: Array API version of this function. + - :func:`jax.numpy.stack`: concatenate arrays along a new axis. + + Examples: + One-dimensional concatenation: + + >>> x = jnp.arange(3) + >>> y = jnp.zeros(3, dtype=int) + >>> jnp.concatenate([x, y]) + Array([0, 1, 2, 0, 0, 0], dtype=int32) + + Two-dimensional concatenation: + + >>> x = jnp.ones((2, 3)) + >>> y = jnp.zeros((2, 1)) + >>> jnp.concatenate([x, y], axis=1) + Array([[1., 1., 1., 0.], + [1., 1., 1., 0.]], dtype=float32) + """ if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) util.check_arraylike("concatenate", *arrays) @@ -3721,15 +3831,96 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], return arrays_out[0] -@util.implements(getattr(np, "concat", None)) def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: + """Join arrays along an existing axis. + + JAX implementation of :func:`array_api.concat`. + + Args: + arrays: a sequence of arrays to concatenate; each must have the same shape + except along the specified axis. If a single array is given it will be + treated equivalently to `arrays = unstack(arrays)`, but the implementation + will avoid explicit unstacking. + axis: specify the axis along which to concatenate. + + Returns: + the concatenated result. + + See also: + - :func:`jax.lax.concatenate`: XLA concatenation API. + - :func:`jax.numpy.concatenate`: NumPy version of this function. + - :func:`jax.numpy.stack`: concatenate arrays along a new axis. + + Examples: + One-dimensional concatenation: + + >>> x = jnp.arange(3) + >>> y = jnp.zeros(3, dtype=int) + >>> jnp.concat([x, y]) + Array([0, 1, 2, 0, 0, 0], dtype=int32) + + Two-dimensional concatenation: + + >>> x = jnp.ones((2, 3)) + >>> y = jnp.zeros((2, 1)) + >>> jnp.concat([x, y], axis=1) + Array([[1., 1., 1., 0.], + [1., 1., 1., 0.]], dtype=float32) + """ util.check_arraylike("concat", *arrays) return jax.numpy.concatenate(arrays, axis=axis) -@util.implements(np.vstack) def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: + """Vertically stack arrays. + + JAX implementation of :func:`numpy.vstack`. + + For arrays of two or more dimensions, this is equivalent to + :func:`jax.numpy.concatenate` with ``axis=0``. + + Args: + tup: a sequence of arrays to stack; each must have the same shape along all + but the first axis. If a single array is given it will be treated + equivalently to `tup = unstack(tup)`, but the implementation will avoid + explicit unstacking. + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the stacked result. + + See also: + - :func:`jax.numpy.stack`: stack along arbitrary axes + - :func:`jax.numpy.concatenate`: concatenation along existing axes. + - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1. + - :func:`jax.numpy.dstack`: stack depth-wise, i.e. along axis 2. + + Examples: + Scalar values: + + >>> jnp.vstack([1, 2, 3]) + Array([[1], + [2], + [3]], dtype=int32, weak_type=True) + + 1D arrays: + + >>> x = jnp.arange(4) + >>> y = jnp.ones(4) + >>> jnp.vstack([x, y]) + Array([[0., 1., 2., 3.], + [1., 1., 1., 1.]], dtype=float32) + + 2D arrays: + + >>> x = x.reshape(1, 4) + >>> y = y.reshape(1, 4) + >>> jnp.vstack([x, y]) + Array([[0., 1., 2., 3.], + [1., 1., 1., 1.]], dtype=float32) + """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_2d)(tup) @@ -3740,9 +3931,54 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0, dtype=dtype) -@util.implements(np.hstack) def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: + """Horizontally stack arrays. + + JAX implementation of :func:`numpy.hstack`. + + For arrays of one or more dimensions, this is equivalent to + :func:`jax.numpy.concatenate` with ``axis=1``. + + Args: + tup: a sequence of arrays to stack; each must have the same shape along all + but the second axis. Input arrays will be promoted to at least rank 1. + If a single array is given it will be treated equivalently to + `tup = unstack(tup)`, but the implementation will avoid explicit unstacking. + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the stacked result. + + See also: + - :func:`jax.numpy.stack`: stack along arbitrary axes + - :func:`jax.numpy.concatenate`: concatenation along existing axes. + - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0. + - :func:`jax.numpy.dstack`: stack depth-wise, i.e. along axis 2. + + Examples: + Scalar values: + + >>> jnp.hstack([1, 2, 3]) + Array([1, 2, 3], dtype=int32, weak_type=True) + + 1D arrays: + + >>> x = jnp.arange(3) + >>> y = jnp.ones(3) + >>> jnp.hstack([x, y]) + Array([0., 1., 2., 1., 1., 1.], dtype=float32) + + 2D arrays: + + >>> x = x.reshape(3, 1) + >>> y = y.reshape(3, 1) + >>> jnp.hstack([x, y]) + Array([[0., 1.], + [1., 1.], + [2., 1.]], dtype=float32) + """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_1d)(tup) @@ -3755,9 +3991,56 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype) -@util.implements(np.dstack) def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: + """Stack arrays depth-wise. + + JAX implementation of :func:`numpy.dstack`. + + For arrays of three or more dimensions, this is equivalent to + :func:`jax.numpy.concatenate` with ``axis=2``. + + Args: + tup: a sequence of arrays to stack; each must have the same shape along all + but the third axis. Input arrays will be promoted to at least rank 3. If a + single array is given it will be treated equivalently to `tup = unstack(tup)`, + but the implementation will avoid explicit unstacking. + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the stacked result. + + See also: + - :func:`jax.numpy.stack`: stack along arbitrary axes + - :func:`jax.numpy.concatenate`: concatenation along existing axes. + - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0. + - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1. + + Examples: + Scalar values: + + >>> jnp.dstack([1, 2, 3]) + Array([[[1, 2, 3]]], dtype=int32, weak_type=True) + + 1D arrays: + + >>> x = jnp.arange(3) + >>> y = jnp.ones(3) + >>> jnp.dstack([x, y]) + Array([[[0., 1.], + [1., 1.], + [2., 1.]]], dtype=float32) + + 2D arrays: + + >>> x = x.reshape(1, 3) + >>> y = y.reshape(1, 3) + >>> jnp.dstack([x, y]) + Array([[[0., 1.], + [1., 1.], + [2., 1.]]], dtype=float32) + """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_3d)(tup) @@ -3768,8 +4051,56 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=2, dtype=dtype) -@util.implements(np.column_stack) def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: + """Stack arrays column-wise. + + JAX implementation of :func:`numpy.column_stack`. + + For arrays of two or more dimensions, this is equivalent to + :func:`jax.numpy.concatenate` with ``axis=1``. + + Args: + tup: a sequence of arrays to stack; each must have the same leading dimension. + Input arrays will be promoted to at least rank 2. If a single array is given + it will be treated equivalently to `tup = unstack(tup)`, but the implementation + will avoid explicit unstacking. + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the stacked result. + + See also: + - :func:`jax.numpy.stack`: stack along arbitrary axes + - :func:`jax.numpy.concatenate`: concatenation along existing axes. + - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0. + - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1. + - :func:`jax.numpy.hstack`: stack depth=wise, i.e. along axis 2. + + Examples: + Scalar values: + + >>> jnp.column_stack([1, 2, 3]) + Array([[1, 2, 3]], dtype=int32, weak_type=True) + + 1D arrays: + + >>> x = jnp.arange(3) + >>> y = jnp.ones(3) + >>> jnp.column_stack([x, y]) + Array([[0., 1.], + [1., 1.], + [2., 1.]], dtype=float32) + + 2D arrays: + + >>> x = x.reshape(3, 1) + >>> y = y.reshape(3, 1) + >>> jnp.column_stack([x, y]) + Array([[0., 1.], + [1., 1.], + [2., 1.]], dtype=float32) + """ arrs: Array | list[Array] | np.ndarray if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup @@ -3777,7 +4108,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("column_stack", *tup, emit_warning=True) arrs = [atleast_2d(arr).T if arr.ndim < 2 else arr for arr in map(asarray, tup)] - return concatenate(arrs, 1) + return concatenate(arrs, axis=1) @util.implements(np.choose, skip_params=['out']) From 3a0ee844be750010aa693f0878af6c0ae0d82894 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 11 Sep 2024 23:04:27 +0530 Subject: [PATCH 451/702] Improve docs for jax.numpy arithmetic comparison operations --- jax/_src/numpy/ufuncs.py | 170 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 166 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index c3b38e57ab76..09546287cbdb 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -1005,24 +1005,186 @@ def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array], lax_op(x.real, y.real)) return lax_op(x, y) -@implements(np.greater_equal, module='numpy') @partial(jit, inline=True) def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise truth value of ``x >= y``. + + JAX implementation of :obj:`numpy.greater_equal`. + + Args: + x: input array or scalar. + y: input array or scalar. ``x`` and ``y`` must either have same shape or be + broadcast compatible. + + Returns: + An array containing boolean values. ``True`` if the elements of ``x >= y``, + and ``False`` otherwise. + + See also: + - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``. + - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``. + - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``. + + Examples: + Scalar inputs: + + >>> jnp.greater_equal(4, 7) + Array(False, dtype=bool, weak_type=True) + + Inputs with same shape: + + >>> x = jnp.array([2, 5, -1]) + >>> y = jnp.array([-6, 4, 3]) + >>> jnp.greater_equal(x, y) + Array([ True, True, False], dtype=bool) + + Inputs with broadcast compatibility: + + >>> x1 = jnp.array([[3, -1, 4], + ... [5, 9, -6]]) + >>> y1 = jnp.array([-1, 4, 2]) + >>> jnp.greater_equal(x1, y1) + Array([[ True, False, True], + [ True, True, False]], dtype=bool) + """ return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y)) -@implements(np.greater, module='numpy') + @partial(jit, inline=True) def greater(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise truth value of ``x > y``. + + JAX implementation of :obj:`numpy.greater`. + + Args: + x: input array or scalar. + y: input array or scalar. ``x`` and ``y`` must either have same shape or be + broadcast compatible. + + Returns: + An array containing boolean values. ``True`` if the elements of ``x > y``, + and ``False`` otherwise. + + See also: + - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``. + - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of + ``x >= y``. + - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``. + + Examples: + Scalar inputs: + + >>> jnp.greater(5, 2) + Array(True, dtype=bool, weak_type=True) + + Inputs with same shape: + + >>> x = jnp.array([5, 9, -2]) + >>> y = jnp.array([4, -1, 6]) + >>> jnp.greater(x, y) + Array([ True, True, False], dtype=bool) + + Inputs with broadcast compatibility: + + >>> x1 = jnp.array([[5, -6, 7], + ... [-2, 5, 9]]) + >>> y1 = jnp.array([-4, 3, 10]) + >>> jnp.greater(x1, y1) + Array([[ True, False, False], + [ True, True, False]], dtype=bool) + """ return _complex_comparison(lax.gt, *promote_args("greater", x, y)) -@implements(np.less_equal, module='numpy') + @partial(jit, inline=True) def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise truth value of ``x <= y``. + + JAX implementation of :obj:`numpy.less_equal`. + + Args: + x: input array or scalar. + y: input array or scalar. ``x`` and ``y`` must have either same shape or be + broadcast compatible. + + Returns: + An array containing the boolean values. ``True`` if the elements of ``x <= y``, + and ``False`` otherwise. + + See also: + - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of + ``x >= y``. + - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``. + - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``. + + Examples: + Scalar inputs: + + >>> jnp.less_equal(6, -2) + Array(False, dtype=bool, weak_type=True) + + Inputs with same shape: + + >>> x = jnp.array([-4, 1, 7]) + >>> y = jnp.array([2, -3, 8]) + >>> jnp.less_equal(x, y) + Array([ True, False, True], dtype=bool) + + Inputs with broadcast compatibility: + + >>> x1 = jnp.array([2, -5, 9]) + >>> y1 = jnp.array([[1, -6, 5], + ... [-2, 4, -6]]) + >>> jnp.less_equal(x1, y1) + Array([[False, False, False], + [False, True, False]], dtype=bool) + """ return _complex_comparison(lax.le, *promote_args("less_equal", x, y)) -@implements(np.less, module='numpy') + @partial(jit, inline=True) def less(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise truth value of ``x < y``. + + JAX implementation of :obj:`numpy.less`. + + Args: + x: input array or scalar. + y: input array or scalar. ``x`` and ``y`` must either have same shape or be + broadcast compatible. + + Returns: + An array containing boolean values. ``True`` if the elements of ``x < y``, + and ``False`` otherwise. + + See also: + - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``. + - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of + ``x >= y``. + - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``. + + Examples: + Scalar inputs: + + >>> jnp.less(3, 7) + Array(True, dtype=bool, weak_type=True) + + Inputs with same shape: + + >>> x = jnp.array([5, 9, -3]) + >>> y = jnp.array([1, 6, 4]) + >>> jnp.less(x, y) + Array([False, False, True], dtype=bool) + + Inputs with broadcast compatibility: + + >>> x1 = jnp.array([[2, -4, 6, -8], + ... [-1, 5, -3, 7]]) + >>> y1 = jnp.array([0, 3, -5, 9]) + >>> jnp.less(x1, y1) + Array([[False, True, False, True], + [ True, False, False, True]], dtype=bool) + """ return _complex_comparison(lax.lt, *promote_args("less", x, y)) # Array API aliases From 10057eb7395187e9e3accec1f42533672fd774c7 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 11 Sep 2024 10:44:00 -0700 Subject: [PATCH 452/702] [Pallas] Fix TPU large array indexing tests. - On TPU, this test OOMs on some chips. We fix this by forcing a garbage collect before the test. - On interpret mode, semaphores were overflowing with a large copy size. We cap the inc/dec value at maxint to prevent overflow. PiperOrigin-RevId: 673451668 --- jax/_src/pallas/core.py | 1 + jax/_src/pallas/mosaic/primitives.py | 10 ++++++---- tests/pallas/tpu_pallas_test.py | 4 ++++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index a6f02f7da62e..8dbf37587f8f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -64,6 +64,7 @@ def __repr__(self): # identifiable in kernels. # TODO(justinfu): Handle semaphores with a custom extended dtype. SEMAPHORE_INTERPRET_DTYPE = jnp.int16 +SEMAPHORE_MAX_VALUE = jnp.iinfo(SEMAPHORE_INTERPRET_DTYPE).max @runtime_checkable diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 7b4faa002d1a..348820907ed0 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -615,14 +615,15 @@ def dma_start_discharge_rule(in_avals, out_avals, # Update semaphore values. # TODO(justinfu): Potentially handle asymmetric copy sizes. - recv_size = jnp.array(updates.size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) + recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) + recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) dst_sem_value = _index_semaphore(dst_sem, dst_sem_indexers, dst_sem_aval) _, new_dst_sem = state_discharge.index_swap_array( dst_sem, dst_sem_indexers, dst_sem_value + recv_size ) if is_remote: - send_size = jnp.array( - local_src.size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) + send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE) + send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) src_sem_value = _index_semaphore(src_sem, src_sem_indexers, src_sem_aval) _, new_src_sem = state_discharge.index_swap_array( src_sem, src_sem_indexers, src_sem_value + send_size @@ -685,7 +686,8 @@ def dma_wait_discharge_rule(in_avals, out_avals, num_sem_indexers = len(tree_util.tree_leaves(sem_indexers_avals)) num_indexers = len(tree_util.tree_leaves(ref_indexers_avals)) updates = state_discharge.index_array(ref, ref_indexers) - copy_size = jnp.array(updates.size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) + copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) + copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) sem_value = _index_semaphore(sem, sem_indexers, sem_aval) _, new_sem = state_discharge.index_swap_array( sem, sem_indexers, sem_value - copy_size diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index e1c01a4e84ec..83ca6a5787cc 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -16,6 +16,7 @@ import contextlib import functools +import gc import io import math import re @@ -1421,6 +1422,9 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): def test_large_array_indexing(self): n = 6 dtype = jnp.bfloat16 + # This test sometimes OOMs on smaller chips. We garbage collect + # to increase the chance there is 6GB memory available. + gc.collect() x = jax.lax.broadcasted_iota(dtype, (n, 1024 * 1024, 512), 0) def kernel(index, x, y, sem): From 0bb5e59d85f879f3948d5dcec0c361a093bb724b Mon Sep 17 00:00:00 2001 From: selamw1 Date: Tue, 10 Sep 2024 14:53:46 -0700 Subject: [PATCH 453/702] fromstring_docstring --- jax/_src/numpy/lax_numpy.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 806022c8a34b..267e8f95c839 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5175,8 +5175,31 @@ def fromfunction(function: Callable[..., Array], shape: Any, return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) -@util.implements(np.fromstring) def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array: + """Convert a string of text into 1-D JAX array. + + JAX implementation of :func:`numpy.fromstring`. + + Args: + string: input string containing the data. + dtype: optional. Desired data type for the array. Default is ``float``. + count: optional integer specifying the number of items to read from the string. + If -1 (default), all items are read. + sep: the string used to separate values in the input string. + + Returns: + A 1-D JAX array containing the parsed data from the input string. + + See also: + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + >>> jnp.fromstring("1 2 3", dtype=int, sep=" ") + Array([1, 2, 3], dtype=int32) + >>> jnp.fromstring("0.1, 0.2, 0.3", dtype=float, count=2, sep=",") + Array([0.1, 0.2], dtype=float32) + """ return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) From 1594d2f30fdbfebf693aba4a2b264e4a3e52acc6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Sep 2024 11:58:03 -0400 Subject: [PATCH 454/702] Prepare for v0.4.32 release. --- .github/workflows/wheel_win_x64.yml | 2 +- jax/version.py | 2 +- setup.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 61912ed8978e..03b5de37b85b 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -58,7 +58,7 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | + python -m pip install --find-links ${{ github.workspace }}\dist jaxlib python -m pip install -e ${{ github.workspace }} - python -m pip install --no-index --find-links ${{ github.workspace }}\dist jaxlib echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" pytest -n auto --tb=short tests examples diff --git a/jax/version.py b/jax/version.py index cc690e02cb46..f2c34d275b01 100644 --- a/jax/version.py +++ b/jax/version.py @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.31" +_minimum_jaxlib_version = "0.4.32" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index 08ce8dbcb4ed..027e5aefbc2f 100644 --- a/setup.py +++ b/setup.py @@ -19,10 +19,10 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.31' +_current_jaxlib_version = '0.4.32' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.31' -_libtpu_version = '0.1.dev20240729' +_libtpu_version = '0.1.dev20240911' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( From 541b3a3f7565b0e3f826b388dd094d22b28efb54 Mon Sep 17 00:00:00 2001 From: kaixih Date: Mon, 26 Aug 2024 17:32:38 +0000 Subject: [PATCH 455/702] New feature --- jax/_src/nn/functions.py | 58 +++++++++++++++++++++++++++++++++------- tests/nn_test.py | 19 ++++++++----- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index a5b5aaf31799..c1f4831e5ec0 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -785,6 +785,14 @@ def _get_causal_mask(T, S): mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) return mask[None, None, :, :] +def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]): + query_pos = jnp.array(range(T)) + key_pos = jnp.array(range(S)) + left_window, right_window = local_window_size + left_mask = query_pos[..., None] <= key_pos[..., None, :] + left_window + right_mask = query_pos[..., None] >= key_pos[..., None, :] - right_window + return jnp.logical_and(right_mask, left_mask)[None, None, :, :] + def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen): q_mask = True kv_mask = True @@ -802,7 +810,8 @@ def _get_padding_mask_encoded(T, q_seqlen): mask = q_indices < q_seqlen[:, None] return mask[:, :, None, None] -def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): +def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, + local_window_size): if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None: return logits @@ -817,6 +826,10 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): mask = _get_causal_mask(T, S) combined_mask = jnp.logical_and(combined_mask, mask) + if local_window_size is not None: + mask = _get_window_mask(T, S, local_window_size) + combined_mask = jnp.logical_and(combined_mask, mask) + if q_seqlen is not None or kv_seqlen is not None: mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen) combined_mask = jnp.logical_and(combined_mask, mask) @@ -826,7 +839,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): return padded_logits def _dot_product_attention_core(query, key, value, bias, mask, is_causal, - scale, q_seqlen, kv_seqlen): + scale, q_seqlen, kv_seqlen, local_window_size): logits_dtype = jnp.promote_types(query.dtype, jnp.float32) logits = jnp.einsum('BTNH,BSNH->BNTS', query, key, preferred_element_type=logits_dtype) @@ -836,7 +849,8 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, if bias is not None: logits = (logits + bias).astype(logits.dtype) - padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen) + padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, + local_window_size) # Softmax and it is always carried out in fp32. padded_logits = padded_logits.astype(jnp.float32) @@ -857,7 +871,8 @@ def _dot_product_attention_xla( is_causal: bool, scale: float, q_seqlen: Array | None, - kv_seqlen: Array | None): + kv_seqlen: Array | None, + local_window_size: tuple[int, int] | None): B, T, N, H = query.shape _, S, K, _ = key.shape @@ -875,11 +890,13 @@ def _reshape_to_grouped(t): return t bias = _reshape_to_grouped(bias) mask = _reshape_to_grouped(mask) - vmapped_fn = jax.vmap(_dot_product_attention_core, - in_axes=(3, None, None, 2, 2, None, None, None, None), - out_axes=3) + vmapped_fn = jax.vmap( + _dot_product_attention_core, + in_axes=(3, None, None, 2, 2, None, None, None, None, None), + out_axes=3, + ) encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale, - q_seqlen, kv_seqlen) + q_seqlen, kv_seqlen, local_window_size) encoded = jnp.reshape(encoded, (B, T, N, H)) return encoded @@ -894,6 +911,7 @@ def dot_product_attention( is_causal: bool = False, query_seq_lengths: ArrayLike | None = None, key_value_seq_lengths: ArrayLike | None = None, + local_window_size: int | tuple[int, int] | None = None, implementation: Literal['xla', 'cudnn'] | None = None) -> Array: r"""Scaled dot product attention function. @@ -943,6 +961,12 @@ def dot_product_attention( :code:`(B)` key_value_seq_lengths: `int32` array of sequence lengths for key and value; shape :code:`(B)` + local_window_size: Window sizes to make self attention to attend to each + token's local window. If set, this specifies the (left_window_size, + right_window_size) for each token. E.g., if local_window_size == (3, 2) + and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend + to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as + a symmetric window (window_size, window_size). implementation: A string to control which implementation backend to use. Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults to `None`, which will automatically select the best available backend. @@ -969,6 +993,8 @@ def _ensure_4d(t): query_seq_lengths = jnp.asarray(query_seq_lengths) if key_value_seq_lengths is not None: key_value_seq_lengths = jnp.asarray(key_value_seq_lengths) + if isinstance(local_window_size, int): + local_window_size = (local_window_size, local_window_size) def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], dtype: DType | None, name: str) -> None: @@ -1003,6 +1029,7 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, q_seqlen=query_seq_lengths, kv_seqlen=key_value_seq_lengths, + local_window_size=local_window_size, ) case 'cudnn': use_padding = ( @@ -1022,9 +1049,21 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], mask_type = MaskType.CAUSAL elif use_padding: mask_type = MaskType.PADDING + # CuDNN supports only the left window with an exclusive boundary when + # causal mask is enabled. + sliding_window = None + if local_window_size is not None: + l_window, r_window = local_window_size + if r_window == 0 or mask_type == MaskType.CAUSAL: + sliding_window = l_window + 1 + else: + raise ValueError(f"cuDNN doesn't support right window: {r_window} " + "when causal mask is not used.") + out = cudnn_dot_product_attention( query_arr, key_arr, value_arr, bias, mask, query_seq_lengths, - key_value_seq_lengths, scale=scale_val, mask_type=mask_type + key_value_seq_lengths, scale=scale_val, mask_type=mask_type, + sliding_window_length=sliding_window, ) case None: # TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select @@ -1033,6 +1072,7 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, q_seqlen=query_seq_lengths, kv_seqlen=key_value_seq_lengths, + local_window_size=local_window_size, ) case _: raise ValueError(f"Unsupported implementation option: {implementation}") diff --git a/tests/nn_test.py b/tests/nn_test.py index 3722db42671c..be07de184e60 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -38,11 +38,11 @@ config.parse_flags_with_absl() -def _is_required_cudnn_version_satisfied(): +def _is_required_cudnn_version_satisfied(min_cudnn_version): return ( jtu.is_cuda_compute_capability_at_least("8.0") and cuda_versions is not None and - cuda_versions.cudnn_get_version() >= 8904 + cuda_versions.cudnn_get_version() >= min_cudnn_version ) def _check_cudnn_backend(fn, *args, **kwargs): @@ -60,7 +60,7 @@ class NNFunctionsTest(jtu.JaxTestCase): impl=['cudnn', 'xla'], ) def testDotProductAttention(self, dtype, group_num, use_vmap, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): + if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") @@ -102,13 +102,15 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl): @parameterized.product( mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), - ('custom', 'padding'), ('bias', 'causal')], + ('custom', 'padding'), ('bias', 'causal'), + ('causal', 'sliding_window')], ) def testDotProductAttentionMask(self, mask_mode): - if not _is_required_cudnn_version_satisfied(): - raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if isinstance(mask_mode, str): mask_mode = (mask_mode,) + min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904 + if not _is_required_cudnn_version_satisfied(min_cudnn_version): + raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") dtype = jnp.bfloat16 B, S, T, N, H = 2, 128, 128, 4, 32 @@ -119,6 +121,7 @@ def testDotProductAttentionMask(self, mask_mode): grad = random.normal(keys[3], (B, T, N, H), dtype) bias, mask = None, None q_seqlen, kv_seqlen = None, None + window_size = None is_causal = 'causal' in mask_mode if 'padding' in mask_mode: @@ -130,6 +133,8 @@ def testDotProductAttentionMask(self, mask_mode): mask = custom_mask[None, None, :, :] if 'bias' in mask_mode: bias = random.normal(keys[4], (1, N, T, S), dtype) + if 'sliding_window' in mask_mode: + window_size = (3, 2) if is_causal else (3, 0) sdpa = nn.dot_product_attention sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) @@ -141,9 +146,11 @@ def testDotProductAttentionMask(self, mask_mode): # Convert the kargs to positional args for the jax.vjp. fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref( q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + local_window_size=window_size, ) fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + local_window_size=window_size, ) out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen) out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen) From 3e81ae530df4b39d4fb17b79ae2442c9a3ea0cc3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Sep 2024 16:08:47 -0400 Subject: [PATCH 456/702] Update version numbers after v0.4.32 release. --- CHANGELOG.md | 9 +++++++-- jax/version.py | 2 +- setup.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 869b9dfdd196..85dca66ce5ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,12 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.4.32 +## jax 0.4.33 + +## jaxlib 0.4.33 + + +## jax 0.4.32 (September 11, 2024) * New Functionality * Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering` @@ -65,7 +70,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. The argument to {func}`jax.dlpack.from_dlpack` should be an array from another framework that implements the ``__dlpack__`` protocol. -## jaxlib 0.4.32 +## jaxlib 0.4.32 (September 11, 2024) * Breaking changes * Hermetic CUDA support is added. diff --git a/jax/version.py b/jax/version.py index f2c34d275b01..0fc66d0d3006 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.4.32" +_version = "0.4.33" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None diff --git a/setup.py b/setup.py index 027e5aefbc2f..52cb8dda347d 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ _current_jaxlib_version = '0.4.32' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.4.31' +_latest_jaxlib_version_on_pypi = '0.4.32' _libtpu_version = '0.1.dev20240911' def load_version_module(pkg_path): From 7c8508e593eeafe8624dc4ead686ea1b47614cc0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 11 Sep 2024 13:20:03 -0700 Subject: [PATCH 457/702] Add link to XLA documentation for building JAX with CUDA from sources. PiperOrigin-RevId: 673510767 --- docs/developer.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/developer.md b/docs/developer.md index 954cf7982a3a..af2e451a22ef 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -108,6 +108,8 @@ current directory. --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" ``` + Please see the full list of instructions in [XLA documentation](https://github.com/openxla/xla/blob/main/docs/hermetic_cuda.md). + * JAX versions prior v.0.4.32: you must have CUDA and CUDNN installed and provide paths to them using configuration options. From 7ce8ff29e734f7f83743a9336460611450b26d22 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 11 Sep 2024 13:22:30 -0700 Subject: [PATCH 458/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/acef59a8b9454702b4a876027a505bc3362fa906. PiperOrigin-RevId: 673511649 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 8f4accca508c..77f6f1bd978e 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "720b2c53346660e95abbed7cf3309a8b85e979f9" -XLA_SHA256 = "a93bb0414c33025f6cb775c374d5695c84055f2bd88d6ea826d51d99612baaef" +XLA_COMMIT = "acef59a8b9454702b4a876027a505bc3362fa906" +XLA_SHA256 = "ae1963475613acc8c364553feb99ec1fec9eb6b817dda350d7c660ca9de2e6ae" def repo(): tf_http_archive( From 9a8dc6b4a3d7433f18a3b53a8b39d434ba584586 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Sep 2024 14:24:24 -0700 Subject: [PATCH 459/702] Disable complex arctan test in mnegj.real part of plane. This test started failing in CI with the 0.4.32 release. PiperOrigin-RevId: 673536548 --- tests/lax_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index ce30131953af..3c812cb91c09 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3957,9 +3957,9 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'arctan': if dtype == np.complex64: regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', - 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag') + 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.real', 'mnegj.imag', 'mposj.imag') if dtype == np.complex128: - regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj') + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mnegj.real') elif name == 'arctanh': regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag') From e64de982a82cd60f80776d8a7f8b7565a3b032db Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Sep 2024 15:27:28 -0700 Subject: [PATCH 460/702] Enable the enhanced TPU launch barrier on all TPU generations. As best I can tell, it works on all TPUs at this point. PiperOrigin-RevId: 673559950 --- jax/_src/cloud_tpu_init.py | 3 +-- jax/_src/hardware_utils.py | 13 ------------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 5b39994c7523..6033e1bbb928 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -80,8 +80,7 @@ def cloud_tpu_init() -> None: os.environ['TPU_ML_PLATFORM'] = 'JAX' os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__ os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - if hardware_utils.tpu_enhanced_barrier_supported(): - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" + os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" # this makes tensorstore serialization work better on TPU os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60') diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py index 7ab5de297752..81ef07a71b19 100644 --- a/jax/_src/hardware_utils.py +++ b/jax/_src/hardware_utils.py @@ -32,13 +32,6 @@ '0x006f', ] -_TPU_ENHANCED_BARRIER_SUPPORTED = [ - # TPU v2, v3 - '0x0027', - # TPU v4 - '0x005e', -] - _NVIDIA_GPU_DEVICES = [ '/dev/nvidia0', '/dev/nvidiactl', # Docker/Kubernetes @@ -62,12 +55,6 @@ def num_available_tpu_chips_and_device_id(): return num_chips, device_id -def tpu_enhanced_barrier_supported() -> bool: - """Returns if tpu_enhanced_barrier flag is supported on this TPU version.""" - _, device_id = num_available_tpu_chips_and_device_id() - return device_id in _TPU_ENHANCED_BARRIER_SUPPORTED - - def has_visible_nvidia_gpu() -> bool: """True if there's a visible nvidia gpu available on device, False otherwise.""" From bf2237a10282dd383642aa3f11faf52a9904d5ec Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 11 Sep 2024 15:41:05 -0700 Subject: [PATCH 461/702] Flip jax_pmap_no_rank_reduction by default to True. This changes: * The performance of array[0] (use array[0:1] instead). * The shape of jax_array.addressable_shards or jax_array.addressable_data(0) of arrays that come from pmap. PiperOrigin-RevId: 673564995 --- CHANGELOG.md | 8 ++++++++ jax/_src/config.py | 6 ++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85dca66ce5ff..df2b5813ce1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. C++ and CUDA code from JAX. * Changes + * `jax_pmap_no_rank_reduction` flag is set to `True` by default. + * array[0] on a pmap result now introduces a reshape (use array[0:1] + instead). + * The per-shard shape (accessable via jax_array.addressable_shards or + jax_array.addressable_data(0)) now has a leading (1, ...). Update code + that directly accesses shards accordingly. The rank of the per-shard-shape + now matches that of the global shape which is the same behavior as jit. + This avoids costly reshapes when passing results from pmap into jit. * `jax_enable_memories` flag is set to `True` by default. * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. See {ref}`python-array-api` for more information. diff --git a/jax/_src/config.py b/jax/_src/config.py index 51d7dab585af..75df6be7d9e7 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1720,10 +1720,8 @@ def _update_debug_log_modules(module_names_str: str | None): pmap_no_rank_reduction = bool_state( name='jax_pmap_no_rank_reduction', - default=False, - help=( - "If True, pmap shards have a the same rank as their enclosing array." - ) + default=True, + help='If True, pmap shards have a the same rank as their enclosing array.', ) use_shardy_partitioner = bool_state( From 3c0242948101fa33b2c271621b0178ae916f03f4 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Thu, 12 Sep 2024 11:56:09 +0300 Subject: [PATCH 462/702] Update complex arctan and arctanh accuracy tests --- jax/_src/test_util.py | 45 +++++++++++++++++++++++++++++++++++++++++++ tests/lax_test.py | 14 ++------------ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 870268e99384..5afcd5e3a718 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -2074,6 +2074,51 @@ def arccosh(self, x): return ctx.make_mpc((inf._mpf_, imag._mpf_)) return ctx.acosh(x) + def arctan(self, x): + ctx = x.context + + if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in atan(+-inf+-infj) evaluation + # (see mpmath/mpmath#775 with the fix). + # TODO(pearu): remove the if-block below when mpmath 1.4 or + # newer will be the required test dependency. + pi = ctx.pi + zero = ctx.zero + if ctx.isinf(x.real) or ctx.isinf(x.imag): + if x.real < 0: + return ctx.make_mpc(((-pi / 2)._mpf_, zero._mpf_)) + return ctx.make_mpc(((pi / 2)._mpf_, zero._mpf_)) + + # On branch cut, mpmath.mp.atan returns different value compared + # to mpmath.fp.atan and numpy.arctan (see mpmath/mpmath#865). + # The following if-block ensures compatibility with + # numpy.arctan. + if x.real == 0 and x.imag < -1: + return (-ctx.atan(x)).conjugate() + return ctx.atan(x) + + def arctanh(self, x): + ctx = x.context + + if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in atanh(+-inf+-infj) evaluation + # (see mpmath/mpmath#775 with the fix). + # TODO(pearu): remove the if-block below when mpmath 1.4 or + # newer will be the required test dependency. + pi = ctx.pi + zero = ctx.zero + if ctx.isinf(x.real) or ctx.isinf(x.imag): + if x.imag < 0: + return ctx.make_mpc((zero._mpf_, (-pi / 2)._mpf_)) + return ctx.make_mpc((zero._mpf_, (pi / 2)._mpf_)) + + # On branch cut, mpmath.mp.atanh returns different value + # compared to mpmath.fp.atanh and numpy.arctanh. The following + # if-block ensures compatibility with numpy.arctanh. + if x.imag == 0 and x.real > 1: + return ctx.atanh(x).conjugate() + return ctx.atanh(x) + def normalize(self, exact, reference, value): """Normalize reference and value using precision defined by the difference of exact and reference. diff --git a/tests/lax_test.py b/tests/lax_test.py index 3c812cb91c09..74fef668aa47 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3798,7 +3798,7 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind): size_im = 11 atol = None - if name in {"arccos", "arcsin", "arcsinh", "arccosh"}: + if name in {"arccos", "arcsin", "arcsinh", "arccosh", "arctan", "arctanh"}: # TODO(pearu): eliminate this if-block when a fix to mpmath#787 # becomes available extra_prec_multiplier = 20 @@ -3954,21 +3954,11 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'arccos': regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real') - elif name == 'arctan': - if dtype == np.complex64: - regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', - 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.real', 'mnegj.imag', 'mposj.imag') - if dtype == np.complex128: - regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mnegj.real') - - elif name == 'arctanh': - regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag') - elif name in {'cos', 'sin'}: regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan', - 'arcsinh', 'arcsin', 'arccosh'}: + 'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh'}: regions_with_inaccuracies.clear() else: assert 0 # unreachable From 5234173f78e60885388c3ee54ecdc7f51917353f Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Thu, 12 Sep 2024 16:30:00 +0530 Subject: [PATCH 463/702] Improve doc for jnp.resize --- jax/_src/numpy/lax_numpy.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 806022c8a34b..6991c30c2efa 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1774,9 +1774,40 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i)) for s, i in safe_zip(shape, out_indices)) -@util.implements(np.resize) + @partial(jit, static_argnames=('new_shape',)) def resize(a: ArrayLike, new_shape: Shape) -> Array: + """Return a new array with specified shape. + + JAX implementation of :func:`numpy.resize`. + + Args: + a: input array or scalar. + new_shape: int or tuple of ints. Specifies the shape of the resized array. + + Returns: + A resized array with specified shape. The elements of ``a`` are repeated in + the resized array, if the resized array is larger than the original aray. + + See also: + - :func:`jax.numpy.reshape`: Returns a reshaped copy of an array. + - :func:`jax.numpy.repeat`: Constructs an array from repeated elements. + + Examples: + >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + >>> jnp.resize(x, (3, 3)) + Array([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], dtype=int32) + >>> jnp.resize(x, (3, 4)) + Array([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 1, 2, 3]], dtype=int32) + >>> jnp.resize(4, (3, 2)) + Array([[4, 4], + [4, 4], + [4, 4]], dtype=int32, weak_type=True) + """ util.check_arraylike("resize", a) new_shape = _ensure_index_tuple(new_shape) From 191d7ccd8c77d5452a530e84fdda59149cdf4ec6 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Thu, 12 Sep 2024 17:55:23 +0530 Subject: [PATCH 464/702] Improve docs for jax.numpy: diff and ediff1d --- jax/_src/numpy/lax_numpy.py | 113 +++++++++++++++++++++++++++++++++--- 1 file changed, 106 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 806022c8a34b..4319004e594c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1283,11 +1283,68 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: return ufuncs.degrees(result) if deg else result -@util.implements(np.diff) @partial(jit, static_argnames=('n', 'axis')) def diff(a: ArrayLike, n: int = 1, axis: int = -1, prepend: ArrayLike | None = None, append: ArrayLike | None = None) -> Array: + """Calculate n-th order difference between array elements along a given axis. + + JAX implementation of :func:`numpy.diff`. + + The first order difference is computed by ``a[i+1] - a[i]``, and the n-th order + difference is computed ``n`` times recursively. + + Args: + a: input array. Must have ``a.ndim >= 1``. + n: int, optional, default=1. Order of the difference. Specifies the number + of times the difference is computed. If n=0, no difference is computed and + input is returned as is. + axis: int, optional, default=-1. Specifies the axis along which the difference + is computed. The difference is computed along ``axis -1`` by default. + prepend: scalar or array, optional, defualt=None. Specifies the values to be + prepended along ``axis`` before computing the difference. + append: scalar or array, optional, defualt=None. Specifies the values to be + appended along ``axis`` before computing the difference. + + Returns: + An array containing the n-th order difference between the elements of ``a``. + + See also: + - :func:`jax.numpy.ediff1d`: Computes the differences between consecutive + elements of an array. + - :func:`jax.numpy.cumsum`: Computes the cumulative sum of the elements of + the array along a given axis. + - :func:`jax.numpy.gradient`: Computes the gradient of an N-dimensional array. + + Examples: + ``jnp.diff`` computes the first order difference along ``axis``, by default. + + >>> a = jnp.array([[1, 5, 2, 9], + ... [3, 8, 7, 4]]) + >>> jnp.diff(a) + Array([[ 4, -3, 7], + [ 5, -1, -3]], dtype=int32) + + When ``n = 2``, second order difference is computed along ``axis``. + + >>> jnp.diff(a, n=2) + Array([[-7, 10], + [-6, -2]], dtype=int32) + + When ``prepend = 2``, it is prepended to ``a`` along ``axis`` before computing + the difference. + + >>> jnp.diff(a, prepend=2) + Array([[-1, 4, -3, 7], + [ 1, 5, -1, -3]], dtype=int32) + + When ``append = jnp.array([[3],[1]])``, it is appended to ``a`` along ``axis`` + before computing the difference. + + >>> jnp.diff(a, append=jnp.array([[3],[1]])) + Array([[ 4, -3, 7, -6], + [ 5, -1, -3, -3]], dtype=int32) + """ util.check_arraylike("diff", a) arr = asarray(a) n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff") @@ -1337,16 +1394,58 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, return arr -_EDIFF1D_DOC = """\ -Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not -issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary`` -loses precision. -""" -@util.implements(np.ediff1d, lax_description=_EDIFF1D_DOC) @jit def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, to_begin: ArrayLike | None = None) -> Array: + """Compute the differences of the elements of the flattened array. + + JAX implementation of :func:`numpy.ediff1d`. + + Args: + ary: input array or scalar. + to_end: scalar or array, optional, default=None. Specifies the numbers to + append to the resulting array. + to_begin: scalar or array, optional, default=None. Specifies the numbers to + prepend to the resulting array. + + Returns: + An array containing the differences between the elements of the input array. + + Note: + Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will + not issue an error if casting ``to_end`` or ``to_begin`` to the type of + ``ary`` loses precision. + + See also: + - :func:`jax.numpy.diff`: Computes the n-th order difference between elements + of the array along a given axis. + - :func:`jax.numpy.cumsum`: Computes the cumulative sum of the elements of + the array along a given axis. + - :func:`jax.numpy.gradient`: Computes the gradient of an N-dimensional array. + + Examples: + >>> a = jnp.array([2, 3, 5, 9, 1, 4]) + >>> jnp.ediff1d(a) + Array([ 1, 2, 4, -8, 3], dtype=int32) + >>> jnp.ediff1d(a, to_begin=-10) + Array([-10, 1, 2, 4, -8, 3], dtype=int32) + >>> jnp.ediff1d(a, to_end=jnp.array([20, 30])) + Array([ 1, 2, 4, -8, 3, 20, 30], dtype=int32) + >>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30])) + Array([-10, 1, 2, 4, -8, 3, 20, 30], dtype=int32) + + For array with ``ndim > 1``, the differences are computed after flattening + the input array. + + >>> a1 = jnp.array([[2, -1, 4, 7], + ... [3, 5, -6, 9]]) + >>> jnp.ediff1d(a1) + Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) + >>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9]) + >>> jnp.ediff1d(a2) + Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) + """ util.check_arraylike("ediff1d", ary) arr = ravel(ary) result = lax.sub(arr[1:], arr[:-1]) From a3bf75e442e0eb21bc8d25cf1b687373796830df Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 12 Sep 2024 07:11:01 -0700 Subject: [PATCH 465/702] Refactor gpusolver kernel definitions into separate build target. There is a lot of boilerplate required for each new custom call to cuSolver / cuBLAS, and having both the FFI logic and the framework wrappers in the same library was getting unwieldy. This change adds a new "interface" target which just includes the shims to wrap cuSolver/BLAS functions, and then these are used from `solver_kernels_ffi` where the FFI logic lives. PiperOrigin-RevId: 673832309 --- jaxlib/cuda/BUILD | 17 ++ jaxlib/gpu/BUILD | 2 + jaxlib/gpu/solver_interface.cc | 237 +++++++++++++++++ jaxlib/gpu/solver_interface.h | 174 +++++++++++++ jaxlib/gpu/solver_kernels_ffi.cc | 435 +++++++------------------------ jaxlib/rocm/BUILD | 16 ++ 6 files changed, 537 insertions(+), 344 deletions(-) create mode 100644 jaxlib/gpu/solver_interface.cc create mode 100644 jaxlib/gpu/solver_interface.h diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 5cf85f3697c7..34e40d12d5be 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -227,6 +227,22 @@ cc_library( ], ) +cc_library( + name = "cusolver_interface", + srcs = ["//jaxlib/gpu:solver_interface.cc"], + hdrs = ["//jaxlib/gpu:solver_interface.h"], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", + ], +) + cc_library( name = "cusolver_kernels_ffi", srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], @@ -237,6 +253,7 @@ cc_library( ":cuda_make_batch_pointers", ":cuda_solver_handle_pool", ":cuda_vendor", + ":cusolver_interface", "//jaxlib:ffi_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 8c4144974b4a..048ea23a9cff 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -53,6 +53,8 @@ exports_files(srcs = [ "solver.cc", "solver_handle_pool.cc", "solver_handle_pool.h", + "solver_interface.cc", + "solver_interface.h", "solver_kernels.cc", "solver_kernels.h", "solver_kernels_ffi.cc", diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc new file mode 100644 index 000000000000..3c8282ec603a --- /dev/null +++ b/jaxlib/gpu/solver_interface.cc @@ -0,0 +1,237 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/solver_interface.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace solver { + +// LU decomposition: getrf + +#define JAX_GPU_DEFINE_GETRF(Type, Name) \ + template <> \ + absl::StatusOr GetrfBufferSize(gpusolverDnHandle_t handle, int m, \ + int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Getrf(gpusolverDnHandle_t handle, int m, int n, Type *a, \ + Type *workspace, int lwork, int *ipiv, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \ + } + +JAX_GPU_DEFINE_GETRF(float, gpusolverDnSgetrf); +JAX_GPU_DEFINE_GETRF(double, gpusolverDnDgetrf); +JAX_GPU_DEFINE_GETRF(gpuComplex, gpusolverDnCgetrf); +JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf); +#undef JAX_GPU_DEFINE_GETRF + +#define JAX_GPU_DEFINE_GETRF_BATCHED(Type, Name) \ + template <> \ + absl::Status GetrfBatched(gpublasHandle_t handle, int n, Type **a, \ + int lda, int *ipiv, int *info, int batch) { \ + return JAX_AS_STATUS(Name(handle, n, a, lda, ipiv, info, batch)); \ + } + +JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched); +#undef JAX_GPU_DEFINE_GETRF_BATCHED + +// QR decomposition: geqrf + +#define JAX_GPU_DEFINE_GEQRF(Type, Name) \ + template <> \ + absl::StatusOr GeqrfBufferSize(gpusolverDnHandle_t handle, int m, \ + int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Geqrf(gpusolverDnHandle_t handle, int m, int n, Type *a, \ + Type *tau, Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, m, n, a, m, tau, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_GEQRF(float, gpusolverDnSgeqrf); +JAX_GPU_DEFINE_GEQRF(double, gpusolverDnDgeqrf); +JAX_GPU_DEFINE_GEQRF(gpuComplex, gpusolverDnCgeqrf); +JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf); +#undef JAX_GPU_DEFINE_GEQRF + +#define JAX_GPU_DEFINE_GEQRF_BATCHED(Type, Name) \ + template <> \ + absl::Status GeqrfBatched(gpublasHandle_t handle, int m, int n, \ + Type **a, Type **tau, int *info, \ + int batch) { \ + return JAX_AS_STATUS(Name(handle, m, n, a, m, tau, info, batch)); \ + } + +JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched); +#undef JAX_GPU_DEFINE_GEQRF_BATCHED + +// Householder transformations: orgqr + +#define JAX_GPU_DEFINE_ORGQR(Type, Name) \ + template <> \ + absl::StatusOr OrgqrBufferSize(gpusolverDnHandle_t handle, int m, \ + int n, int k) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \ + handle, m, n, k, /*A=*/nullptr, /*lda=*/m, /*tau=*/nullptr, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Orgqr(gpusolverDnHandle_t handle, int m, int n, int k, \ + Type *a, Type *tau, Type *workspace, int lwork, \ + int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_ORGQR(float, gpusolverDnSorgqr); +JAX_GPU_DEFINE_ORGQR(double, gpusolverDnDorgqr); +JAX_GPU_DEFINE_ORGQR(gpuComplex, gpusolverDnCungqr); +JAX_GPU_DEFINE_ORGQR(gpuDoubleComplex, gpusolverDnZungqr); +#undef JAX_GPU_DEFINE_ORGQR + +// Symmetric (Hermitian) eigendecomposition: +// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) +// * QR algorithm: syevd/heevd + +#define JAX_GPU_DEFINE_SYEVJ(Type, Name) \ + template <> \ + absl::StatusOr SyevjBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ + /*w=*/nullptr, &lwork, params))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Syevj( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, RealType::value *w, \ + Type *workspace, int lwork, int *info, gpuSyevjInfo_t params) { \ + return JAX_AS_STATUS( \ + Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info, params)); \ + } + +JAX_GPU_DEFINE_SYEVJ(float, gpusolverDnSsyevj); +JAX_GPU_DEFINE_SYEVJ(double, gpusolverDnDsyevj); +JAX_GPU_DEFINE_SYEVJ(gpuComplex, gpusolverDnCheevj); +JAX_GPU_DEFINE_SYEVJ(gpuDoubleComplex, gpusolverDnZheevj); +#undef JAX_GPU_DEFINE_SYEVJ + +#define JAX_GPU_DEFINE_SYEVJ_BATCHED(Type, Name) \ + template <> \ + absl::StatusOr SyevjBatchedBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ + /*w=*/nullptr, &lwork, params, batch))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status SyevjBatched( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, RealType::value *w, \ + Type *workspace, int lwork, int *info, gpuSyevjInfo_t params, \ + int batch) { \ + return JAX_AS_STATUS(Name(handle, jobz, uplo, n, a, n, w, workspace, \ + lwork, info, params, batch)); \ + } + +JAX_GPU_DEFINE_SYEVJ_BATCHED(float, gpusolverDnSsyevjBatched); +JAX_GPU_DEFINE_SYEVJ_BATCHED(double, gpusolverDnDsyevjBatched); +JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuComplex, gpusolverDnCheevjBatched); +JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuDoubleComplex, gpusolverDnZheevjBatched); +#undef JAX_GPU_DEFINE_SYEVJ_BATCHED + +#define JAX_GPU_DEFINE_SYEVD(Type, Name) \ + template <> \ + absl::StatusOr SyevdBufferSize(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR( \ + JAX_AS_STATUS(Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, \ + /*lda=*/n, /*w=*/nullptr, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Syevd(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ + int n, Type *a, RealType::value *w, \ + Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_SYEVD(float, gpusolverDnSsyevd); +JAX_GPU_DEFINE_SYEVD(double, gpusolverDnDsyevd); +JAX_GPU_DEFINE_SYEVD(gpuComplex, gpusolverDnCheevd); +JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd); +#undef JAX_GPU_DEFINE_SYEVD + +// Symmetric rank-k update: syrk + +#define JAX_GPU_DEFINE_SYRK(Type, Name) \ + template <> \ + absl::Status Syrk(gpublasHandle_t handle, gpublasFillMode_t uplo, \ + gpublasOperation_t trans, int n, int k, \ + const Type *alpha, const Type *a, const Type *beta, \ + Type *c) { \ + int lda = trans == GPUBLAS_OP_N ? n : k; \ + return JAX_AS_STATUS( \ + Name(handle, uplo, trans, n, k, alpha, a, lda, beta, c, n)); \ + } + +JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk); +JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk); +JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk); +JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk); +#undef JAX_GPU_DEFINE_SYRK + +} // namespace solver +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h new file mode 100644 index 000000000000..5072be98489f --- /dev/null +++ b/jaxlib/gpu/solver_interface.h @@ -0,0 +1,174 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines a standard interface to the GPU linear algebra libraries. + +#ifndef JAXLIB_GPU_SOLVER_INTERFACE_H_ +#define JAXLIB_GPU_SOLVER_INTERFACE_H_ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace solver { + +template +struct RealType { + using value = T; +}; + +template <> +struct RealType { + using value = float; +}; + +template <> +struct RealType { + using value = double; +}; + +#define JAX_GPU_SOLVER_EXPAND_DEFINITION(ReturnType, FunctionName) \ + template \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(T, typename RealType::value)) { \ + return absl::UnimplementedError(absl::StrFormat( \ + #FunctionName " not implemented for type %s", typeid(T).name())); \ + } \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(float, float)); \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(double, double)); \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuComplex, float)); \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuDoubleComplex, double)) + +// LU decomposition: getrf + +#define JAX_GPU_SOLVER_GetrfBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GetrfBufferSize); +#undef JAX_GPU_SOLVER_GetrfBufferSize_ARGS + +#define JAX_GPU_SOLVER_Getrf_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, Type *a, Type *workspace, \ + int lwork, int *ipiv, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Getrf); +#undef JAX_GPU_SOLVER_Getrf_ARGS + +#define JAX_GPU_SOLVER_GetrfBatched_ARGS(Type, ...) \ + gpublasHandle_t handle, int n, Type **a, int lda, int *ipiv, int *info, \ + int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GetrfBatched); +#undef JAX_GPU_SOLVER_GetrfBatched_ARGS + +// QR decomposition: geqrf + +#define JAX_GPU_SOLVER_GeqrfBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GeqrfBufferSize); +#undef JAX_GPU_SOLVER_GeqrfBufferSize_ARGS + +#define JAX_GPU_SOLVER_Geqrf_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, Type *a, Type *tau, \ + Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Geqrf); +#undef JAX_GPU_SOLVER_Geqrf_ARGS + +#define JAX_GPU_SOLVER_GeqrfBatched_ARGS(Type, ...) \ + gpublasHandle_t handle, int m, int n, Type **a, Type **tau, int *info, \ + int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GeqrfBatched); +#undef JAX_GPU_SOLVER_GeqrfBatched_ARGS + +// Householder transformations: orgqr + +#define JAX_GPU_SOLVER_OrgqrBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, int k +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, OrgqrBufferSize); +#undef JAX_GPU_SOLVER_OrgqrBufferSize_ARGS + +#define JAX_GPU_SOLVER_Orgqr_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, int k, Type *a, Type *tau, \ + Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Orgqr); +#undef JAX_GPU_SOLVER_Orgqr_ARGS + +// Symmetric (Hermitian) eigendecomposition: +// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) +// * QR algorithm: syevd/heevd + +#define JAX_GPU_SOLVER_SyevjBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevjBufferSize); +#undef JAX_GPU_SOLVER_SyevjBufferSize_ARGS + +#define JAX_GPU_SOLVER_Syevj_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \ + int lwork, int *info, gpuSyevjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevj); +#undef JAX_GPU_SOLVER_Syevj_ARGS + +#define JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevjBatchedBufferSize); +#undef JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS + +#define JAX_GPU_SOLVER_SyevjBatched_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \ + int lwork, int *info, gpuSyevjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, SyevjBatched); +#undef JAX_GPU_SOLVER_SyevjBatched_ARGS + +#define JAX_GPU_SOLVER_SyevdBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevdBufferSize); +#undef JAX_GPU_SOLVER_SyevdBufferSize_ARGS + +#define JAX_GPU_SOLVER_Syevd_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \ + int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevd); +#undef JAX_GPU_SOLVER_Syevd_ARGS + +// Symmetric rank-k update: syrk + +#define JAX_GPU_SOLVER_Syrk_ARGS(Type, ...) \ + gpublasHandle_t handle, gpublasFillMode_t uplo, gpublasOperation_t trans, \ + int n, int k, const Type *alpha, const Type *a, const Type *beta, \ + Type *c +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk); +#undef JAX_GPU_SOLVER_Syrk_ARGS + +#undef JAX_GPU_SOLVER_EXPAND_DEFINITION + +} // namespace solver +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_SOLVER_INTERFACE_H_ diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 3c74b85192ad..e3f63234f538 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -29,9 +29,13 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/make_batch_pointers.h" #include "jaxlib/gpu/solver_handle_pool.h" +#include "jaxlib/gpu/solver_interface.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" +#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) + XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::JAX_GPU_NAMESPACE::SyevdAlgorithm); namespace jax { @@ -39,7 +43,6 @@ namespace JAX_GPU_NAMESPACE { namespace ffi = ::xla::ffi; -namespace { template inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, int64_t size, @@ -53,22 +56,6 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, return static_cast(maybe_workspace.value()); } -template -struct RealType { - using Type = T; -}; - -template <> -struct RealType { - using Type = float; -}; - -template <> -struct RealType { - using Type = double; -}; -} // namespace - #define SOLVER_DISPATCH_IMPL(impl, ...) \ if (dataType == ffi::F32) { \ return impl(__VA_ARGS__); \ @@ -93,33 +80,6 @@ struct RealType { // LU decomposition: getrf -namespace { -#define GETRF_KERNEL_IMPL(type, name) \ - template <> \ - struct GetrfKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \ - int n) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \ - type* workspace, int lwork, int* ipiv, \ - int* info) { \ - return JAX_AS_STATUS( \ - name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \ - } \ - } - -template -struct GetrfKernel; -GETRF_KERNEL_IMPL(float, gpusolverDnSgetrf); -GETRF_KERNEL_IMPL(double, gpusolverDnDgetrf); -GETRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgetrf); -GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf); -#undef GETRF_KERNEL_IMPL - template ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, @@ -131,7 +91,7 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); FFI_ASSIGN_OR_RETURN(int lwork, - GetrfKernel::BufferSize(handle.get(), m, n)); + solver::GetrfBufferSize(handle.get(), m, n)); FFI_ASSIGN_OR_RETURN(auto workspace, AllocateWorkspace(scratch, lwork, "getrf")); @@ -140,13 +100,13 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, auto ipiv_data = ipiv->typed_data(); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } int ipiv_step = std::min(m, n); for (auto i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(GetrfKernel::Run( + FFI_RETURN_IF_ERROR_STATUS(solver::Getrf( handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data)); out_data += m * n; ipiv_data += ipiv_step; @@ -155,23 +115,6 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, return ffi::Error::Success(); } -#define GETRF_BATCHED_KERNEL_IMPL(type, name) \ - template <> \ - struct GetrfBatchedKernel { \ - static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \ - int* ipiv, int* info, int batch) { \ - return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \ - } \ - } - -template -struct GetrfBatchedKernel; -GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched); -#undef GETRF_BATCHED_KERNEL_IMPL - template ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, @@ -188,15 +131,15 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, auto ipiv_data = ipiv->typed_data(); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch, sizeof(T) * n * n); - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); - FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run( + FFI_RETURN_IF_ERROR_STATUS(solver::GetrfBatched( handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch)); return ffi::Error::Success(); @@ -228,7 +171,6 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, return ffi::Error::InvalidArgument(absl::StrFormat( "Unsupported dtype %s in getrf", absl::FormatStreamed(dataType))); } -} // namespace XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch, ffi::Ffi::Bind() @@ -242,33 +184,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch, // QR decomposition: geqrf -namespace { -#define GEQRF_KERNEL_IMPL(type, name) \ - template <> \ - struct GeqrfKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \ - int n) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \ - type* tau, type* workspace, int lwork, \ - int* info) { \ - return JAX_AS_STATUS( \ - name(handle, m, n, a, m, tau, workspace, lwork, info)); \ - } \ - } - -template -struct GeqrfKernel; -GEQRF_KERNEL_IMPL(float, gpusolverDnSgeqrf); -GEQRF_KERNEL_IMPL(double, gpusolverDnDgeqrf); -GEQRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgeqrf); -GEQRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgeqrf); -#undef GEQRF_KERNEL_IMPL - template ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, @@ -279,7 +194,7 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); FFI_ASSIGN_OR_RETURN(int lwork, - GeqrfKernel::BufferSize(handle.get(), m, n)); + solver::GeqrfBufferSize(handle.get(), m, n)); FFI_ASSIGN_OR_RETURN(auto workspace, AllocateWorkspace(scratch, lwork, "geqrf")); @@ -292,14 +207,14 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, auto out_data = static_cast(out->untyped_data()); auto tau_data = static_cast(tau->untyped_data()); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } int out_step = m * n; int tau_step = std::min(m, n); for (auto i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel::Run( + FFI_RETURN_IF_ERROR_STATUS(solver::Geqrf( handle.get(), m, n, out_data, tau_data, workspace, lwork, info)); out_data += out_step; tau_data += tau_step; @@ -307,23 +222,6 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, return ffi::Error::Success(); } -#define GEQRF_BATCHED_KERNEL_IMPL(type, name) \ - template <> \ - struct GeqrfBatchedKernel { \ - static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \ - type** tau, int* info, int batch) { \ - return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \ - } \ - } - -template -struct GeqrfBatchedKernel; -GEQRF_BATCHED_KERNEL_IMPL(float, gpublasSgeqrfBatched); -GEQRF_BATCHED_KERNEL_IMPL(double, gpublasDgeqrfBatched); -GEQRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgeqrfBatched); -GEQRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgeqrfBatched); -#undef GEQRF_BATCHED_KERNEL_IMPL - template ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, @@ -341,21 +239,21 @@ ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols, auto out_data = out->untyped_data(); auto tau_data = tau->untyped_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch, sizeof(T) * m * n); - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); MakeBatchPointersAsync(stream, tau_data, tau_batch_ptrs, batch, sizeof(T) * std::min(m, n)); - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); // We ignore the output value of `info` because it is only used for shape // checking. int info; - FFI_RETURN_IF_ERROR_STATUS(GeqrfBatchedKernel::Run( + FFI_RETURN_IF_ERROR_STATUS(solver::GeqrfBatched( handle.get(), m, n, out_batch_ptrs, tau_batch_ptrs, &info, batch)); return ffi::Error::Success(); @@ -385,7 +283,6 @@ ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, return ffi::Error::InvalidArgument(absl::StrFormat( "Unsupported dtype %s in geqrf", absl::FormatStreamed(dataType))); } -} // namespace XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch, ffi::Ffi::Bind() @@ -398,34 +295,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch, // Householder transformations: orgqr -namespace { -#define ORGQR_KERNEL_IMPL(type, name) \ - template <> \ - struct OrgqrKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \ - int n, int k) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, m, n, k, /*A=*/nullptr, /*lda=*/m, \ - /*tau=*/nullptr, &lwork))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, int k, \ - type* a, type* tau, type* workspace, int lwork, \ - int* info) { \ - return JAX_AS_STATUS( \ - name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \ - } \ - } - -template -struct OrgqrKernel; -ORGQR_KERNEL_IMPL(float, gpusolverDnSorgqr); -ORGQR_KERNEL_IMPL(double, gpusolverDnDorgqr); -ORGQR_KERNEL_IMPL(gpuComplex, gpusolverDnCungqr); -ORGQR_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZungqr); -#undef ORGQR_KERNEL_IMPL - template ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size, gpuStream_t stream, ffi::ScratchAllocator& scratch, @@ -437,7 +306,7 @@ ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size, FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); FFI_ASSIGN_OR_RETURN(int lwork, - OrgqrKernel::BufferSize(handle.get(), m, n, k)); + solver::OrgqrBufferSize(handle.get(), m, n, k)); FFI_ASSIGN_OR_RETURN(auto workspace, AllocateWorkspace(scratch, lwork, "orgqr")); @@ -450,13 +319,13 @@ ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size, auto tau_data = static_cast(tau.untyped_data()); auto out_data = static_cast(out->untyped_data()); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } int out_step = m * n; for (auto i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(OrgqrKernel::Run( + FFI_RETURN_IF_ERROR_STATUS(solver::Orgqr( handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info)); out_data += out_step; tau_data += k; @@ -492,7 +361,6 @@ ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, return ffi::Error::InvalidArgument(absl::StrFormat( "Unsupported dtype %s in orgqr", absl::FormatStreamed(dataType))); } -} // namespace XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, ffi::Ffi::Bind() @@ -510,98 +378,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, // dispatches dynamically to both syevd and syevj depending on the problem // size and the algorithm selected by the user via the `algorithm` attribute. -namespace { -#define SYEVJ_KERNEL_IMPL(type, name) \ - template <> \ - struct SyevjKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, \ - gpusolverFillMode_t uplo, int n, \ - gpuSyevjInfo_t params) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ - /*w=*/nullptr, &lwork, params))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ - int n, type* a, RealType::Type* w, \ - type* workspace, int lwork, int* info, \ - gpuSyevjInfo_t params) { \ - return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \ - lwork, info, params)); \ - } \ - } - -template -struct SyevjKernel; -SYEVJ_KERNEL_IMPL(float, gpusolverDnSsyevj); -SYEVJ_KERNEL_IMPL(double, gpusolverDnDsyevj); -SYEVJ_KERNEL_IMPL(gpuComplex, gpusolverDnCheevj); -SYEVJ_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevj); -#undef SYEVJ_KERNEL_IMPL - -#define SYEVJ_BATCHED_KERNEL_IMPL(type, name) \ - template <> \ - struct SyevjBatchedKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, \ - gpusolverFillMode_t uplo, int n, \ - gpuSyevjInfo_t params, int batch) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ - /*w=*/nullptr, &lwork, params, batch))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ - int n, type* a, RealType::Type* w, \ - type* workspace, int lwork, int* info, \ - gpuSyevjInfo_t params, int batch) { \ - return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \ - lwork, info, params, batch)); \ - } \ - } - -template -struct SyevjBatchedKernel; -SYEVJ_BATCHED_KERNEL_IMPL(float, gpusolverDnSsyevjBatched); -SYEVJ_BATCHED_KERNEL_IMPL(double, gpusolverDnDsyevjBatched); -SYEVJ_BATCHED_KERNEL_IMPL(gpuComplex, gpusolverDnCheevjBatched); -SYEVJ_BATCHED_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevjBatched); -#undef SYEVJ_BATCHED_KERNEL_IMPL - -#define SYEVD_KERNEL_IMPL(type, name) \ - template <> \ - struct SyevdKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, \ - gpusolverFillMode_t uplo, int n) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ - /*w=*/nullptr, &lwork))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ - int n, type* a, RealType::Type* w, \ - type* workspace, int lwork, int* info) { \ - return JAX_AS_STATUS( \ - name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \ - } \ - } - -template -struct SyevdKernel; -SYEVD_KERNEL_IMPL(float, gpusolverDnSsyevd); -SYEVD_KERNEL_IMPL(double, gpusolverDnDsyevd); -SYEVD_KERNEL_IMPL(gpuComplex, gpusolverDnCheevd); -SYEVD_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevd); -#undef SYEVD_KERNEL_IMPL - template ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm, @@ -618,49 +394,48 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto w_data = static_cast::Type*>(w->untyped_data()); + auto w_data = static_cast::value*>(w->untyped_data()); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } if (algorithm == SyevdAlgorithm::kJacobi || (algorithm == SyevdAlgorithm::kDefault && size <= 32)) { gpuSyevjInfo_t params; - FFI_RETURN_IF_ERROR_STATUS( - JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms)); std::unique_ptr params_cleanup( params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); if (batch == 1) { - FFI_ASSIGN_OR_RETURN(int lwork, SyevjKernel::BufferSize( + FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize( handle.get(), jobz, uplo, n, params)); FFI_ASSIGN_OR_RETURN(auto workspace, AllocateWorkspace(scratch, lwork, "syevj")); - FFI_RETURN_IF_ERROR_STATUS( - SyevjKernel::Run(handle.get(), jobz, uplo, n, out_data, w_data, - workspace, lwork, info_data, params)); + FFI_RETURN_IF_ERROR_STATUS(solver::Syevj(handle.get(), jobz, uplo, n, + out_data, w_data, workspace, + lwork, info_data, params)); } else { FFI_ASSIGN_OR_RETURN( - int lwork, SyevjBatchedKernel::BufferSize(handle.get(), jobz, uplo, + int lwork, solver::SyevjBatchedBufferSize(handle.get(), jobz, uplo, n, params, batch)); FFI_ASSIGN_OR_RETURN( auto workspace, AllocateWorkspace(scratch, lwork, "syevj_batched")); - FFI_RETURN_IF_ERROR_STATUS(SyevjBatchedKernel::Run( - handle.get(), jobz, uplo, n, out_data, w_data, workspace, lwork, - info_data, params, batch)); + FFI_RETURN_IF_ERROR_STATUS( + solver::SyevjBatched(handle.get(), jobz, uplo, n, out_data, w_data, + workspace, lwork, info_data, params, batch)); } } else { FFI_ASSIGN_OR_RETURN( - int lwork, SyevdKernel::BufferSize(handle.get(), jobz, uplo, n)); + int lwork, solver::SyevdBufferSize(handle.get(), jobz, uplo, n)); FFI_ASSIGN_OR_RETURN(auto workspace, AllocateWorkspace(scratch, lwork, "syevd")); int out_step = n * n; for (auto i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS( - SyevdKernel::Run(handle.get(), jobz, uplo, n, out_data, w_data, - workspace, lwork, info_data)); + FFI_RETURN_IF_ERROR_STATUS(solver::Syevd(handle.get(), jobz, uplo, n, + out_data, w_data, workspace, + lwork, info_data)); out_data += out_step; w_data += n; ++info_data; @@ -695,7 +470,6 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, return ffi::Error::InvalidArgument(absl::StrFormat( "Unsupported dtype %s in syevd", absl::FormatStreamed(dataType))); } -} // namespace XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch, ffi::Ffi::Bind() @@ -709,110 +483,83 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch, .Ret>() // info ); -#define SYRK_KERNEL_IMPL(type, fn) \ - template <> \ - struct SyrkKernel { \ - static absl::Status Run(gpublasHandle_t handle, std::int64_t n, \ - std::int64_t k, bool transpose, \ - const type* alpha, const type* beta, \ - const type* a_matrix, type* c_matrix) { \ - gpublasOperation_t op = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; \ - gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; \ - int lda = transpose ? n : k; \ - return JAX_AS_STATUS(fn(handle, uplo, op, n, k, \ - alpha, a_matrix, lda, beta, \ - c_matrix, n)); \ - } \ - } - -template -struct SyrkKernel; - -SYRK_KERNEL_IMPL(float, gpublasSsyrk); -SYRK_KERNEL_IMPL(double, gpublasDsyrk); -SYRK_KERNEL_IMPL(gpublasComplex, gpublasCsyrk); -SYRK_KERNEL_IMPL(gpublasDoubleComplex, gpublasZsyrk); -#undef SYRK_KERNEL_IMPL +// Symmetric rank-k update: syrk template -ffi::Error SyrkImpl(gpuStream_t stream, - ffi::AnyBuffer a_matrix, - ffi::AnyBuffer c_matrix, - bool transpose, - ffi::AnyBuffer alpha, - ffi::AnyBuffer beta, - ffi::Result c_matrix_out) { +ffi::Error SyrkImpl(gpuStream_t stream, bool transpose, ffi::AnyBuffer a, + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha, + ffi::AnyBuffer beta, ffi::Result c_out) { FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), - SplitBatch2D(a_matrix.dimensions())); - FFI_ASSIGN_OR_RETURN((auto [batch_c, rows_c, cols_c]), - SplitBatch2D(c_matrix.dimensions())); - FFI_ASSIGN_OR_RETURN((auto [batch_out, rows_out, cols_out]), - SplitBatch2D(c_matrix_out->dimensions())); - if (batch != batch_c || batch != batch_out) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "a_matrix, c_matrix and c_matrix_out must have the same " - "batch size."); + SplitBatch2D(a.dimensions())); + if (alpha.element_count() != 1 || beta.element_count() != 1) { + return ffi::Error::InvalidArgument( + "The alpha and beta inputs to syrk must be scalars"); } - int n = transpose ? cols : rows; - int k = transpose ? rows : cols; - + auto size = transpose ? cols : rows; FFI_RETURN_IF_ERROR( - CheckShape(c_matrix_out->dimensions().last(2), {n, n}, "out", "Syrk")); + CheckShape(c_in.dimensions(), {batch, size, size}, "c_in", "syrk")); FFI_RETURN_IF_ERROR( - CheckShape(c_matrix.dimensions().last(2), {n, n}, "C", "Syrk")); + CheckShape(c_out->dimensions(), {batch, size, size}, "c_out", "syrk")); + + FFI_ASSIGN_OR_RETURN(auto n, + MaybeCastNoOverflow(transpose ? cols : rows)); + FFI_ASSIGN_OR_RETURN(auto k, + MaybeCastNoOverflow(transpose ? rows : cols)); + gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; + gpublasOperation_t trans = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; - const T* a_data = static_cast(a_matrix.untyped_data()); - T* c_data = static_cast(c_matrix.untyped_data()); - T* c_out_data = static_cast(c_matrix_out->untyped_data()); + const T* a_data = static_cast(a.untyped_data()); + T* c_data = static_cast(c_in.untyped_data()); + T* c_out_data = static_cast(c_out->untyped_data()); // with alpha or beta provided as device_pointers, cublassyrk will SIGSEGV T host_alpha; - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - &host_alpha, alpha.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost, - stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_alpha, alpha.untyped_data(), + sizeof(T), gpuMemcpyDeviceToHost, + stream)); T host_beta; - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - &host_beta, beta.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost, - stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_beta, beta.untyped_data(), + sizeof(T), gpuMemcpyDeviceToHost, + stream)); if (c_data != c_out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - c_out_data, c_data, c_matrix.size_bytes(), gpuMemcpyDeviceToDevice, - stream))); + JAX_FFI_RETURN_IF_GPU_ERROR( + gpuMemcpyAsync(c_out_data, c_data, c_in.size_bytes(), + gpuMemcpyDeviceToDevice, stream)); } FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); for (int i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(SyrkKernel::Run( - handle.get(), n, k, transpose, &host_alpha, &host_beta, - a_data + i * k * n, c_out_data + i * n * n)); + FFI_RETURN_IF_ERROR_STATUS(solver::Syrk(handle.get(), uplo, trans, n, k, + &host_alpha, a_data, &host_beta, + c_out_data)); + a_data += k * n; + c_out_data += n * n; } return ffi::Error::Success(); } -ffi::Error SyrkDispatch( - gpuStream_t stream, - ffi::AnyBuffer a_matrix, - ffi::AnyBuffer c_matrix, - bool transpose, - ffi::AnyBuffer alpha, - ffi::AnyBuffer beta, - ffi::Result c_matrix_out) { - auto dataType = a_matrix.element_type(); - SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, a_matrix, c_matrix, transpose, - alpha, beta, c_matrix_out); - return ffi::Error::InvalidArgument("Unsupported element type for Syrk"); +ffi::Error SyrkDispatch(gpuStream_t stream, bool transpose, ffi::AnyBuffer a, + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha, + ffi::AnyBuffer beta, + ffi::Result c_out) { + auto dataType = a.element_type(); + SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, transpose, a, c_in, alpha, beta, + c_out); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in syrk", absl::FormatStreamed(dataType))); } XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch, ffi::Ffi::Bind() .Ctx>() - .Arg() // a_matrix - .Arg() // c_matrix .Attr("transpose") // transpose - .Arg() // alpha - .Arg() // beta - .Ret()); // c_matrix_out + .Arg() // a + .Arg() // c_in + .Arg() // alpha + .Arg() // beta + .Ret() // c_out +); #undef SOLVER_DISPATCH_IMPL #undef SOLVER_BLAS_DISPATCH_IMPL diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index ce856ae5f83d..5987415224c7 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -168,6 +168,21 @@ cc_library( ], ) +cc_library( + name = "hipsolver_interface", + srcs = ["//jaxlib/gpu:solver_interface.cc"], + hdrs = ["//jaxlib/gpu:solver_interface.h"], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:hipblas", + "@local_config_rocm//rocm:hipsolver", + ], +) + cc_library( name = "hipsolver_kernels_ffi", srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], @@ -178,6 +193,7 @@ cc_library( ":hip_make_batch_pointers", ":hip_solver_handle_pool", ":hip_vendor", + ":hipsolver_interface", "//jaxlib:ffi_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", From 208116489f526bc2a1151a353d1ade79b5e3f73c Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Thu, 12 Sep 2024 02:35:19 -0700 Subject: [PATCH 466/702] Enable --- tests/layout_test.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/tests/layout_test.py b/tests/layout_test.py index 2f240195f22d..1603320d2531 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -47,8 +47,6 @@ def setUp(self): super().setUp() def test_auto_layout(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape1 = (128, 128) shape2 = (128, 128) @@ -114,8 +112,6 @@ def init(x, y): self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T) def test_default_layout(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -155,8 +151,6 @@ def f(x): out_shardings=DLL.AUTO).lower(sds).compile() def test_in_layouts_out_layouts(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (8, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -181,8 +175,6 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -246,6 +238,10 @@ def f(x, y): def test_aot_layout_mismatch(self): if jtu.test_device_matches(["gpu"]): + # The test fails on GPU because the compilation with both input and + # output set to auto layout is underspecified. The GPU compiler chooses + # the default layout as the input layout and that choice does not + # raise an exception. self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (256, 4, 2) @@ -416,8 +412,6 @@ def f(x): self.assertArraysEqual(out, inp.T) def test_device_put_user_concrete_layout(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") shape = (8, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -472,8 +466,6 @@ def test_incompatible_aval_error_device_put(self): jax.device_put(inp, l) def test_concrete_layout_in_shardings(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) shape = (16, 128) @@ -482,7 +474,9 @@ def test_concrete_layout_in_shardings(self): custom_dll = DLL(major_to_minor=(0, 1)) - @partial(jax.jit, in_shardings=Layout(custom_dll, s)) + @partial(jax.jit, + in_shardings=Layout(custom_dll, s), + out_shardings=Layout(DLL.AUTO)) def f(x): return x.T From de9b98e0a8fb27e04a2758866e62b1c1c0ab37b1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 12 Sep 2024 11:47:03 -0700 Subject: [PATCH 467/702] Delete jax.xla_computation since it's been 3 months since it was deprecated. PiperOrigin-RevId: 673938336 --- CHANGELOG.md | 12 +++ docs/jax.rst | 1 - jax/__init__.py | 9 +- jax/_src/api.py | 245 +--------------------------------------------- tests/api_test.py | 187 +++-------------------------------- 5 files changed, 30 insertions(+), 424 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df2b5813ce1a..1c29ae7dc6d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,18 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.33 +* Deletion: + * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation + in 0.4.30 JAX release. + Please use the AOT APIs to get the same functionality as `jax.xla_computation`. + * `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with + `jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`. + * You can also use `.out_info` property of `jax.stages.Lowered` to get the + output information (like tree structure, shape and dtype). + * For cross-backend lowering, you can replace + `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with + `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. + ## jaxlib 0.4.33 diff --git a/docs/jax.rst b/docs/jax.rst index b2c4ba60739b..a8781d31a448 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -69,7 +69,6 @@ Just-in-time compilation (:code:`jit`) jit disable_jit ensure_compile_time_eval - xla_computation make_jaxpr eval_shape ShapeDtypeStruct diff --git a/jax/__init__.py b/jax/__init__.py index 7e958b21c5dd..e2e302adb855 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -127,7 +127,6 @@ from jax._src.api import value_and_grad as value_and_grad from jax._src.api import vjp as vjp from jax._src.api import vmap as vmap -from jax._src.api import xla_computation as _deprecated_xla_computation from jax._src.sharding_impls import NamedSharding as NamedSharding from jax._src.sharding_impls import make_mesh as make_mesh @@ -224,20 +223,18 @@ "jax.clear_backends is deprecated.", _deprecated_clear_backends ), - # Added Jun 16, 2024 + # Remove after jax 0.4.35 release. "xla_computation": ( - "jax.xla_computation is deprecated. Please use the AOT APIs; see " + "jax.xla_computation is deleted. Please use the AOT APIs; see " "https://jax.readthedocs.io/en/latest/aot.html. For example, replace " "xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See " - "CHANGELOG.md for 0.4.30 for more examples.", - _deprecated_xla_computation + "CHANGELOG.md for 0.4.30 for more examples.", None ), } import typing as _typing if _typing.TYPE_CHECKING: from jax._src.api import clear_backends as clear_backends - from jax._src.api import xla_computation as xla_computation from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves diff --git a/jax/_src/api.py b/jax/_src/api.py index 935995ec5cba..8ed03e8e3c87 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -46,7 +46,6 @@ from jax._src import config from jax._src import core from jax._src import dispatch -from jax._src import effects from jax._src import array from jax._src import basearray from jax._src import distributed @@ -60,7 +59,7 @@ from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray from jax._src.api_util import ( flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, - argnums_partial_except, flatten_axes, donation_vector, + flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, shaped_abstractify, apply_flat_fun_nokwargs, check_callable, debug_info, result_paths, flat_out_axes, debug_info_final, fun_sourceinfo) @@ -73,13 +72,11 @@ from jax._src.layout import Layout, AutoLayout from jax._src.traceback_util import api_boundary from jax._src import tree_util -from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps, - split_list) +from jax._src.util import unzip2, safe_map, safe_zip, wraps, split_list from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src.interpreters import xla @@ -337,244 +334,6 @@ def disable_jit(disable: bool = True): yield -def xla_computation(fun: Callable, - static_argnums: int | Iterable[int] = (), - axis_env: Sequence[tuple[AxisName, int]] | None = None, - in_parts=None, out_parts=None, - backend: str | None = None, - tuple_args: bool = False, - instantiate_const_outputs: bool | None = None, - return_shape: bool = False, - donate_argnums: int | Iterable[int] = ()) -> Callable: - """Creates a function that produces its XLA computation given example args. - - .. warning:: - - This function is deprecated as of JAX v0.4.30, and will be removed in a future - JAX release. You can replace it with :ref:`ahead-of-time-lowering` APIs; for - example, ``jax.xla_computation(fn)(*args)`` can be replaced with - ``jax.jit(fn).lower(*args).compiler_ir('hlo')``. - See the `JAX 0.4.30 Change log`_ for more examples. - - Args: - fun: Function from which to form XLA computations. - static_argnums: See the :py:func:`jax.jit` docstring. - axis_env: Optional, a sequence of pairs where the first element is an axis - name and the second element is a positive integer representing the size of - the mapped axis with that name. This parameter is useful when lowering - functions that involve parallel communication collectives, and it - specifies the axis name/size environment that would be set up by - applications of :py:func:`jax.pmap`. See the examples below. - in_parts: Optional, how each argument to ``fun`` should be partitioned or - replicated. This is used to specify partitioned XLA computations, see - ``sharded_jit`` for more info. - out_parts: Optional, how each output of ``fun`` should be partitioned or - replicated. This is used to specify partitioned XLA computations, see - ``sharded_jit`` for more info. - backend: This is an experimental feature and the API is likely to change. - Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or - ``'tpu'``. - tuple_args: Optional bool, defaults to ``False``. If ``True``, the resulting - XLA computation will have a single tuple argument that is unpacked into - the specified function arguments. If `None`, tupling will be enabled when - there are more than 100 arguments, since some platforms have limits on - argument arity. - instantiate_const_outputs: Deprecated argument, does nothing. - return_shape: Optional boolean, defaults to ``False``. If ``True``, the - wrapped function returns a pair where the first element is the XLA - computation and the second element is a pytree with the same structure as - the output of ``fun`` and where the leaves are objects with ``shape`` and - ``dtype`` attributes representing the corresponding types of the output - leaves. - donate_argnums: Specify which arguments are "donated" to the computation. - It is safe to donate arguments if you no longer need them once the - computation has finished. In some cases XLA can make use of donated - buffers to reduce the amount of memory needed to perform a computation, - for example recycling one of your input buffers to store a result. You - should not reuse buffers that you donate to a computation, JAX will raise - an error if you try to. - - Returns: - A wrapped version of ``fun`` that when applied to example arguments returns - a built XLA Computation (see xla_client.py), from which representations of - the unoptimized XLA HLO computation can be extracted using methods like - ``as_hlo_text``, ``as_serialized_hlo_module_proto``, and - ``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the - wrapped function returns a pair where the first element is the XLA - Computation and the second element is a pytree representing the structure, - shapes, dtypes, and named shapes of the output of ``fun``. - - Concrete example arguments are not always necessary. For those arguments not - indicated by ``static_argnums``, any object with ``shape`` and ``dtype`` - attributes is acceptable (excepting namedtuples, which are treated as Python - containers). - - For example: - - >>> import jax - >>> - >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) - >>> c = jax.xla_computation(f)(3.) # doctest: +SKIP - >>> print(c.as_hlo_text()) # doctest: +SKIP - HloModule xla_computation_f.6 - - ENTRY xla_computation_f.6 { - constant.2 = pred[] constant(false) - parameter.1 = f32[] parameter(0) - cosine.3 = f32[] cosine(parameter.1) - sine.4 = f32[] sine(cosine.3) - ROOT tuple.5 = (f32[]) tuple(sine.4) - } - - - - - Alternatively, the assignment to ``c`` above could be written: - - >>> import types - >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) - >>> c = jax.xla_computation(f)(scalar) # doctest: +SKIP - - - Here's an example that involves a parallel collective and axis name: - - >>> def f(x): return x - jax.lax.psum(x, 'i') - >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) # doctest: +SKIP - >>> print(c.as_hlo_text()) # doctest: +SKIP - HloModule jaxpr_computation.9 - primitive_computation.3 { - parameter.4 = s32[] parameter(0) - parameter.5 = s32[] parameter(1) - ROOT add.6 = s32[] add(parameter.4, parameter.5) - } - ENTRY jaxpr_computation.9 { - tuple.1 = () tuple() - parameter.2 = s32[] parameter(0) - all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3 - ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7) - } - - - - Notice the ``replica_groups`` that were generated. Here's an example that - generates more interesting ``replica_groups``: - - >>> from jax import lax - >>> def g(x): - ... rowsum = lax.psum(x, 'i') - ... colsum = lax.psum(x, 'j') - ... allsum = lax.psum(x, ('i', 'j')) - ... return rowsum, colsum, allsum - ... - >>> axis_env = [('i', 4), ('j', 2)] - >>> c = jax.xla_computation(g, axis_env=axis_env)(5.) # doctest: +SKIP - >>> print(c.as_hlo_text()) # doctest: +SKIP - HloModule jaxpr_computation__1.19 - [removed uninteresting text here] - ENTRY jaxpr_computation__1.19 { - tuple.1 = () tuple() - parameter.2 = f32[] parameter(0) - all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3 - all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8 - all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13 - ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17) - } - - .. _JAX 0.4.30 Change log: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-30-june-18-2024 - """ - if instantiate_const_outputs is not None: - raise ValueError( - "instantiate_const_outputs has been deprecated. Please use the ahead of" - " time APIs. You can read more here:" - " https://jax.readthedocs.io/en/latest/aot.html") - if in_parts is not None: - raise ValueError( - "in_parts has been deprecated. Please use the ahead of time APIs. You" - " can read more here: https://jax.readthedocs.io/en/latest/aot.html") - if out_parts is not None: - raise ValueError( - "out_parts has been deprecated. Please use the ahead of time APIs. You" - " can read more here: https://jax.readthedocs.io/en/latest/aot.html") - - check_callable(fun) - static_argnums = _ensure_index_tuple(static_argnums) - donate_argnums = _ensure_index_tuple(donate_argnums) - donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums) - - fun_name = getattr(fun, "__name__", "unknown") - - platform = backend if backend is not None else xb.get_backend().platform - - def make_axis_env(nreps): - if axis_env is None: - return sharding_impls.AxisEnv(nreps, (), ()) - else: - nreps = nreps * math.prod(size for name, size in axis_env) - names, sizes = unzip2(axis_env) - return sharding_impls.AxisEnv(nreps, names, sizes) - - @wraps(fun) - @api_boundary - def computation_maker(*args, **kwargs): - if max(static_argnums + donate_argnums, default=-1) >= len(args): - raise ValueError(f"jitted function has {static_argnums=}, {donate_argnums=} but " - f"was called with only {len(args)} positional arguments.") - - f = lu.wrap_init(fun) - f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False) - args_flat, in_tree = tree_flatten((dyn_args, kwargs)) - if donate_argnums: - donated_invars = donation_vector(donate_argnums, (), in_tree) - else: - donated_invars = (False,) * len(args_flat) - - jaxtree_fun, out_tree = flatten_fun(f, in_tree) - avals = map(shaped_abstractify, args_flat) - with ExitStack() as stack: - for axis_name, size in axis_env or []: - stack.enter_context(core.extend_axis_env(axis_name, size, None)) - jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals) - jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) - if axis_env: - jaxpr = core.remove_named_axis_effects( - jaxpr, {axis_name for axis_name, _ in axis_env} - ) - axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr)) - ordered_effects = list( - effects.ordered_effects.filter_in(jaxpr.effects)) - lowering_result = mlir.lower_jaxpr_to_module( - f"xla_computation_{fun_name}", - core.ClosedJaxpr(jaxpr, consts), - ordered_effects=ordered_effects, - backend_or_name=backend, - platforms=[platform], - axis_context=sharding_impls.ReplicaAxisContext(axis_env_), - name_stack=source_info_util.new_name_stack( - wrap_name(fun_name, "xla_computation")), - donated_args=donated_invars, - arg_shardings=None, - result_shardings=None, - lowering_parameters=mlir.LoweringParameters()) - - m = mlir.module_to_bytecode(lowering_result.module) - built = xc._xla.mlir.mlir_module_to_xla_computation( - m, use_tuple_args=tuple_args, return_tuple=True) - out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals] - out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals] - out_shape = tree_unflatten(out_tree(), out_shapes_flat) - for out_aval in out_avals: - if not isinstance(out_aval, ShapedArray): - raise RuntimeError("As we want to propagate the weak_type, we need " - "to get a ShapedArray, otherwise this " - "information is lost") - - if return_shape: - return built, out_shape - else: - return built - - return computation_maker - def grad(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, diff --git a/tests/api_test.py b/tests/api_test.py index 1a119846be9c..fabdb9ffe503 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -50,7 +50,6 @@ from jax._src import config from jax._src import core from jax._src import custom_derivatives -from jax._src import deprecations from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import xla_bridge @@ -60,7 +59,6 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled -from jax._src.lib import xla_client from jax._src.lib import xla_extension import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint @@ -2904,74 +2902,6 @@ def test_jacfwd_of_complex_errors(self): r"sub-dtype of np.floating\), but got complex.*"), lambda: dfn(3. + 1j)) - def test_xla_computation(self): - # these tests basically check the examples in the xla_computation docstring - - def e(x): - return jnp.sin(jnp.cos(x)) - c = api.xla_computation(e)(2.) - self.assertIn('cosine', c.as_hlo_text()) - self.assertIn('sine', c.as_hlo_text()) - - def f(x): - return x - lax.psum(x, 'i') - axis_env = [('i', 4)] - c = api.xla_computation(f, axis_env=axis_env)(2) - self.assertIn('all-reduce', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1,2,3}}', c.as_hlo_text()) - - def g(x): - rowsum = lax.psum(x, 'i') - colsum = lax.psum(x, 'j') - allsum = lax.psum(x, ('i', 'j')) - return rowsum, colsum, allsum - axis_env = [('i', 4), ('j', 2)] - c = api.xla_computation(g, axis_env=axis_env)(5.) - self.assertIn('all-reduce', c.as_hlo_text()) - self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.as_hlo_text()) - - def h(x): - rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]) - colsum = lax.psum(x, 'j') - return rowsum, colsum - axis_env = [('i', 4), ('j', 2)] - c = api.xla_computation(h, axis_env=axis_env)(5.) - self.assertIn('all-reduce', c.as_hlo_text()) - self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text()) - - def test_xla_computation_args(self): - def foo(x, y, z): - return x + y + z - - c = api.xla_computation(foo)(1., 2., 3.) - self.assertEqual(len(c.program_shape().parameter_shapes()), 3) - - c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.) - param_shapes = c.program_shape().parameter_shapes() - self.assertEqual(len(param_shapes), 1) - self.assertEqual(param_shapes[0].xla_element_type(), - xla_client.PrimitiveType.TUPLE) - - def test_xla_computation_duck_typing(self): - def foo(x, y, z): - return x + y + z - - x = jax.ShapeDtypeStruct((), np.float32) - y = jax.ShapeDtypeStruct((), np.float32) - z = jax.ShapeDtypeStruct((), np.float32) - - c = api.xla_computation(foo)(x, y, z) - self.assertEqual(len(c.program_shape().parameter_shapes()), 3) - - c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.) - param_shapes = c.program_shape().parameter_shapes() - self.assertEqual(len(param_shapes), 1) - self.assertEqual(param_shapes[0].xla_element_type(), - xla_client.PrimitiveType.TUPLE) - def test_compiler_ir(self): # TODO(phawkins): merge these tests with the `xla_computation` tests. def e(x): @@ -2983,72 +2913,6 @@ def e(x): self.assertIn("stablehlo.cosine", stablehlo) self.assertIn("stablehlo.sine", stablehlo) - def test_staging_out_multi_replica(self): - def f(x): - return api.pmap(jnp.mean)(x) - xla_comp = api.xla_computation(f) - xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash - - def test_xla_computation_instantiate_constant_outputs(self): - def f(): - return jnp.zeros((3, 4)) - - xla_comp = api.xla_computation(f)() - out_shape, = xla_comp.program_shape().result_shape().tuple_shapes() - self.assertEqual(out_shape.dimensions(), (3, 4)) - - def test_xla_computation_static_argnums(self): - def f(x, y): - return x + y - - xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3) - hlo_text = xla_comp.as_hlo_text() - self.assertIn("constant(3)", hlo_text) - # The static arguments should be removed from the function being compiled, - # thus the function should have only a single argument. - self.assertIn("parameter(0)", hlo_text) - self.assertNotIn("parameter(1)", hlo_text) - - def test_xla_computation_return_shape(self): - _, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)), - return_shape=True)(np.int32(1)) - expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32), - api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32)) - self.assertEqual(shape_tree, expected) - - def test_xla_computation_psum_constant(self): - f = lambda: jax.lax.psum(1, "i") - api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash - - @jtu.ignore_warning(message="Some donated buffers were not usable") - def test_xla_computation_donate_argnums(self): - api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash - - def test_xla_computation_lower_fun_axis_env(self): - axis_name = 'i' - def fn(x): - y = lax.all_gather( - x, axis_name=axis_name) - return y * lax.axis_index(axis_name).astype(jnp.float32) - - input_x = jnp.ones((5,6,4), dtype=jnp.float32) - axis_env = [(axis_name, jax.local_device_count())] - _ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x) - - @jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated') - def test_xla_computation_axis_env(self): - is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation') - xla_computation = api.xla_computation if is_accelerated else jax.xla_computation - - def fn(x): - z = x * jax.lax.axis_index('i').astype(jnp.float32) - def inner_fn(carry, a): - return carry + a, () - return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z) - - x = jnp.ones((5, 6, 4), dtype=jnp.float32) - _ = xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x) - def test_concurrent_device_get_and_put(self): def f(x): for _ in range(100): @@ -3678,7 +3542,7 @@ def f(x): return x + y + y x = np.array([1, 2], dtype=np.float32) - hlo_lines = jax.xla_computation(f)(x).as_hlo_text().split('\n') + hlo_lines = jax.jit(f).lower(x).as_text('hlo').split('\n') hlo_lines = {s.strip() for s in hlo_lines} self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines) self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines) @@ -3805,11 +3669,6 @@ def g(x): with self.assertRaisesRegex(core.ConcretizationTypeError, msg): g(1) - def test_xla_computation_zeros_doesnt_device_put(self): - with jtu.count_device_put() as count: - api.xla_computation(lambda: jnp.zeros(3))() - self.assertEqual(count[0], 0) - def test_join_concrete_arrays_with_omnistaging(self): # https://github.com/google/jax/issues/4622 x = jnp.array([1., 2., 3.]) @@ -5532,13 +5391,12 @@ def f(x): x, _ = g(x) return x - c = api.xla_computation(f)(2.) - self.assertNotIn('while', c.as_hlo_text()) - self.assertNotIn('conditional', c.as_hlo_text()) - self.assertNotIn('opt-barrier', c.as_hlo_text()) + text = jax.jit(f).lower(2.).as_text('hlo') + self.assertNotIn('while', text) + self.assertNotIn('conditional', text) + self.assertNotIn('opt-barrier', text) - c = api.xla_computation(grad(f))(2.) - text = c.as_hlo_text() + text = jax.jit(grad(f)).lower(2.).as_text('hlo') self.assertTrue('while' in text or 'conditional' in text or 'opt-barrier' in text) @@ -5557,13 +5415,13 @@ def f(x): x, _ = g(x) return x - c = api.xla_computation(f)(2.) - self.assertNotIn('while', c.as_hlo_text()) - self.assertNotIn('conditional', c.as_hlo_text()) + text = jax.jit(f).lower(2.).as_text('hlo') + self.assertNotIn('while', text) + self.assertNotIn('conditional', text) - c = api.xla_computation(grad(f))(2.) - self.assertNotIn('while', c.as_hlo_text()) - self.assertNotIn('conditional', c.as_hlo_text()) + text = jax.jit(grad(f)).lower(2.).as_text('hlo') + self.assertNotIn('while', text) + self.assertNotIn('conditional', text) @parameterized.named_parameters( {"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat, @@ -6679,7 +6537,7 @@ def test_elide_trivial_broadcasts(self): self.assertLen(jaxpr.jaxpr.eqns, 0) def test_convert_element_type_literal_constant_folding(self): - # this convert_elemnt_type is nontrivial, but because it's on a scalar we + # this convert_element_type is nontrivial, but because it's on a scalar we # constant-fold it cet = partial(lax.convert_element_type, new_dtype='float16') jaxpr = api.make_jaxpr(lambda: cet(3.))() @@ -10966,25 +10824,6 @@ def test_pmap_nested_donate_ignored(self): class NamedCallTest(jtu.JaxTestCase): - @jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated') - def test_default_name(self): - is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation') - xla_computation = api.xla_computation if is_accelerated else jax.xla_computation - - @api.named_call - def my_test_function(x): - return x**2 - - @jax.jit - def f(x): - return my_test_function(x) - - c = xla_computation(f)(2) - print_opts = xla_client._xla.HloPrintOptions.short_parsable() - print_opts.print_metadata = True - hlo_text = c.as_hlo_module().to_string(print_opts) - self.assertIn("my_test_function", hlo_text) - def test_non_jaxtype_arg(self): # For the test to fail without the invalid JaxType filter we need to pass # in a valid JaxType that forces the invalid Jaxtype to be raised to an From 39c39acd19453deaa65b4cde3888e16fa3ba5556 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 12 Sep 2024 13:40:04 -0700 Subject: [PATCH 468/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/4aee555551c2be2e3e7891eab7b4343bf14ab279. PiperOrigin-RevId: 673990349 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 77f6f1bd978e..066e102c2f44 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "acef59a8b9454702b4a876027a505bc3362fa906" -XLA_SHA256 = "ae1963475613acc8c364553feb99ec1fec9eb6b817dda350d7c660ca9de2e6ae" +XLA_COMMIT = "4aee555551c2be2e3e7891eab7b4343bf14ab279" +XLA_SHA256 = "efecdfd85763d0374eb76b4948b2413f68fb154ba4a5827fa852afee659f08e5" def repo(): tf_http_archive( From 46a7be6d7a9e63d6fe97f516a8576fffc3db08ee Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Sep 2024 13:45:33 -0700 Subject: [PATCH 469/702] Lower the minimum jaxlib version to 0.4.31, since 0.4.32 was yanked from pypi. PiperOrigin-RevId: 673992844 --- jax/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/version.py b/jax/version.py index 0fc66d0d3006..ff8f14f2d05a 100644 --- a/jax/version.py +++ b/jax/version.py @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.32" +_minimum_jaxlib_version = "0.4.31" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) From c2adbf9e711ab148d485f97210d148a6a2168fe0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Sep 2024 20:50:04 +0000 Subject: [PATCH 470/702] Redisable complex arctan tests on older jaxlibs. --- tests/lax_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/lax_test.py b/tests/lax_test.py index 74fef668aa47..a29dcadeed0e 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3954,6 +3954,16 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'arccos': regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real') + elif name == 'arctan' and jax._src.lib.version <= (0, 4, 31): + if dtype == np.complex64: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', + 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.real', 'mnegj.imag', 'mposj.imag') + if dtype == np.complex128: + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mnegj.real') + + elif name == 'arctanh' and jax._src.lib.version <= (0, 4, 31): + regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag') + elif name in {'cos', 'sin'}: regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') From 42a295c735ffd4aff2182e05dd77757ff5126bc2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Sep 2024 21:01:44 +0000 Subject: [PATCH 471/702] Redisable one more complex arctan test. --- tests/lax_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index a29dcadeed0e..3d31bcb7d555 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3798,7 +3798,8 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind): size_im = 11 atol = None - if name in {"arccos", "arcsin", "arcsinh", "arccosh", "arctan", "arctanh"}: + if (name in {"arccos", "arcsin", "arcsinh", "arccosh"} + or name in {"arctan", "arctanh"} and jax._src.lib.version > (0, 4, 31)): # TODO(pearu): eliminate this if-block when a fix to mpmath#787 # becomes available extra_prec_multiplier = 20 From 5b9c73f371922c9148269c9a0a235236668a2787 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 12 Sep 2024 14:27:01 -0700 Subject: [PATCH 472/702] Fixed a static type error in Mosaic GPU lowering PiperOrigin-RevId: 674010791 --- jax/_src/pallas/core.py | 9 +++------ jax/_src/pallas/mosaic_gpu/core.py | 13 +++++++++---- jax/_src/pallas/mosaic_gpu/lowering.py | 3 ++- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 8dbf37587f8f..8ad3bca8c055 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -430,11 +430,10 @@ def __repr__(self): BlockSpecTree = Any -class MemrefTransform(Protocol): - """Represents a transformation applied to a Memref on load or store.""" +class MemoryRefTransform(Protocol): + """Transforms a memory reference on load or store.""" def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef: - """Returns the transformed aval given an input aval.""" raise NotImplementedError("Abstract evaluation not implemented.") @@ -451,9 +450,7 @@ class BlockMapping: indexing_mode: IndexingMode array_shape_dtype: jax.ShapeDtypeStruct # The whole array origin: OriginStr - transforms: Sequence[MemrefTransform] = dataclasses.field( - default_factory=tuple - ) + transforms: Sequence[MemoryRefTransform] = () def check_invariants(self) -> None: if not config.enable_checks.value: return diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 1a0c489af47d..76e3c6d1b3f6 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -17,7 +17,7 @@ from collections.abc import Sequence import dataclasses import enum -from typing import ClassVar, Literal +from typing import ClassVar, Literal, Protocol from jax import core as jax_core from jax._src import core from jax._src import tree_util @@ -59,8 +59,13 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): return MemoryRef(shape, dtype, self) -class TilingTransform(pallas_core.MemrefTransform): - """Represents a tiling transformation for Memrefs. +class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol): + def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform: + ... + + +class TilingTransform(MemoryRefTransform): + """Represents a tiling transformation for memory refs. A tiling of (X, Y) on an array of shape (M, N) will result in a transformed shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a @@ -125,7 +130,7 @@ def to_block_mapping( grid=grid, mapped_dims=mapped_dims, ) - transforms: tuple[pallas_core.MemrefTransform, ...] = () + transforms: tuple[pallas_core.MemoryRefTransform, ...] = () if self.tiling is not None: transforms += (TilingTransform(self.tiling),) return GPUBlockMapping( diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 87a147c91c81..4a71c594e261 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -262,7 +262,8 @@ def lower_jaxpr_to_module( for bm in block_mappings[:num_inputs] ] in_gmem_transforms = [ - bm.transforms for bm in grid_mapping.block_mappings[:num_inputs] + cast(gpu_core.MemoryRefTransform, bm.transforms) + for bm in grid_mapping.block_mappings[:num_inputs] ] _get_swizzle = ( lambda bm: bm.swizzle From 255c30303d32e7473262b2e35348175c87e4348f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Sep 2024 14:48:41 -0700 Subject: [PATCH 473/702] Fix a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs. For example, tree_map(..., None, [2, 3]) did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case. PiperOrigin-RevId: 674019460 --- CHANGELOG.md | 24 +++++++++++++++--------- tests/tree_util_test.py | 7 +++++++ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c29ae7dc6d9..223dde9123bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.33 -* Deletion: +* Changes + * `jax_pmap_no_rank_reduction` flag is set to `True` by default. + * array[0] on a pmap result now introduces a reshape (use array[0:1] + instead). + * The per-shard shape (accessable via jax_array.addressable_shards or + jax_array.addressable_data(0)) now has a leading (1, ...). Update code + that directly accesses shards accordingly. The rank of the per-shard-shape + now matches that of the global shape which is the same behavior as jit. + This avoids costly reshapes when passing results from pmap into jit. * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation in 0.4.30 JAX release. Please use the AOT APIs to get the same functionality as `jax.xla_computation`. @@ -23,6 +31,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * For cross-backend lowering, you can replace `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. + * `jax.tree.map(f, None, non-None)`, which previously emitted a + `DeprecationWarning`, now raises an error in a future version of jax. `None` + is only a tree-prefix of itself. To preserve the current behavior, you can + ask `jax.tree.map` to treat `None` as a leaf value by writing: + `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`. + ## jaxlib 0.4.33 @@ -35,14 +49,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. C++ and CUDA code from JAX. * Changes - * `jax_pmap_no_rank_reduction` flag is set to `True` by default. - * array[0] on a pmap result now introduces a reshape (use array[0:1] - instead). - * The per-shard shape (accessable via jax_array.addressable_shards or - jax_array.addressable_data(0)) now has a leading (1, ...). Update code - that directly accesses shards accordingly. The rank of the per-shard-shape - now matches that of the global shape which is the same behavior as jit. - This avoids costly reshapes when passing results from pmap into jit. * `jax_enable_memories` flag is set to `True` by default. * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. See {ref}`python-array-api` for more information. diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index bc741702ce58..2c4bb904eaa5 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -24,6 +24,7 @@ import jax from jax import flatten_util from jax import tree_util +from jax._src.lib import xla_extension_version from jax._src import test_util as jtu from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp @@ -395,6 +396,7 @@ def testFlattenOrder(self): ({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]), ({"a": 1}, {"a": (7,)}, [(7,)]), ({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]), + (None, None, []) ) def testFlattenUpTo(self, tree, xs, expected): _, tree_def = tree_util.tree_flatten(tree) @@ -483,6 +485,11 @@ def testFlattenUpTo(self, tree, xs, expected): [([1], (2,), {"a": [1]})], re.escape("Custom node type mismatch"), ), + *( + [] + if xla_extension_version < 284 + else [(None, [2], re.escape("Expected None, got [2]."))] + ), ) def testFlattenUpToErrors(self, tree, xs, error): _, tree_def = tree_util.tree_flatten(tree) From 358f00d5e046b8705a5ecb3f9e4f01b5f193e823 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 12 Sep 2024 23:03:06 +0000 Subject: [PATCH 474/702] shmap in_spec None shouldn't require hashability Co-authored-by: Roy Frostig --- jax/experimental/shard_map.py | 2 +- tests/shard_map_test.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index bf331bbb913f..8319e3fba70f 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -166,7 +166,7 @@ def wrapped(*args): raise e('shard_map in_specs') from None dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) if s is not None) - fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat) + fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat, False) _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat) in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index e9c23b3e5f0d..0c6848f94c03 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2164,6 +2164,19 @@ def f(x, y): with config.disable_vmap_shmap_error(): _ = jax.vmap(f, in_axes=(0, None), spmd_axis_name='i')(xs, y) + def test_in_spec_none_hashability(self): + mesh = jtu.create_mesh((2,), ('i',)) + + class A: + def __hash__(self): + raise Exception + + @partial(shard_map, mesh=mesh, in_specs=(None,), out_specs=()) + def f(a): + return () + + f(A()) # don't crash + class FunSpec(NamedTuple): name: str From 178fb03050e20e49660be5cd039a941a6905ed35 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 12 Sep 2024 16:55:53 -0700 Subject: [PATCH 475/702] [Mosaic TPU] Better error message when shape of memref bitcast is invalid. PiperOrigin-RevId: 674062237 --- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index bc1d30893537..d80db4e1394e 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -322,13 +322,21 @@ LogicalResult MemRefBitcastOp::verify() { auto src_dim_size = src_ty.getDimSize(i); auto tgt_dim_size = tgt_ty.getDimSize(i); if (i == src_ty.getRank() - 2) { - src_dim_size *= src_bitwidth; - tgt_dim_size *= tgt_bitwidth; - } - if (src_dim_size != tgt_dim_size) { - return emitOpError( - "Expected the same dim size on the 2nd minormost dim: ") - << src_dim_size << " vs " << tgt_dim_size; + auto src_bits = src_dim_size * src_bitwidth; + auto tgt_bits = tgt_dim_size * tgt_bitwidth; + if (src_bits != tgt_bits) { + return emitOpError( + "Expected the same number of bits on the 2nd minormost " + "dim: (") + << src_dim_size << " * " << src_bitwidth << ") vs (" + << tgt_dim_size << " * " << tgt_bitwidth << ")"; + ; + } + } else { + if (src_dim_size != tgt_dim_size) { + return emitOpError("Expected the same dim size on dim ") + << i << ": " << src_dim_size << " vs " << tgt_dim_size; + } } } // Source and target attributes may be different before propagation is done by From 8d93e101b90f7acc947685c735130f2b5290b593 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 12 Sep 2024 17:13:46 -0700 Subject: [PATCH 476/702] [Mosaic TPU] Propagate the memory space change for memref bitcast and reshape. PiperOrigin-RevId: 674067380 --- .../dialect/tpu/transforms/memory_space_specialization.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc index 37112666f542..569038500067 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc @@ -78,6 +78,14 @@ LogicalResult specializeMemorySpace(TypedValue value, updateResultFrom(op, op.getInput().getType()); continue; } + if (auto op = dyn_cast(some_op)) { + updateResultFrom(op, op.getInput().getType()); + continue; + } + if (auto op = dyn_cast(some_op)) { + updateResultFrom(op, op.getInput().getType()); + continue; + } if (auto op = dyn_cast(some_op)) { updateResultFrom(op, op.getOperand().getType()); continue; From dffac29e63de6a51047fe77cf9d553ab762ef19b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Sep 2024 18:13:51 -0700 Subject: [PATCH 477/702] Reverts 255c30303d32e7473262b2e35348175c87e4348f PiperOrigin-RevId: 674083626 --- CHANGELOG.md | 24 +++++++++--------------- tests/tree_util_test.py | 7 ------- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 223dde9123bf..1c29ae7dc6d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,15 +12,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.33 -* Changes - * `jax_pmap_no_rank_reduction` flag is set to `True` by default. - * array[0] on a pmap result now introduces a reshape (use array[0:1] - instead). - * The per-shard shape (accessable via jax_array.addressable_shards or - jax_array.addressable_data(0)) now has a leading (1, ...). Update code - that directly accesses shards accordingly. The rank of the per-shard-shape - now matches that of the global shape which is the same behavior as jit. - This avoids costly reshapes when passing results from pmap into jit. +* Deletion: * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation in 0.4.30 JAX release. Please use the AOT APIs to get the same functionality as `jax.xla_computation`. @@ -31,12 +23,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * For cross-backend lowering, you can replace `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. - * `jax.tree.map(f, None, non-None)`, which previously emitted a - `DeprecationWarning`, now raises an error in a future version of jax. `None` - is only a tree-prefix of itself. To preserve the current behavior, you can - ask `jax.tree.map` to treat `None` as a leaf value by writing: - `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`. - ## jaxlib 0.4.33 @@ -49,6 +35,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. C++ and CUDA code from JAX. * Changes + * `jax_pmap_no_rank_reduction` flag is set to `True` by default. + * array[0] on a pmap result now introduces a reshape (use array[0:1] + instead). + * The per-shard shape (accessable via jax_array.addressable_shards or + jax_array.addressable_data(0)) now has a leading (1, ...). Update code + that directly accesses shards accordingly. The rank of the per-shard-shape + now matches that of the global shape which is the same behavior as jit. + This avoids costly reshapes when passing results from pmap into jit. * `jax_enable_memories` flag is set to `True` by default. * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. See {ref}`python-array-api` for more information. diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 2c4bb904eaa5..bc741702ce58 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -24,7 +24,6 @@ import jax from jax import flatten_util from jax import tree_util -from jax._src.lib import xla_extension_version from jax._src import test_util as jtu from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp @@ -396,7 +395,6 @@ def testFlattenOrder(self): ({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]), ({"a": 1}, {"a": (7,)}, [(7,)]), ({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]), - (None, None, []) ) def testFlattenUpTo(self, tree, xs, expected): _, tree_def = tree_util.tree_flatten(tree) @@ -485,11 +483,6 @@ def testFlattenUpTo(self, tree, xs, expected): [([1], (2,), {"a": [1]})], re.escape("Custom node type mismatch"), ), - *( - [] - if xla_extension_version < 284 - else [(None, [2], re.escape("Expected None, got [2]."))] - ), ) def testFlattenUpToErrors(self, tree, xs, error): _, tree_def = tree_util.tree_flatten(tree) From 3d1d5e94ab327057e47e79f0284afddc87a88a5c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 12 Sep 2024 18:47:25 -0700 Subject: [PATCH 478/702] Remove the device assignment check in _resolve_in_shardings since that's historical and not needed anymore PiperOrigin-RevId: 674091716 --- jax/_src/numpy/lax_numpy.py | 2 +- jax/_src/pjit.py | 24 +++--------------------- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f2d90b95afa1..b234eb878376 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1340,7 +1340,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, When ``append = jnp.array([[3],[1]])``, it is appended to ``a`` along ``axis`` before computing the difference. - + >>> jnp.diff(a, append=jnp.array([[3],[1]])) Array([[ 4, -3, 7, -6], [ 5, -1, -3, -3]], dtype=int32) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ed5b825c62b4..dc88e42ec0a1 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -19,7 +19,6 @@ import dataclasses from functools import partial import inspect -import itertools as it import logging import operator as op import weakref @@ -1494,11 +1493,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): return tuple(resolved_in_layouts) -def _resolve_in_shardings( - args, pjit_in_shardings: Sequence[PjitSharding], - out_shardings: Sequence[PjitSharding], - pjit_mesh: pxla.Mesh | None, - check_device_assignment: bool = True) -> Sequence[PjitSharding]: +def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] + ) -> Sequence[PjitSharding]: # If True, means that device or backend is set by the user on pjit and it # has the same semantics as device_put i.e. doesn't matter which device the # arg is on, reshard it to the device mentioned. So don't do any of the @@ -1521,18 +1517,6 @@ def _resolve_in_shardings( if getattr(a, '_committed', True): committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) - # Check if the device_assignment across inputs, outputs and arguments is the - # same. - if check_device_assignment: - pxla._get_and_check_device_assignment( - it.chain( - util.stable_unique(committed_arg_shardings), - ((i, pxla.MismatchType.IN_SHARDING, None) - for i in util.stable_unique(pjit_in_shardings)), - ((o, pxla.MismatchType.OUT_SHARDING, None) - for o in util.stable_unique(out_shardings))), - (None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat))) - resolved_in_shardings = [] for arg, pjit_in_s in zip(args, pjit_in_shardings): # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does @@ -1602,9 +1586,7 @@ def _resolve_and_lower( args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, lowering_platforms, lowering_parameters, pgle_profiler): - in_shardings = _resolve_in_shardings( - args, in_shardings, out_shardings, - resource_env.physical_mesh if resource_env is not None else None) + in_shardings = _resolve_in_shardings(args, in_shardings) in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals) lowered = _pjit_lower( From 634fbb5bec29d77207d9cb3a7dd23cffd7932521 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 12 Sep 2024 19:02:57 -0700 Subject: [PATCH 479/702] Move `DeviceAssignmentMismatchError` exception catching code to `def lower` method of `Traced` so that all libraries calling `traced.lower()` see a better error message PiperOrigin-RevId: 674095608 --- jax/_src/pjit.py | 13 ++----------- jax/_src/stages.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index dc88e42ec0a1..fb76f7931c01 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -474,16 +474,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): @api_boundary def lower(*args, **kwargs): - traced = trace(*args, **kwargs) - try: - return traced.lower() - except pxla.DeviceAssignmentMismatchError as e: - fails, = e.args - fun_name = getattr(fun, '__qualname__', - getattr(fun, '__name__', str(fun))) - msg = _device_assignment_mismatch_error( - fun_name, fails, traced._args_flat, 'jit', traced._arg_names) - raise ValueError(msg) from None + return trace(*args, **kwargs).lower() @api_boundary def eval_shape(*args, **kwargs): @@ -503,7 +494,7 @@ def trace(*args, **kwargs) -> stages.Traced: lower_callable = partial(_resolve_and_lower, args_flat, **p.params, pgle_profiler=None) return stages.Traced( - p.params['jaxpr'], args_info, p.params["name"],p.out_tree, + p.params['jaxpr'], args_info, p.params["name"], p.out_tree, lower_callable, args_flat, p.arg_names, p.num_consts) wrapped = _cpp_pjit(fun, jit_info) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index b924072fc044..3a2c375b64db 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -734,12 +734,22 @@ def out_info(self): def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, _private_parameters: mlir.LoweringParameters | None = None): + from jax._src.interpreters import pxla + from jax._src import pjit + if _private_parameters is None: _private_parameters = mlir.LoweringParameters() new_callable = functools.partial( self._lower_callable, lowering_platforms=lowering_platforms, lowering_parameters=_private_parameters) - return Lowered(new_callable(), self.args_info, self._out_tree) + try: + lowering = new_callable() + except pxla.DeviceAssignmentMismatchError as e: + fails, = e.args + msg = pjit._device_assignment_mismatch_error( + self.fun_name, fails, self._args_flat, 'jit', self._arg_names) + raise ValueError(msg) from None + return Lowered(lowering, self.args_info, self._out_tree) @runtime_checkable From 16699952aaa0ee8365267a843fcf362de905b61d Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 12 Sep 2024 19:50:01 -0700 Subject: [PATCH 480/702] ParsedPartitionSpec needs to check that it is the proper instance type before comparing for equality or it will throw an exception in the later code. PiperOrigin-RevId: 674106064 --- jax/_src/sharding_impls.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index add297b6a351..af86425128c9 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1030,6 +1030,8 @@ def __hash__(self): return hash((self.partitions, self.sync)) def __eq__(self, other): + if not isinstance(other, ParsedPartitionSpec): + return False return (self.partitions == other.partitions and self.sync == other.sync) From 67980d6af48d3eb3d47c341fc0fd85fd159aa0db Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 13 Sep 2024 08:35:21 +0300 Subject: [PATCH 481/702] [export] Improve the forward compatibility documentation Update the documentation to use the `LoweringRuleContext.is_forward_compat` helper function. --- docs/export/export.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/export/export.md b/docs/export/export.md index 9e6597cef49b..0ca1a64800e0 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -732,10 +732,7 @@ that live in jaxlib): from jax._src.lib import version as jaxlib_version def my_lowering_rule(ctx: LoweringRuleContext, ...): - lowering_parameters = ctx.module_context.lowering_parameters - forward_compat_mode = (lowering_parameters.for_export and - not lowering_parameters.export_ignore_forward_compatibility) - if forward_compat_mode or jaxlib_version < (0, 4, 31): + if ctx.is_forward_compat() or jaxlib_version < (0, 4, 31): # this is the old lowering, using target T, while we # are in forward compatibility mode for T, or we # are in OSS and are using an old jaxlib. From e2d7ef2a49e4ce70972c038c042bb9e08f646ab3 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 13 Sep 2024 00:09:23 -0700 Subject: [PATCH 482/702] Pallas Mosaic GPU now supports scratch buffers in SMEM PiperOrigin-RevId: 674173250 --- jax/_src/pallas/core.py | 7 +-- jax/_src/pallas/mosaic_gpu/core.py | 26 +++++++++-- jax/_src/pallas/mosaic_gpu/lowering.py | 63 +++++++++++++++++--------- tests/pallas/mosaic_gpu_test.py | 18 ++++++++ 4 files changed, 84 insertions(+), 30 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 8ad3bca8c055..56c47b9401cc 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -126,12 +126,9 @@ def from_pallas_call(pallas_call_name: str | None, class AbstractMemoryRef(state.AbstractRef): __slots__ = ["inner_aval", "memory_space"] - def __init__(self, inner_aval: jax_core.AbstractValue, - memory_space: Any): + inner_aval: jax_core.ShapedArray - assert isinstance( - inner_aval, jax_core.ShapedArray - ), f"Illegal ref, got {type(inner_aval)}" + def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any): self.inner_aval = inner_aval self.memory_space = memory_space diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 76e3c6d1b3f6..99f7235fda4a 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -17,7 +17,7 @@ from collections.abc import Sequence import dataclasses import enum -from typing import ClassVar, Literal, Protocol +from typing import Any, ClassVar, Literal, Protocol from jax import core as jax_core from jax._src import core from jax._src import tree_util @@ -56,7 +56,7 @@ def __str__(self) -> str: def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. - return MemoryRef(shape, dtype, self) + return MemoryRef(shape, dtype, memory_space=self) class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol): @@ -146,6 +146,26 @@ def to_block_mapping( ) +@dataclasses.dataclass(init=False, kw_only=True) +class GPUGridSpec(pallas_core.GridSpec): + scratch_shapes: Sequence[Any] + + def __init__( + self, + grid: pallas_core.Grid = (), + in_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec, + out_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec, + scratch_shapes: Sequence[Any] = () + ): + super().__init__(grid, in_specs, out_specs) + self.scratch_shapes = tuple(scratch_shapes) + + def _make_scratch_aval(self, obj: object) -> core.AbstractValue: + if isinstance(obj, MemoryRef): + return obj.get_aval() + raise TypeError(f"Cannot convert {obj} to an abstract value") + + # TODO(b/354568887): Cosolidate this with TPU's MemoryRef. @dataclasses.dataclass(frozen=True) class MemoryRef: @@ -153,7 +173,7 @@ class MemoryRef: shape: tuple[int, ...] dtype: jnp.dtype - memory_space: GPUMemorySpace + memory_space: GPUMemorySpace = dataclasses.field(kw_only=True) def get_aval(self) -> AbstractMemoryRef: return AbstractMemoryRef( diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 4a71c594e261..f2c81234b073 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -19,6 +19,7 @@ from collections.abc import Sequence import dataclasses import functools +import itertools as it import math from typing import Any, cast @@ -231,8 +232,6 @@ def lower_jaxpr_to_module( if len(grid) < 3: grid += (1,) * (3 - len(grid)) block = (128,) + (1,) * (len(grid) - 1) - - num_inputs = grid_mapping.num_inputs params = compiler_params.get("mosaic_gpu", {}) num_stages = params.get("num_stages", 1) dimension_semantics = params.get( @@ -252,25 +251,26 @@ def lower_jaxpr_to_module( in_structs_gmem = [*grid_mapping.in_shapes] in_block_shapes = [ bm.block_shape - for bm in grid_mapping.block_mappings[:num_inputs] + for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] ] in_structs_smem = [ jax.ShapeDtypeStruct( - [num_stages, - *bm.ref_aval.inner_aval.shape], # pytype: disable=attribute-error - bm.ref_aval.inner_aval.dtype) # pytype: disable=attribute-error - for bm in block_mappings[:num_inputs] + [num_stages, *bm.ref_aval.inner_aval.shape], + bm.ref_aval.inner_aval.dtype, + ) + for bm in block_mappings[: grid_mapping.num_inputs] ] in_gmem_transforms = [ - cast(gpu_core.MemoryRefTransform, bm.transforms) - for bm in grid_mapping.block_mappings[:num_inputs] + cast(gpu_core.MemoryRefTransform, bm.transforms) + + for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] ] - _get_swizzle = ( + in_swizzles = map( lambda bm: bm.swizzle if isinstance(bm, gpu_core.GPUBlockMapping) - else None + else None, + grid_mapping.block_mappings[: grid_mapping.num_inputs], ) - in_swizzles = map(_get_swizzle, grid_mapping.block_mappings[:num_inputs]) out_structs_gmem = [*grid_mapping.out_shapes] # TODO(justinfu): Implement output Memref transforms out_structs_smem = [ @@ -283,12 +283,15 @@ def lower_jaxpr_to_module( def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value): *buffers_gmem, (*buffers_smem, runtime_smem, barriers) = buffers - assert len(buffers_gmem) == len(buffers_smem) + assert ( + len(buffers_gmem) + == len(buffers_smem) - grid_mapping.num_scratch_operands + ) in_buffers_gmem, out_buffers_gmem = util.split_list( buffers_gmem, [grid_mapping.num_inputs] ) - in_buffers_smem, out_buffers_smem = util.split_list( - buffers_smem, [grid_mapping.num_inputs] + in_buffers_smem, out_buffers_smem, scratch_buffers_smem = util.split_list( + buffers_smem, [grid_mapping.num_inputs, grid_mapping.num_outputs] ) module_ctx = ModuleContext( @@ -393,11 +396,12 @@ def _(step, _): # Only wait if async copies were issued. barriers[slot].wait() - _ = lower_jaxpr_to_mosaic_gpu( - module_ctx, - jaxpr, - [mgpu.memref_slice(b_smem, slot) for b_smem in buffers_smem], - ) + args = [ + mgpu.memref_slice(b_smem, slot) + for b_smem in it.chain(in_buffers_smem, out_buffers_smem) + ] + args.extend(scratch_buffers_smem) + _ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, args) mgpu.commit_shared() with mgpu.single_thread(): @@ -416,12 +420,27 @@ def _(step, _): launch_ctx.await_async_copy(0) + scratch_avals = [ + var.aval for var in jaxpr.invars[grid_mapping.slice_scratch_ops] + ] + if not all( + isinstance(aval, pallas_core.AbstractMemoryRef) + and aval.memory_space is gpu_core.SMEM + for aval in scratch_avals + ): + raise TypeError( + f"All scratch operands must be in SMEM, but got: {scratch_avals}" + ) + extra_smem_scratch = [ + jax.ShapeDtypeStruct(aval.shape, aval.dtype) for aval in scratch_avals + ] smem_scratch_bytes = compiler_params.get("smem_scratch_bytes") if smem_scratch_bytes is None: smem_scratch_bytes = _estimate_smem_scratch_bytes(jaxpr) - extra_smem_scratch = [ + extra_smem_scratch.append( jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8) - ] + ) + module, out_structs_smem, _ = mosaic_gpu._lower_as_gpu_kernel( body, grid=grid, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index e90033be151d..2c7277c6deb3 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -78,6 +78,24 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 2).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_add_one_grid_with_scratch(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), + grid_spec=plgpu.GPUGridSpec( + in_specs=[pl.BlockSpec((128,), lambda *i: i)], + out_specs=pl.BlockSpec((128,), lambda *i: i), + scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], + grid=2, + ), + ) + def kernel(x_ref, o_ref, scratch_ref): + scratch_ref[...] = x_ref[...] + 1 + o_ref[...] = scratch_ref[...] + + x = jnp.arange(256).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + @parameterized.product(num_stages=[1, 2, 3]) def test_add_one_grid_pipelined(self, num_stages): From 978905606154e88cabb9beca0abf6e20e6b592a4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 13 Sep 2024 02:05:56 -0700 Subject: [PATCH 483/702] Fix a small typo for the condition of scipy.entr. PiperOrigin-RevId: 674205855 --- jax/_src/scipy/special.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 3401edd9e112..837aa011f165 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -558,7 +558,7 @@ def entr(x: ArrayLike) -> Array: \mathrm{entr}(x) = \begin{cases} -x\log(x) & x > 0 \\ 0 & x = 0\\ - -\infty & x > 0 + -\infty & \mathrm{otherwise} \end{cases} Args: From 8159d3352c6818c77fd0f9c31699850f262ec500 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 13 Sep 2024 04:23:29 -0700 Subject: [PATCH 484/702] Updated :gpu_test configuration PiperOrigin-RevId: 674242448 --- tests/mosaic/BUILD | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index abab212d8618..6e5c94982d47 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -47,13 +47,16 @@ jax_test( name = "gpu_test", srcs = ["gpu_test.py"], config_tags_overrides = { - "gpu_h100_2gpu": { + "gpu_h100": { "ondemand": False, # Include in presubmit. }, }, disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, - enable_configs = ["gpu_h100_2gpu"], + enable_configs = [ + "gpu_h100", + "gpu_h100_2gpu", + ], shard_count = 4, deps = [ "//jax:mosaic_gpu", From 427a490d2b4147c41c84560792936d693d3c921a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 13 Sep 2024 04:31:56 -0700 Subject: [PATCH 485/702] Ported a few changes to FragmentArray by cperivol@ * It now supports unary negation * and pointwise operations between scalars and FragmentedArrays PiperOrigin-RevId: 674244294 --- jax/experimental/mosaic/gpu/fragmented_array.py | 12 +++++++++++- tests/mosaic/gpu_test.py | 2 ++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 7e0d43a05551..0b228833cbdb 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -245,7 +245,9 @@ def _pointwise(self, op, *other): other_arrs = [] for o in other: if not isinstance(o, FragmentedArray): - if not isinstance(o, ir.Value): + if isinstance(o, (float, int)): + o = utils.c(o, self.mlir_dtype) + elif not isinstance(o, ir.Value): raise NotImplementedError(o) o = FragmentedArray.splat(o, shape=self.shape, layout=self.layout) @@ -267,6 +269,14 @@ def _pointwise(self, op, *other): new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs)) return FragmentedArray(_registers=new_regs, _layout=self.layout) + def __neg__(self): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(arith.negf) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise(arith.negsi) + else: + raise NotImplementedError(self.mlir_dtype) + def __add__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise(arith.addf, other) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index ec9a7cd8b64e..1d6f6eb9e584 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1243,6 +1243,8 @@ def kernel(ctx, dst, _): (lambda x: mgpu.FragmentedArray.cos(x, approx=True), np.cos, True), (lambda x: mgpu.FragmentedArray.rsqrt(x), jax.lax.rsqrt, False), (lambda x: mgpu.FragmentedArray.rsqrt(x, approx=True), jax.lax.rsqrt, True), + (lambda x: -x, jax.lax.neg, False), + (lambda x: x + 42.0, lambda x: x + 42.0, False), ), m=(64, 128), n=(8, 16, 32, 64, 80, 128, 256), From db7484f392fb868a54f57b179d4b5f533aaf1c40 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 13 Sep 2024 05:37:41 -0700 Subject: [PATCH 486/702] Do a single mbarrier.arrive.expect_tx per fetch in Pallas Mosaic GPU PiperOrigin-RevId: 674260767 --- jax/_src/pallas/mosaic_gpu/lowering.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index f2c81234b073..ce1900521e51 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -306,6 +306,15 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value): start_indices, [grid_mapping.num_inputs] ) + # Precompute the total number of bytes transferred from GMEM to SMEM, + # so that we can do a single arrive instruction for all of the inputs. + in_transfer_bytes = 0 + for b_smem in in_buffers_smem: + b_smem_type = ir.MemRefType(b_smem.type) + in_transfer_bytes += math.prod(b_smem_type.shape[1:]) * mgpu.bytewidth( + b_smem_type.element_type + ) + def gmem_slice( start_indices: Sequence[ir.Value], step: ir.Value, @@ -337,7 +346,7 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: barrier=barriers[slot], gmem_transform=tuple(gmem_transforms), swizzle=in_swizzles[idx], - arrive=True, + arrive=False, # The caller must do ``arrive_expect_tx`` manually! uniform=False, ) @@ -386,6 +395,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: with mgpu.single_thread(): for slot in range(min(num_stages, num_steps)): + barriers[slot].arrive_expect_tx(in_transfer_bytes) for idx in range(grid_mapping.num_inputs): fetch(idx, _as_index(slot), _as_index(slot)) @@ -415,6 +425,7 @@ def _(step, _): with mgpu.when(next_step_in_bounds), mgpu.single_thread(): for idx in range(grid_mapping.num_inputs): fetch(idx, next_step, slot) + barriers[slot].arrive_expect_tx(in_transfer_bytes) return () @@ -452,10 +463,7 @@ def _(step, _): *in_structs_smem, *out_structs_smem, *extra_smem_scratch, - mgpu.Barrier( - arrival_count=len(in_structs_gmem), - num_barriers=num_stages, - ), + mgpu.Barrier(arrival_count=1, num_barriers=num_stages), ), module_name=name_and_src_info.name, ) From 8fa0e925dd035f013910dd6a4fb2211f45ebb34a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 13 Sep 2024 08:16:48 -0700 Subject: [PATCH 487/702] Added a docstring to `dce_jaxpr` PiperOrigin-RevId: 674304558 --- jax/_src/interpreters/partial_eval.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 5bb3e204ced0..5406aad172c4 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1536,6 +1536,18 @@ def _prune_closed_jaxpr_outputs( def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool], instantiate: bool | Sequence[bool] = False, ) -> tuple[Jaxpr, list[bool]]: + """Runs dead-code elementation on a given jaxpr. + + Args: + jaxpr: The jaxpr to DCE. + used_outputs: A list of bools indicating which outputs are used. + instantiate: A bool or a list of bools indicating which inputs should be + considered used, regardless of whether they are actually used in a jaxpr. + If a bool, the same value is used for all inputs. + + Returns: + A tuple of ``(new_jaxpr, used_inputs)``. + """ if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.invars) return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate)) @@ -1545,7 +1557,7 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], instantiate: bool | Sequence[bool] = False, ) -> tuple[Jaxpr, list[bool], list[bool]]: jaxpr_ = convert_constvars_jaxpr(jaxpr) - new_jaxpr, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs) + new_jaxpr, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs, instantiate) used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)]) if sum(used_consts): new_jaxpr = convert_invars_to_constvars(new_jaxpr, sum(used_consts)) From 40040e3f6915b77ab2b228056a29b6329c31a97c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 13 Sep 2024 08:19:57 -0700 Subject: [PATCH 488/702] Added a new `approx_math` flag to Mosaic GPU params in Pallas The flag allows to control the precision of some operations, e.g. `exp`. PiperOrigin-RevId: 674305430 --- jax/_src/pallas/mosaic_gpu/core.py | 5 ++++- jax/_src/pallas/mosaic_gpu/lowering.py | 16 +++++++++------- tests/pallas/mosaic_gpu_test.py | 13 +++++++++++++ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 99f7235fda4a..025b9f1b57d0 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -29,11 +29,13 @@ AbstractMemoryRef = pallas_core.AbstractMemoryRef -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class GPUCompilerParams(pallas_core.CompilerParams): """Mosaic GPU compiler parameters. Attributes: + approx_math: If True, the compiler is allowed to use approximate + implementations of some math operations, e.g. ``exp``. Defaults to False. dimension_semantics: A list of dimension semantics for each grid dimension of the kernel. Either "parallel" for dimensions that can execute in any order, or "sequential" for dimensions that must be @@ -42,6 +44,7 @@ class GPUCompilerParams(pallas_core.CompilerParams): meaning no pipelining is done. """ PLATFORM: ClassVar[str] = "mosaic_gpu" + approx_math: bool = False dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None num_stages: int = 1 diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ce1900521e51..7c8f2e85f27b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -97,8 +97,9 @@ def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int: class ModuleContext: name: str grid_mapping: pallas_core.GridMapping + approx_math: bool runtime_smem: ir.Value # ir.MemRefType - smem_used_bytes: int + smem_used_bytes: int = 0 # TODO(cperivol): Only return the shapes and figure out the sizes when freeing. def scratch_view( @@ -233,11 +234,12 @@ def lower_jaxpr_to_module( grid += (1,) * (3 - len(grid)) block = (128,) + (1,) * (len(grid) - 1) params = compiler_params.get("mosaic_gpu", {}) + approx_math = params.get("approx_math", False) num_stages = params.get("num_stages", 1) - dimension_semantics = params.get( - "dimension_semantics", ["parallel"] * len(grid_mapping.grid) - ) - if len(dimension_semantics) != len(grid_mapping.grid): + dimension_semantics = params.get("dimension_semantics") + if dimension_semantics is None: + dimension_semantics = ["parallel"] * len(grid_mapping.grid) + elif len(dimension_semantics) != len(grid_mapping.grid): raise ValueError( "dimension_semantics must have an entrey for each grid dimension:" f" {len(dimension_semantics)=}, but len(grid={grid_mapping.grid})." @@ -295,7 +297,7 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value): ) module_ctx = ModuleContext( - name_and_src_info.name, grid_mapping, runtime_smem, smem_used_bytes=0 + name_and_src_info.name, grid_mapping, approx_math, runtime_smem ) program_ids = map(_program_id, range(len(grid_mapping.grid))) start_indices = map( @@ -622,7 +624,7 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): @register_lowering_rule(lax.rsqrt_p) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): - return _ensure_fa(x, *ctx.avals_in).rsqrt() + return _ensure_fa(x, *ctx.avals_in).rsqrt(ctx.module_context.approx_math) @register_lowering_rule(lax.reduce_sum_p) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 2c7277c6deb3..0ee12edd883e 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -127,6 +127,19 @@ def kernel(x_ref, o_ref): x = jnp.arange(128).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + x.sum()*2) + @parameterized.parameters(False, True) + def test_rsqrt(self, approx_math): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jax.lax.rsqrt(x_ref[...]) + + x = jnp.arange(128).astype(jnp.float32) + np.testing.assert_allclose(kernel(x), jax.lax.rsqrt(x)) + @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) def test_layer_norm(self, input_factor): eps = 1e-5 From b886bd7300376efc59fd9a159fa83314ed8c51a4 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 13 Sep 2024 08:37:32 -0700 Subject: [PATCH 489/702] Removed the `named_shape` argument from `jex.core.ShapedArray` and `jax.ShapeDtypeStruct` It is unused and was only kept around to avoid breaking internal users. PiperOrigin-RevId: 674310795 --- CHANGELOG.md | 2 ++ jax/_src/api.py | 5 +-- jax/_src/array.py | 11 +++++++ jax/_src/core.py | 9 ++---- jax/_src/deprecations.py | 1 + tests/api_test.py | 63 ------------------------------------- tests/custom_object_test.py | 18 ++++------- 7 files changed, 23 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c29ae7dc6d9..78cb5daad284 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * For cross-backend lowering, you can replace `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. + * {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument. + The argument was only used by `xmap` which was removed in 0.4.31. ## jaxlib 0.4.33 diff --git a/jax/_src/api.py b/jax/_src/api.py index 8ed03e8e3c87..3c0f28532810 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2454,11 +2454,8 @@ class ShapeDtypeStruct: sharding: (optional) a :class:`jax.Sharding` object """ __slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type"] - named_shape = {} # type: ignore - def __init__(self, shape, dtype, named_shape=None, sharding=None, - weak_type=False): - del named_shape # ignored, vestigial + def __init__(self, shape, dtype, sharding=None, weak_type=False): self.shape = tuple(shape) if dtype is None: raise ValueError("ShapeDtypeStruct: dtype must be specified.") diff --git a/jax/_src/array.py b/jax/_src/array.py index 909f5acf0d43..9e5595aacca3 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -28,6 +28,7 @@ from jax._src import basearray from jax._src import config from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import errors @@ -115,6 +116,16 @@ def _reconstruct_array(fun, args, arr_state, aval_state): np_value = fun(*args) np_value.__setstate__(arr_state) jnp_value = api.device_put(np_value) + # TODO(slebedev): Remove this branch after December 10th 2024. + if "named_shape" in aval_state: + deprecations.warn( + "jax-aval-named-shape", + "Pickled array contains an aval with a named_shape attribute. This is" + " deprecated and the code path supporting such avals will be removed." + " Please re-pickle the array.", + stacklevel=2, + ) + del aval_state["named_shape"] jnp_value.aval = jnp_value.aval.update(**aval_state) return jnp_value diff --git a/jax/_src/core.py b/jax/_src/core.py index ef3ace2e0e31..6c3b4093b071 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1745,20 +1745,15 @@ def _invalid_shape_error(shape: Shape, context: str=""): class ShapedArray(UnshapedArray): __slots__ = ['shape', 'sharding'] # inherits slots from parent array_abstraction_level = 2 - named_shape = {} # type: ignore - def __init__(self, shape, dtype, weak_type=False, named_shape=None, - sharding=None): - del named_shape # unused, vestigial + def __init__(self, shape, dtype, weak_type=False, sharding=None): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type if config.sharding_in_types.value: self.sharding = sharding - def update(self, shape=None, dtype=None, weak_type=None, named_shape=None, - sharding=None): - del named_shape # unused, vestigial + def update(self, shape=None, dtype=None, weak_type=None, sharding=None): if shape is None: shape = self.shape if dtype is None: diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 4e01c88afd1f..10850357f677 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -121,6 +121,7 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: # Register a number of deprecations: we do this here to ensure they're # always registered by the time `accelerate` and `is_acelerated` are called. +register('jax-aval-named-shape') register('jax-dlpack-import-legacy') register("jax-numpy-astype-complex-to-real") register("jax-numpy-array-none") diff --git a/tests/api_test.py b/tests/api_test.py index fabdb9ffe503..0390e2e4b636 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2715,26 +2715,6 @@ def __init__(self, *args, **kwargs): out_shape = api.eval_shape(lambda x: x, x) # doesn't crash self.assertEqual(out_shape.shape, (3,)) - def test_eval_shape_names(self): - raise unittest.SkipTest("named shape are deprecated") - - def fun(x, y): - return lax.psum(x, 'i') + y - - class MyArgArray: - def __init__(self, shape, dtype, named_shape): - self.shape = shape - self.dtype = jnp.dtype(dtype) - self.named_shape = named_shape - - x = MyArgArray((3, 2), jnp.float32, {'i': 10}) - y = MyArgArray((3, 2), jnp.float32, {'j': 5}) - with core.extend_axis_env('i', 10, None): - with core.extend_axis_env('j', 5, None): - out_shape = api.eval_shape(fun, x, y) - - self.assertEqual(out_shape.named_shape, {'j': 5}) - def test_issue_871(self): T = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) x = jnp.array([1, 2, 3]) @@ -6466,49 +6446,6 @@ def f(x): jaxpr = api.make_jaxpr(f, axis_env=[('i', 4)])(2) self.assertIn('psum', str(jaxpr)) - def test_make_jaxpr_named(self): - raise unittest.SkipTest("named shape are deprecated") - def f(x): - return x - lax.psum(x, 'i') - - x = api.ShapeDtypeStruct( - shape=(2, 3), dtype=jnp.dtype(jnp.float32), named_shape={'i': 10}) - jaxpr = api.make_jaxpr(f, axis_env=[('i', 10)])(x) - named_shapes = [v.aval.named_shape for v in jaxpr.jaxpr.eqns[1].invars] - self.assertEqual(named_shapes, [{'i': 10}, {}]) - - @parameterized.parameters(True, False) - def test_vjp_reduce_axes_jaxpr(self, gy_batched): - raise unittest.SkipTest("reduce_axes autodiff is removed") - def f(w, x): - return jnp.sin(jnp.dot(x, w)) - - w = api.ShapeDtypeStruct( - shape=(3, 4), dtype=jnp.float32, named_shape={}) - x = api.ShapeDtypeStruct( - shape=(3,), dtype=jnp.float32, named_shape={'batch': 2}) - gy = api.ShapeDtypeStruct( - shape=(4,), dtype=jnp.float32, - named_shape={'batch': 2} if gy_batched else {}) - - # per-example - jaxpr, shapes = api.make_jaxpr( - lambda w, x, gy: api.vjp(f, w, x)[1](gy), axis_env=[('batch', 2)], - return_shape=True)(w, x, gy) - expected = (api.ShapeDtypeStruct( - shape=(3, 4), dtype=jnp.float32, named_shape={'batch': 2}), x) - self.assertEqual(shapes, expected) - self.assertNotIn('psum', str(jaxpr)) - - # reduced - jaxpr, shapes = api.make_jaxpr( - lambda w, x, gy: api.vjp(f, w, x, reduce_axes=('batch',))[1](gy), - axis_env=[('batch', 2)], - return_shape=True)(w, x, gy) - expected = (w, x) - self.assertEqual(shapes, expected) - self.assertIn('psum', str(jaxpr)) - def test_weak_type_jit_invariance(self): y = jnp.broadcast_to(3., (3,)) self.assertTrue(y.aval.weak_type) diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index 75ff39630705..4b1182e16b5a 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -68,20 +68,17 @@ def __repr__(self): class AbstractSparseArray(core.ShapedArray): __slots__ = ['index_dtype', 'nnz', 'data_aval', 'indices_aval'] - def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False, - named_shape=None): + def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False): super().__init__(shape, dtypes.canonicalize_dtype(dtype)) - named_shape = {} if named_shape is None else named_shape self.index_dtype = index_dtype self.nnz = nnz - self.data_aval = core.ShapedArray((nnz,), dtypes.canonicalize_dtype(dtype), - weak_type, named_shape) + self.data_aval = core.ShapedArray( + (nnz,), dtypes.canonicalize_dtype(dtype), weak_type) self.indices_aval = core.ShapedArray( - (nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype), - named_shape=named_shape) + (nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype)) def update(self, shape=None, dtype=None, index_dtype=None, nnz=None, - weak_type=None, named_shape=None): + weak_type=None): if shape is None: shape = self.shape if dtype is None: @@ -92,10 +89,7 @@ def update(self, shape=None, dtype=None, index_dtype=None, nnz=None, nnz = self.nnz if weak_type is None: weak_type = self.weak_type - if named_shape is None: - named_shape = self.named_shape - return AbstractSparseArray( - shape, dtype, index_dtype, nnz, weak_type, named_shape) + return AbstractSparseArray(shape, dtype, index_dtype, nnz, weak_type) def strip_weak_type(self): return self From 3f2bc9b60846b8a32c29804ba5e8caac7766add7 Mon Sep 17 00:00:00 2001 From: Joao Sousa-Pinto Date: Mon, 26 Aug 2024 17:25:16 -0700 Subject: [PATCH 490/702] Lower tan to StableHLO instead of CHLO. Fixes #23259 --- jax/_src/lax/lax.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8d2c24d6e64c..09f0b312df71 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -340,6 +340,10 @@ def cos(x: ArrayLike) -> Array: r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.""" return cos_p.bind(x) +def tan(x: ArrayLike) -> Array: + r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.""" + return tan_p.bind(x) + def atan2(x: ArrayLike, y: ArrayLike) -> Array: r"""Elementwise arc tangent of two variables: :math:`\mathrm{atan}({x \over y})`.""" @@ -1549,10 +1553,6 @@ def f_wrapped(x): return f_wrapped -def tan(x: ArrayLike) -> Array: - r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.""" - return tan_p.bind(x) - def asin(x: ArrayLike) -> Array: r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`.""" return asin_p.bind(x) @@ -2014,7 +2014,11 @@ def _tan_impl(x): tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) +def _lower_tan(ctx, x): + if ctx.is_forward_compat(): + return _nary_lower_hlo(chlo.tan, ctx, x) + return _nary_lower_hlo(hlo.tan, ctx, x) +mlir.register_lowering(tan_p, _lower_tan) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): From 83bccdd289f475466b9542aa4c8833499b2913af Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 13 Sep 2024 09:23:47 -0700 Subject: [PATCH 491/702] `sharding` and `weak_type` parameters of `ShapeDtypeStruct` are now keyword-only We decided not to go through a deprecation cycle for this change, because in the vast majority of cases internally these parameters are bound via a keyword argument anyway. PiperOrigin-RevId: 674324964 --- jax/_src/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 3c0f28532810..8ca3803aec35 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2455,7 +2455,7 @@ class ShapeDtypeStruct: """ __slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type"] - def __init__(self, shape, dtype, sharding=None, weak_type=False): + def __init__(self, shape, dtype, *, sharding=None, weak_type=False): self.shape = tuple(shape) if dtype is None: raise ValueError("ShapeDtypeStruct: dtype must be specified.") From 5b8d5ce3422882b2d605c7958fa1acd008b57561 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Fri, 13 Sep 2024 09:56:50 -0700 Subject: [PATCH 492/702] Fix some layout test failures on gpu backend PiperOrigin-RevId: 674336502 --- tests/layout_test.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/layout_test.py b/tests/layout_test.py index 1603320d2531..f14120e46116 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -47,6 +47,9 @@ def setUp(self): super().setUp() def test_auto_layout(self): + # Remove this condition when xla_extension_version >= 285 + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: + self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape1 = (128, 128) shape2 = (128, 128) @@ -112,6 +115,9 @@ def init(x, y): self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T) def test_default_layout(self): + # Remove this condition when xla_extension_version >= 285 + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: + self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -151,6 +157,9 @@ def f(x): out_shardings=DLL.AUTO).lower(sds).compile() def test_in_layouts_out_layouts(self): + # Remove this condition when xla_extension_version >= 285 + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: + self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (8, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -175,6 +184,9 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): + # Remove this condition when xla_extension_version >= 285 + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: + self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -466,6 +478,9 @@ def test_incompatible_aval_error_device_put(self): jax.device_put(inp, l) def test_concrete_layout_in_shardings(self): + # Remove this condition when xla_extension_version >= 285 + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: + self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) shape = (16, 128) From 28b5dee0324c1c951ee3ccf896a08676707303d8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 13 Sep 2024 10:02:44 -0700 Subject: [PATCH 493/702] Disable flaky tsan tests temporarily. PiperOrigin-RevId: 674338720 --- tests/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index d1fb4dcc7cde..4635a48cede1 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1187,7 +1187,10 @@ jax_test( shard_count = { "tpu": 5, }, - tags = ["noasan"], # Times out + tags = [ + "noasan", # Times out. + "notsan", # TODO(b/309111150): Re-enable after rolling forward cl/666056414. + ], deps = [ "//jax:experimental", "//jax:experimental_host_callback", From 0daca4646428abbdb728c48f73429460c5456f87 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 13 Sep 2024 13:00:55 -0700 Subject: [PATCH 494/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/32ebd694c4d0442e241d76324ff1a721831366b4. PiperOrigin-RevId: 674404604 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 066e102c2f44..63cd60688727 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "4aee555551c2be2e3e7891eab7b4343bf14ab279" -XLA_SHA256 = "efecdfd85763d0374eb76b4948b2413f68fb154ba4a5827fa852afee659f08e5" +XLA_COMMIT = "32ebd694c4d0442e241d76324ff1a721831366b4" +XLA_SHA256 = "d12ca6212b1bff774db28efaf292131d84e31205b720fa0055d5492a41d928ef" def repo(): tf_http_archive( From ee6f098fa9ae05ecf41cda31ab5bb3fc01e2f0e6 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 14 Sep 2024 02:31:31 -0700 Subject: [PATCH 495/702] [pallas] Clean up forward-compatibility conditionals in Pallas lowering In cl/657184114 (July 29th) I have made some changes in error reporting for invalid block shapes, but have left behind some conditionals to ensure forward compatibility. We are now out of the forward compatibility windows, and we clean up those conditionals. PiperOrigin-RevId: 674603915 --- jax/_src/pallas/mosaic/lowering.py | 52 +++++++----------------------- tests/pallas/pallas_test.py | 16 +++------ 2 files changed, 16 insertions(+), 52 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index bd897deb3d1f..3f631412d391 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -40,7 +40,6 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func @@ -448,20 +447,10 @@ def err_details(): f"and index_map returning {bm.index_map_jaxpr.jaxpr.outvars}, in " f"memory space {bm.block_aval.memory_space}." "\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec") - if lowering_context.is_forward_compat() or jaxlib_version < (0, 4, 32): - # TODO(b/356116061): Remove the old rank condition - if rank < 2: - raise ValueError( - "The Pallas TPU lowering currently supports only blocks of " - "rank >= 2 for blocks, except those in the SMEM memory space " - "having the same block shape as the array shape and a " - "trivial index_map (returning all 0s). " + err_details()) - else: - if rank < 1: - raise ValueError( - "The Pallas TPU lowering currently supports only blocks of " - "rank >= 1. " + err_details()) - + if rank < 1: + raise ValueError( + "The Pallas TPU lowering currently supports only blocks of " + "rank >= 1. " + err_details()) if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY and not bm.has_trivial_window()): @@ -476,34 +465,17 @@ def err_details(): bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2] else: bs1, as1 = 1, 1 - if lowering_context.is_forward_compat(): - # TODO(b/356116061): Remove the old divisibility condition - # With shape polymorphism block_shape is static, but the array shape may - # be symbolic. Write the divisibility comparisons to defer inequality - # comparisons on dimensions as much as possible. + + if rank >= 2: evenly_divisible = ( - (bs0 % 128 == 0 or (bs0 == as0 and as0 < 128)) and - (bs1 % 8 == 0 or (bs1 == as1 and as1 < 8)) + (bs0 == as0 or bs0 % 128 == 0) and + (bs1 == as1 or bs1 % 8 == 0) ) - if not evenly_divisible: - raise ValueError( - "The Pallas TPU lowering currently requires that the last two " - "dimensions of your block shape are divisible by 8 and 128 " - "respectively, if the respective dimensions of the overall array " - "are larger than the respective factors. If array dimensions are " - "smaller, the block should span the full array dimension. " - + err_details()) else: - if rank >= 2: - evenly_divisible = ( - (bs0 == as0 or bs0 % 128 == 0) and - (bs1 == as1 or bs1 % 8 == 0) - ) - else: - assert rank == 1 - # TODO(necula): test this for bool. What should it do? - tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype)) - evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0) + assert rank == 1 + # TODO(necula): test this for bool. What should it do? + tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype)) + evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0) if not evenly_divisible: raise ValueError( diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 1d3316760fe8..5ee30ba3382a 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -33,7 +33,6 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax.control_flow.for_loop import for_loop -from jax._src.lib import version as jaxlib_version from jax._src.pallas import core as pallas_core from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl @@ -371,17 +370,10 @@ def copy_kernel(x_ref, o_ref): test_context = contextlib.nullcontext() if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - if jaxlib_version < (0, 4, 32): - # TODO(b/356116061): Remove the old rank condition - if rank < 2: - test_context = self.assertRaisesRegex( - ValueError, - "TPU lowering currently supports only blocks of rank >= 2") - else: - if rank < 1: - test_context = self.assertRaisesRegex( - ValueError, - "TPU lowering currently supports only blocks of rank >= 1") + if rank < 1: + test_context = self.assertRaisesRegex( + ValueError, + "TPU lowering currently supports only blocks of rank >= 1") if rank >= 1: bs0, as0 = block_shape[-1], shape[-1] From 02bb3d1c845e519ff65190af4b2c0196c62af759 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 14 Sep 2024 17:26:23 +0000 Subject: [PATCH 496/702] tweak error logic to save a comment :) --- jax/_src/interpreters/partial_eval.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 5406aad172c4..2d27bf064fce 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2062,11 +2062,9 @@ def process_primitive(self, primitive, tracers, params): def default_process_primitive(self, primitive, tracers, params): avals = [t.aval for t in tracers] out_avals, effects = primitive.abstract_eval(*avals, **params) - # == serve as a "not xor" here. - if not (isinstance(out_avals, (tuple,list)) == primitive.multiple_results): - raise ValueError(f"{primitive}.abstract_eval() method should return" - f" a tuple or a list if {primitive}.multiple_results" - " is true. Otherwise it shouldn't.") + if isinstance(out_avals, (tuple, list)) != primitive.multiple_results: + raise ValueError(f"{primitive}.abstract_eval() method should return " + f"a tuple or a list iff {primitive}.multiple_results.") out_avals = [out_avals] if not primitive.multiple_results else out_avals source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] From 45d448c143c8930c33ead5d100408cfc39f8dd28 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 14 Sep 2024 12:50:39 -0700 Subject: [PATCH 497/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/dedab4f8cf4a6f892311d331ea0a6b93568ff2b0. PiperOrigin-RevId: 674697897 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 63cd60688727..e87fde31e659 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "32ebd694c4d0442e241d76324ff1a721831366b4" -XLA_SHA256 = "d12ca6212b1bff774db28efaf292131d84e31205b720fa0055d5492a41d928ef" +XLA_COMMIT = "dedab4f8cf4a6f892311d331ea0a6b93568ff2b0" +XLA_SHA256 = "7b1e9ba324184896a4646a97de2cd66b500e3fda2b3b95eaea1222b7a7036c47" def repo(): tf_http_archive( From fcc8c3759d9026143fd3166e71d0e70e8915a482 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Sat, 14 Sep 2024 15:22:12 -0500 Subject: [PATCH 498/702] Fixed func ref in shared-computation --- docs/developer.md | 2 +- docs/sharded-computation.ipynb | 2 +- docs/sharded-computation.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/developer.md b/docs/developer.md index af2e451a22ef..53b6f0cf0f45 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -697,7 +697,7 @@ using [jupytext](https://jupytext.readthedocs.io/) by running `jupytext --sync` notebooks; for example: ``` -pip install jupytext==1.16.0 +pip install jupytext==1.16.4 jupytext --sync docs/notebooks/thinking_in_jax.ipynb ``` diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index 60bf4d41a7a6..a4b6f2e0ced2 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -360,7 +360,7 @@ "\n", "## 2. Semi-automated sharding with constraints\n", "\n", - "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", + "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", "\n", "For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:" ] diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index ef4dc2d3288d..c273e23c771e 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -133,7 +133,7 @@ The result is partially replicated: that is, the first two elements of the array ## 2. Semi-automated sharding with constraints -If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. +If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices: From b8d135aa0592f05a9594ff47bc6e95fa4c44501c Mon Sep 17 00:00:00 2001 From: enerrio <12959255+enerrio@users.noreply.github.com> Date: Sat, 14 Sep 2024 13:53:19 -0700 Subject: [PATCH 499/702] fix small typos in docs --- docs/sharded-computation.ipynb | 2 +- docs/sharded-computation.md | 2 +- docs/stateful-computations.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index 60bf4d41a7a6..8629b481d228 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -60,7 +60,7 @@ "\n", "Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.\n", "\n", - "How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n", + "How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n", "\n", "In the simplest cases, arrays are sharded on a single device, as demonstrated below:" ] diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index ef4dc2d3288d..e02c5c185b53 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -39,7 +39,7 @@ jax.devices() Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices. -How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`. +How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`. In the simplest cases, arrays are sharded on a single device, as demonstrated below: diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 2eeffc30b255..4eb6e7a66cdd 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -144,7 +144,7 @@ This is because, like the strategy we just applied, object-oriented programming In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class? -Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, {mod}`jax.random`, shown in the :ref:`pseudorandom-numbers` section. +Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, {mod}`jax.random`, shown in the {ref}`pseudorandom-numbers` section. Unlike Numpy, which manages random state using implicitly updated stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key. From d60371b5db3320d9dee9783333dd645d86c40bdd Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Sun, 15 Sep 2024 10:33:05 +0530 Subject: [PATCH 500/702] Improve docs for jax.numpy: power and pow --- jax/_src/numpy/ufuncs.py | 60 ++++++++++++++++++++++++++++++++++++++-- tests/lax_numpy_test.py | 3 +- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 09546287cbdb..d54795f02b5d 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -1585,8 +1585,61 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]: return lax.round(div), mod -@implements(np.power, module='numpy') def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Calculate element-wise base ``x1`` exponential of ``x2``. + + JAX implementation of :obj:`numpy.power`. + + Args: + x1: scalar or array. Specifies the bases. + x2: scalar or array. Specifies the exponent. ``x1`` and ``x2`` should either + have same shape or be broadcast compatible. + + Returns: + An array containing the base ``x1`` exponentials of ``x2`` with same dtype + as input. + + Note: + - When ``x2`` is a concrete integer scalar, ``jnp.power`` lowers to + :func:`jax.lax.integer_pow`. + - When ``x2`` is a traced scalar or an array, ``jnp.power`` lowers to + :func:`jax.lax.pow`. + - ``jnp.power`` raises a ``TypeError`` for integer type raised to negative + integer power. + - ``jnp.power`` returns ``nan`` for negative value raised to the power of + non-integer values. + + See also: + - :func:`jax.lax.pow`: Computes element-wise power, :math:`x^y`. + - :func:`jax.lax.integer_pow`: Computes element-wise power :math:`x^y`, where + :math:`y` is a fixed integer. + - :func:`jax.numpy.float_power`: Computes the first array raised to the power + of second array, element-wise, by promoting to the inexact dtype. + - :func:`jax.numpy.pow`: Computes the first array raised to the power of second + array, element-wise. + + Examples: + Inputs with scalar integers: + + >>> jnp.power(4, 3) + Array(64, dtype=int32, weak_type=True) + + Inputs with same shape: + + >>> x1 = jnp.array([2, 4, 5]) + >>> x2 = jnp.array([3, 0.5, 2]) + >>> jnp.power(x1, x2) + Array([ 8., 2., 25.], dtype=float32) + + Inputs with broadcast compatibility: + + >>> x3 = jnp.array([-2, 3, 1]) + >>> x4 = jnp.array([[4, 1, 6], + ... [1.3, 3, 5]]) + >>> jnp.power(x3, x4) + Array([[16., 3., 1.], + [nan, 27., 1.]], dtype=float32) + """ check_arraylike("power", x1, x2) check_no_float0s("power", x1, x2) @@ -1616,8 +1669,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # Handle cases #2 and #3 under a jit: return _power(x1, x2) -# Array API alias -pow = power +def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.power`""" + return power(x1, x2) @partial(jit, inline=True) def _power(x1: ArrayLike, x2: ArrayLike) -> Array: diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 1323196feda0..704bb90116ea 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6288,7 +6288,8 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', - 'amax', 'amin', 'around', 'bitwise_right_shift', 'divide', 'round_'] + 'amax', 'amin', 'around', 'bitwise_right_shift', 'divide', 'pow', + 'round_'] skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split'] for name in dir(jnp): From 6bfa53d8c31a0936ca68942491d752b055dba984 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 15 Sep 2024 12:57:36 -0700 Subject: [PATCH 501/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/af733ec6fb9885ddebffdac13acf94b839e049df. PiperOrigin-RevId: 674918791 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e87fde31e659..efd799616237 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "dedab4f8cf4a6f892311d331ea0a6b93568ff2b0" -XLA_SHA256 = "7b1e9ba324184896a4646a97de2cd66b500e3fda2b3b95eaea1222b7a7036c47" +XLA_COMMIT = "af733ec6fb9885ddebffdac13acf94b839e049df" +XLA_SHA256 = "7ee9cd3a28d18b3fa2a699b6f5baae0d65351f3c631486b816d45296b1e3328a" def repo(): tf_http_archive( From 839ce9a11db994c65bc0b4f595fbadc5f4443e94 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Sun, 15 Sep 2024 17:52:43 -0700 Subject: [PATCH 502/702] [Pallas TPU] Refactor ref indexers to transforms and support ref bitcast. This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg: ``` ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]... ``` Jaxpr Preview (apply multiple transforms to same ref): ``` { lambda ; a:MemRef{int32[16,256]} b:MemRef{int32[8,128]}. let c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:] b[:,:] <- c in () } ``` Tested: * DMA with bitcasted ref * Load from bitcasted ref * Store to bitcasted ref * Multiple transforms * Interpret Mode for ref transforms (updated discharge rules) PiperOrigin-RevId: 674961388 --- jax/BUILD | 1 + jax/_src/pallas/mosaic/lowering.py | 159 ++++++++---- jax/_src/pallas/mosaic/primitives.py | 331 ++++++++++++++----------- jax/_src/pallas/mosaic/verification.py | 14 +- jax/_src/pallas/primitives.py | 22 +- jax/_src/state/__init__.py | 16 +- jax/_src/state/discharge.py | 154 +++++++----- jax/_src/state/primitives.py | 215 ++++++++++------ jax/_src/state/types.py | 70 +++++- jax/_src/state/utils.py | 46 +++- tests/pallas/tpu_ops_test.py | 39 +-- tests/pallas/tpu_pallas_test.py | 95 +++++++ 12 files changed, 780 insertions(+), 382 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 74072dc44644..c6d8fe25af59 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -848,6 +848,7 @@ pytype_strict_library( ], deps = [ ":core", + ":dtypes", ":effects", ":pretty_printer", ":tree_util", diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 3f631412d391..13d861033e90 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -47,8 +47,8 @@ from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector -from jax._src.pallas import pallas_call from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core @@ -57,6 +57,9 @@ from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives +from jax._src.state.types import RefBitcaster +from jax._src.state.utils import dtype_bitwidth +from jax._src.typing import DTypeLike from jax._src.util import safe_map from jax._src.util import safe_zip from jax._src.util import split_list @@ -957,11 +960,12 @@ def _indexer_to_start_size_stride( ) -def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef, - indexer: NDIndexer, - ref_block_shape: tuple[int | pallas_core.Mapped, ...] - ) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...], - tuple[int | pallas_core.Mapped, ...]]: +def _slice_memref( + ref: ir.Value, + indexer: NDIndexer, + ref_dtype: DTypeLike, + ref_block_shape: tuple[int | pallas_core.Mapped, ...], +) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...]]: assert ref_block_shape is not None target_shape = indexer.get_indexer_shape() starts, sizes, strides, squeeze_dims, ref_block_shape = ( @@ -978,26 +982,79 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef, static_sizes = tuple(s if not isinstance(s, ir.Value) else ir_dynamic_size for s in sizes) target_ref_ty = ir.MemRefType.get( - static_sizes, _dtype_to_ir_type(ref_aval.dtype), - memory_space=ref.type.memory_space) + static_sizes, + _dtype_to_ir_type(ref_dtype), + memory_space=ref.type.memory_space, + ) out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, dynamic_sizes).result if any(squeeze_dims): # We need to squeeze out some dimensions static_sizes = tuple(s if not isinstance(s, ir.Value) else ir_dynamic_size for s in target_shape) squeezed_ref_ty = ir.MemRefType.get( - static_sizes, _dtype_to_ir_type(ref_aval.dtype), - memory_space=ref.type.memory_space) + static_sizes, + _dtype_to_ir_type(ref_dtype), + memory_space=ref.type.memory_space, + ) out = tpu.MemRefSqueezeOp(squeezed_ref_ty, out).result return out, ref_block_shape -def _index_ref(ref, ref_aval, ref_block_shape, indexers): - for indexer in indexers: - ref, ref_block_shape = _slice_memref(ref, ref_aval, indexer, - ref_block_shape) +def _bitcast_memref( + ref: ir.Value, + bitcaster: RefBitcaster, + ref_dtype: DTypeLike, + ref_block_shape: tuple[int | pallas_core.Mapped, ...], +) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]: + src_bitwidth = dtype_bitwidth(ref_dtype) + dst_bitwidth = dtype_bitwidth(bitcaster.dtype) + if src_bitwidth != dst_bitwidth: + if len(ref_block_shape) < 2: + raise NotImplementedError( + "Bitcast 1D ref with bitwidth change is not supported." + ) + if ref_block_shape[-2] is pallas_core.mapped: + raise NotImplementedError( + "Bitcast a ref whose 2nd minormost dimension is squeezed when" + " bitwidth changes." + ) + new_ref_dtype = bitcaster.dtype + target_ref_ty = ir.MemRefType.get( + bitcaster.shape, + _dtype_to_ir_type(new_ref_dtype), + memory_space=ref.type.memory_space, + ) + new_ref_block_shape = list(ref_block_shape) + if ( + len(new_ref_block_shape) >= 2 + and new_ref_block_shape[-2] is not pallas_core.mapped + ): + new_ref_block_shape[-2] = ( + new_ref_block_shape[-2] * src_bitwidth // dst_bitwidth + ) + return ( + tpu.memref_bitcast(target_ref_ty, ref), + new_ref_dtype, + tuple(new_ref_block_shape), + ) + + +def _transform_ref(ref, ref_dtype, ref_block_shape, transforms): + for transform in transforms: + match transform: + case NDIndexer(): + ref, ref_block_shape = _slice_memref( + ref, transform, ref_dtype, ref_block_shape + ) + case RefBitcaster(): + ref, ref_dtype, ref_block_shape = _bitcast_memref( + ref, transform, ref_dtype, ref_block_shape + ) + case _: + raise NotImplementedError(f"Unsupported transform: {transform}") return ref, ref_block_shape + @dataclasses.dataclass(frozen=True) class KeyScalarBundle: """A container class for PRNG key data. @@ -1016,21 +1073,21 @@ class KeyScalarBundle: scalars: list[ir.OpResult] def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): - ref, indexers, mask, _ = args_tree.unflatten(args_flat) - ref_aval, indexers_avals, _, _ = args_tree.unflatten(ctx.avals_in) - (*slice_indexers, idx) = indexers + ref, transforms, mask, _ = args_tree.unflatten(args_flat) + ref_aval, transforms_avals, _, _ = args_tree.unflatten(ctx.avals_in) + (*prev_transforms, idx) = transforms # Select last aval, which is the one that will be used for the load. - (*_, idx_aval) = indexers_avals + (*_, idx_aval) = transforms_avals if mask is not None: raise NotImplementedError ref_block_shape, *_ = ctx.block_shapes - ref, ref_block_shape = _index_ref( - ref, ref_aval, ref_block_shape, slice_indexers) + ref, ref_block_shape = _transform_ref( + ref, ref_aval.dtype, ref_block_shape, prev_transforms + ) ref_type = ir.MemRefType(ref.type) is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space" - ref_aval, *_ = ctx.avals_in (aval_out,) = ctx.avals_out if isinstance(aval_out.dtype, prng.KeyTy): if not is_smem_load: @@ -1064,7 +1121,7 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): raise ValueError( "Loads are only allowed on VMEM and SMEM references." + extra ) - load_aval = jax_core.ShapedArray(sizes, dtype=ref_aval.dtype) + load_aval = jax_core.ShapedArray(sizes, dtype=aval_out.dtype) if need_stride: load_val = tpu.StridedLoadOp( aval_to_ir_type(load_aval, is_kernel_boundary=True), ref, starts, strides @@ -1159,17 +1216,18 @@ def _maybe_cast_store_to_memref_type( def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, **_ ): - ref, indexers, val, mask = args_tree.unflatten(args_flat) - ref_aval, indexers_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in) - (*slice_indexers, idx) = indexers - (*_, idx_aval) = indexers_avals + ref, transforms, val, mask = args_tree.unflatten(args_flat) + ref_aval, transforms_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in) + (*prev_transforms, idx) = transforms + (*_, idx_aval) = transforms_avals if mask is not None: raise NotImplementedError ref_block_shape, *_ = ctx.block_shapes - ref, ref_block_shape = _index_ref( - ref, ref_aval, ref_block_shape, slice_indexers) + ref, ref_block_shape = _transform_ref( + ref, ref_aval.dtype, ref_block_shape, prev_transforms + ) ref_type = ir.MemRefType(ref.type) is_smem_store = str(ref_type.memory_space) == "#tpu.memory_space" @@ -2553,8 +2611,8 @@ def _semaphore_read_lowering_rule( args_tree, ): sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) - sem, indexers = tree_util.tree_unflatten(args_tree, args) - sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + sem, transforms = tree_util.tree_unflatten(args_tree, args) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) return tpu.SemaphoreReadOp(sem).result @@ -2567,8 +2625,10 @@ def _semaphore_signal_lowering_rule( device_id_type: tpu_primitives.DeviceIdType, ): sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) - sem, indexers, value, device_id, core_index = tree_util.tree_unflatten(args_tree, args) - sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + sem, transforms, value, device_id, core_index = tree_util.tree_unflatten( + args_tree, args + ) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) return tpu.SemaphoreSignalOp( @@ -2582,8 +2642,8 @@ def _semaphore_signal_lowering_rule( def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) - sem, indexers, value = tree_util.tree_unflatten(args_tree, args) - sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + sem, transforms, value = tree_util.tree_unflatten(args_tree, args) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) return tpu.SemaphoreWaitOp(sem, value).results lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule @@ -2591,13 +2651,13 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, device_id_type: tpu_primitives.DeviceIdType): ( src_ref, - src_indexers, + src_transforms, dst_ref, - dst_indexers, + dst_transforms, sem, - sem_indexers, + sem_transforms, src_sem, - src_sem_indexers, + src_sem_transforms, device_id, ) = tree_util.tree_unflatten(tree, args) (src_ref_aval, _, dst_ref_aval, _, sem_aval, _, src_sem_aval, _, _) = ( @@ -2607,16 +2667,17 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, raise NotImplementedError("DMAs with bool dtypes are not supported.") block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes) src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2] - src_ref, _ = _index_ref( - src_ref, src_ref_aval, src_ref_block_shape, src_indexers + src_ref, _ = _transform_ref( + src_ref, src_ref_aval.dtype, src_ref_block_shape, src_transforms ) if src_sem is not None: - src_sem, _ = _index_ref( - src_sem, src_sem_aval, src_sem_aval.shape, src_sem_indexers) - dst_ref, _ = _index_ref( - dst_ref, dst_ref_aval, dst_ref_block_shape, dst_indexers + src_sem, _ = _transform_ref( + src_sem, src_sem_aval.dtype, src_sem_aval.shape, src_sem_transforms + ) + dst_ref, _ = _transform_ref( + dst_ref, dst_ref_aval.dtype, dst_ref_block_shape, dst_transforms ) - sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) return tpu.EnqueueDMAOp(src_ref, dst_ref, sem, source_semaphore=src_sem, @@ -2627,14 +2688,12 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, device_id_type: tpu_primitives.DeviceIdType): del device_id_type - sem, sem_indexers, ref, indexers = tree_util.tree_unflatten(tree, args) + sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, args) sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in) block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes) ref_block_shape = block_shapes[2] - ref, _ = _index_ref( - ref, ref_aval, ref_block_shape, indexers - ) - sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) + ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) return tpu.WaitDMAOp(sem, ref).results lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 348820907ed0..aab214a2d700 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -33,6 +33,7 @@ from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as sp +from jax._src.state.types import Transform from jax._src.typing import DTypeLike import jax.numpy as jnp @@ -164,17 +165,21 @@ class DeviceIdType(enum.Enum): LOGICAL = "logical" -def check_sem_avals(sem_aval, sem_indexers_avals, name, allowed_semaphore_types=None): +def check_sem_avals( + sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None +): if allowed_semaphore_types is None: - allowed_semaphore_types = {tpu_core.semaphore, - tpu_core.barrier_semaphore, - # For interpret mode. - pl_core.SEMAPHORE_INTERPRET_DTYPE} + allowed_semaphore_types = { + tpu_core.semaphore, + tpu_core.barrier_semaphore, + # For interpret mode. + pl_core.SEMAPHORE_INTERPRET_DTYPE, + } if not isinstance(sem_aval, state.AbstractRef): raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") sem_shape = sem_aval.shape - if sem_indexers_avals: - sem_shape = sem_indexers_avals[-1].get_indexer_shape() + if sem_transforms_avals: + sem_shape = sem_transforms_avals[-1].get_indexer_shape() if sem_shape: raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") sem_dtype = sem_aval.dtype @@ -187,10 +192,11 @@ def check_sem_avals(sem_aval, sem_indexers_avals, name, allowed_semaphore_types= f" {allowed_semaphore_types}." ) -def _index_semaphore(ref_value, indexers, ref_aval): + +def _transform_semaphore(ref_value, transforms, ref_aval): """Helper function for indexing into a semaphore during state_discharge.""" if ref_value.shape == ref_aval.shape: - return state_discharge.index_array(ref_value, indexers) + return state_discharge.transform_array(ref_value, transforms) elif len(ref_value.shape) == 0: return ref_value else: @@ -199,13 +205,14 @@ def _index_semaphore(ref_value, indexers, ref_aval): f" {ref_aval.shape}" ) + semaphore_read_p = jax_core.Primitive("semaphore_read") semaphore_read_p.multiple_results = False def semaphore_read(sem_or_view): - ref, indexers = _get_ref_and_indexers(sem_or_view) - args = [ref, indexers] + ref, transforms = _get_ref_and_transforms(sem_or_view) + args = [ref, transforms] flat_args, args_tree = tree_util.tree_flatten(args) return semaphore_read_p.bind(*flat_args, args_tree=args_tree) @@ -214,10 +221,10 @@ def _semaphore_read_abstract_eval( *avals, args_tree, ): - sem_aval, sem_indexers_avals = tree_util.tree_unflatten(args_tree, avals) + sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) check_sem_avals( sem_aval, - sem_indexers_avals, + sem_transforms_avals, "read", allowed_semaphore_types={ tpu_core.dma_semaphore, @@ -233,8 +240,8 @@ def _semaphore_read_discharge_rule(in_avals, *flat_args, args_tree): del out_avals - [ref, indexers] = args_tree.unflatten(flat_args) - sem_value = _index_semaphore(ref, indexers, in_avals[0]) + [ref, transforms] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) sem_value = sem_value.astype(jnp.int32) return (None,) * len(in_avals), sem_value state_discharge.register_discharge_rule(semaphore_read_p)( @@ -254,9 +261,9 @@ def semaphore_signal( device_id_type: DeviceIdType = DeviceIdType.MESH, core_index: int | jax.Array | None = None, ): - ref, indexers = _get_ref_and_indexers(sem_or_view) + ref, transforms = _get_ref_and_transforms(sem_or_view) inc = jnp.asarray(inc, dtype=jnp.int32) - args = [ref, indexers, inc, device_id, core_index] + args = [ref, transforms, inc, device_id, core_index] flat_args, args_tree = tree_util.tree_flatten(args) semaphore_signal_p.bind( *flat_args, @@ -272,10 +279,14 @@ def _semaphore_signal_abstract_eval( device_id_type: DeviceIdType, ): del device_id_type - sem_aval, sem_indexers_avals, value_aval, device_id_avals, core_index_aval = ( - tree_util.tree_unflatten(args_tree, avals) - ) - check_sem_avals(sem_aval, sem_indexers_avals, "signal") + ( + sem_aval, + sem_transforms_avals, + value_aval, + device_id_avals, + core_index_aval, + ) = tree_util.tree_unflatten(args_tree, avals) + check_sem_avals(sem_aval, sem_transforms_avals, "signal") if value_aval.dtype != jnp.dtype("int32"): raise ValueError("Must signal an int32 value.") if device_id_avals is not None: @@ -294,16 +305,16 @@ def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, tree = eqn.params["args_tree"] ( sem, - sem_indexers, + sem_transforms, value, device_ids, _, ) = tree_util.tree_unflatten(tree, invars) out = pp.concat([ - pp.text('semaphore_signal'), - pp.text(' '), - sp.pp_ref_indexers(context, sem, sem_indexers), - pp.text(' '), + pp.text("semaphore_signal"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), pp.text(jax_core.pp_var(value, context)), ]) if device_ids is not None: @@ -325,15 +336,15 @@ def _semaphore_signal_discharge_rule(in_avals, args_tree, device_id_type): del out_avals, device_id_type - [ref, indexers, inc, device_id, core_index] = args_tree.unflatten(flat_args) + [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) if device_id is not None: raise NotImplementedError("Remote signal not implemented.") if core_index is not None: raise NotImplementedError("Multiple core support not implemented.") - sem_value = _index_semaphore(ref, indexers, in_avals[0]) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) inc = inc.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.index_swap_array( - ref, indexers, sem_value + inc + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value + inc ) return (new_sem_value,) + (None,) * (len(in_avals) - 1), () state_discharge.register_discharge_rule(semaphore_signal_p)( @@ -345,16 +356,18 @@ def _semaphore_signal_discharge_rule(in_avals, semaphore_wait_p.multiple_results = True def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): - ref, indexers = _get_ref_and_indexers(sem_or_view) + ref, transforms = _get_ref_and_transforms(sem_or_view) dec = jnp.asarray(dec, dtype=jnp.int32) - args = [ref, indexers, dec] + args = [ref, transforms, dec] flat_args, args_tree = tree_util.tree_flatten(args) semaphore_wait_p.bind(*flat_args, args_tree=args_tree) @semaphore_wait_p.def_abstract_eval def _semaphore_wait_abstract_eval(*avals, args_tree): - sem_aval, sem_indexers_avals, value_aval = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals(sem_aval, sem_indexers_avals, "wait") + sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( + args_tree, avals + ) + check_sem_avals(sem_aval, sem_transforms_avals, "wait") if value_aval.dtype != jnp.dtype("int32"): raise ValueError("Must wait an int32 value.") return [] @@ -367,14 +380,14 @@ def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, tree = eqn.params["args_tree"] ( sem, - sem_indexers, + sem_transforms, value, ) = tree_util.tree_unflatten(tree, invars) return pp.concat([ - pp.text('semaphore_wait'), - pp.text(' '), - sp.pp_ref_indexers(context, sem, sem_indexers), - pp.text(' '), + pp.text("semaphore_wait"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), pp.text(jax_core.pp_var(value, context)), ]) jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn @@ -384,11 +397,11 @@ def _semaphore_wait_discharge_rule(in_avals, *flat_args, args_tree): del out_avals - [ref, indexers, dec] = args_tree.unflatten(flat_args) - sem_value = _index_semaphore(ref, indexers, in_avals[0]) + [ref, transforms, dec] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) dec = dec.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.index_swap_array( - ref, indexers, sem_value -dec + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value - dec ) return (new_sem_value,) + (None,) * (len(in_avals) - 1), () state_discharge.register_discharge_rule(semaphore_wait_p)( @@ -399,13 +412,13 @@ def _semaphore_wait_discharge_rule(in_avals, @dataclasses.dataclass class AsyncCopyDescriptor: src_ref: Any - src_indexers: tuple[indexing.NDIndexer, ...] + src_transforms: tuple[Transform, ...] dst_ref: Any - dst_indexers: tuple[indexing.NDIndexer, ...] + dst_transforms: tuple[Transform, ...] dst_sem: int | jax.Array - dst_sem_indexers: tuple[indexing.NDIndexer, ...] + dst_sem_transforms: tuple[Transform, ...] src_sem: int | jax.Array | None - src_sem_indexers: tuple[indexing.NDIndexer, ...] | None + src_sem_transforms: tuple[Transform, ...] | None device_id: int | jax.Array | None device_id_type: DeviceIdType = DeviceIdType.MESH @@ -421,13 +434,13 @@ def is_remote(self): def start(self): flat_args, tree = tree_util.tree_flatten(( self.src_ref, - self.src_indexers, + self.src_transforms, self.dst_ref, - self.dst_indexers, + self.dst_transforms, self.dst_sem, - self.dst_sem_indexers, + self.dst_sem_transforms, self.src_sem, - self.src_sem_indexers, + self.src_sem_transforms, self.device_id, )) dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type) @@ -438,9 +451,12 @@ def wait(self): self.wait_recv() def wait_recv(self): - wait_args, tree = tree_util.tree_flatten( - (self.dst_sem, self.dst_sem_indexers, self.dst_ref, self.dst_indexers) - ) + wait_args, tree = tree_util.tree_flatten(( + self.dst_sem, + self.dst_sem_transforms, + self.dst_ref, + self.dst_transforms, + )) dma_wait_p.bind( *wait_args, tree=tree, device_id_type=self.device_id_type ) @@ -448,9 +464,12 @@ def wait_recv(self): def wait_send(self): if not self.is_remote: raise ValueError("Cannot `wait_send` on a local copy.") - wait_args, tree = tree_util.tree_flatten( - (self.src_sem, self.src_sem_indexers, self.src_ref, self.src_indexers) - ) + wait_args, tree = tree_util.tree_flatten(( + self.src_sem, + self.src_sem_transforms, + self.src_ref, + self.src_transforms, + )) dma_wait_p.bind( *wait_args, tree=tree, device_id_type=self.device_id_type ) @@ -463,32 +482,32 @@ def wait_send(self): def _dma_start_abstract_eval(*args, tree, device_id_type): ( src_ref_aval, - src_indexers_avals, + src_transforms_avals, dst_ref_aval, - dst_indexers_avals, + dst_transforms_avals, dst_sem_aval, - dst_sem_indexers_avals, + dst_sem_transforms_avals, src_sem_aval, - src_sem_indexers_avals, + src_sem_transforms_avals, device_id_aval, ) = tree_util.tree_unflatten(tree, args) dst_sem_shape = dst_sem_aval.shape - if dst_sem_indexers_avals: - dst_sem_shape = dst_sem_indexers_avals[-1].get_indexer_shape() + if dst_sem_transforms_avals: + dst_sem_shape = dst_sem_transforms_avals[-1].get_indexer_shape() if dst_sem_shape: raise ValueError( f"Cannot signal on a non-()-shaped semaphore: {dst_sem_shape}" ) if src_sem_aval is not None: src_sem_shape = src_sem_aval.shape - if src_sem_indexers_avals: - src_sem_shape = src_sem_indexers_avals[-1].get_indexer_shape() + if src_sem_transforms_avals: + src_sem_shape = src_sem_transforms_avals[-1].get_indexer_shape() if src_sem_shape: raise ValueError( f"Cannot signal on a non-()-shaped semaphore: {src_sem_shape}" ) - n_src_indexers = len(tree_util.tree_leaves(src_indexers_avals)) - return [], {state.ReadEffect(0), state.WriteEffect(n_src_indexers + 1)} + n_src_transforms = len(tree_util.tree_leaves(src_transforms_avals)) + return [], {state.ReadEffect(0), state.WriteEffect(n_src_transforms + 1)} def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, context: jax_core.JaxprPpContext, @@ -497,27 +516,27 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, tree = eqn.params["tree"] ( src_ref, - src_indexers, + src_transforms, dst_ref, - dst_indexers, + dst_transforms, dst_sem, - dst_sem_indexers, + dst_sem_transforms, src_sem, - src_sem_indexers, + src_sem_transforms, device_id, ) = tree_util.tree_unflatten(tree, invars) - del src_sem_indexers + del src_sem_transforms # TODO(sharadmv): pretty print source semaphores and device id if src_sem or device_id: return jax_core._pp_eqn(eqn, context, settings) return pp.concat([ - pp.text('dma_start'), - pp.text(' '), - sp.pp_ref_indexers(context, src_ref, src_indexers), - pp.text(' -> '), - sp.pp_ref_indexers(context, dst_ref, dst_indexers), - pp.text(' '), - sp.pp_ref_indexers(context, dst_sem, dst_sem_indexers), + pp.text("dma_start"), + pp.text(" "), + sp.pp_ref_transforms(context, src_ref, src_transforms), + pp.text(" -> "), + sp.pp_ref_transforms(context, dst_ref, dst_transforms), + pp.text(" "), + sp.pp_ref_transforms(context, dst_sem, dst_sem_transforms), ]) jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn @@ -526,24 +545,24 @@ def dma_start_discharge_rule(in_avals, out_avals, *args, tree, device_id_type): ( src_ref, - src_indexers, + src_transforms, dst_ref, - dst_indexers, + dst_transforms, dst_sem, - dst_sem_indexers, + dst_sem_transforms, src_sem, - src_sem_indexers, + src_sem_transforms, device_id, ) = tree_util.tree_unflatten(tree, args) ( _, - src_indexers_avals, + src_transforms_avals, _, - dst_indexers_avals, + dst_transforms_avals, dst_sem_aval, - dst_sem_indexers_avals, + dst_sem_transforms_avals, src_sem_aval, - src_sem_indexers_avals, + src_sem_transforms_avals, _, ) = tree_util.tree_unflatten(tree, in_avals) del out_avals @@ -551,14 +570,14 @@ def dma_start_discharge_rule(in_avals, out_avals, if not is_remote: # Local async copies only use one semaphore. assert src_sem is None - assert src_sem_indexers is None + assert src_sem_transforms is None - num_src_sem_indexers = len(tree_util.tree_leaves(src_sem_indexers_avals)) - num_dst_sem_indexers = len(tree_util.tree_leaves(dst_sem_indexers_avals)) - num_src_index_vals = len(tree_util.tree_leaves(src_indexers_avals)) - num_dst_index_vals = len(tree_util.tree_leaves(dst_indexers_avals)) + num_src_sem_transforms = len(tree_util.tree_leaves(src_sem_transforms_avals)) + num_dst_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals)) + num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals)) + num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals)) - updates = state_discharge.index_array(src_ref, src_indexers) + updates = state_discharge.transform_array(src_ref, src_transforms) local_src = updates if is_remote: @@ -602,44 +621,52 @@ def dma_start_discharge_rule(in_avals, out_avals, global_updates, index, axis=0, keepdims=False) # Handle asymmetrical indexing when devices do not share the same - # dst_indexer. - global_dst_indexers = tree_util.tree_map( - lambda x: jax.lax.all_gather(x, shard_axis), dst_indexers) - dst_indexers = tree_util.tree_map( + # dst_transform. + global_dst_transforms = tree_util.tree_map( + lambda x: jax.lax.all_gather(x, shard_axis), dst_transforms + ) + dst_transforms = tree_util.tree_map( lambda x: jax.lax.dynamic_index_in_dim( - x, index, axis=0, keepdims=False), global_dst_indexers) + x, index, axis=0, keepdims=False + ), + global_dst_transforms, + ) - _, new_dst = state_discharge.index_swap_array( - dst_ref, dst_indexers, updates + _, new_dst = state_discharge.transform_swap_array( + dst_ref, dst_transforms, updates ) # Update semaphore values. # TODO(justinfu): Potentially handle asymmetric copy sizes. recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - dst_sem_value = _index_semaphore(dst_sem, dst_sem_indexers, dst_sem_aval) - _, new_dst_sem = state_discharge.index_swap_array( - dst_sem, dst_sem_indexers, dst_sem_value + recv_size + dst_sem_value = _transform_semaphore( + dst_sem, dst_sem_transforms, dst_sem_aval + ) + _, new_dst_sem = state_discharge.transform_swap_array( + dst_sem, dst_sem_transforms, dst_sem_value + recv_size ) if is_remote: send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE) send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - src_sem_value = _index_semaphore(src_sem, src_sem_indexers, src_sem_aval) - _, new_src_sem = state_discharge.index_swap_array( - src_sem, src_sem_indexers, src_sem_value + send_size + src_sem_value = _transform_semaphore( + src_sem, src_sem_transforms, src_sem_aval + ) + _, new_src_sem = state_discharge.transform_swap_array( + src_sem, src_sem_transforms, src_sem_value + send_size ) else: new_src_sem = None new_vals = (None,) # src_val - new_vals += (None,) * num_src_index_vals + new_vals += (None,) * num_src_transform_vals new_vals += (new_dst,) # dst_val - new_vals += (None,) * num_dst_index_vals + new_vals += (None,) * num_dst_transform_vals new_vals += (new_dst_sem,) # dst_sem - new_vals += (None,) * num_dst_sem_indexers + new_vals += (None,) * num_dst_sem_transforms if is_remote: new_vals += (new_src_sem,) # src_sem - new_vals += (None,) * num_src_sem_indexers + new_vals += (None,) * num_src_sem_transforms new_vals += (None,) # device_id assert (len(new_vals) == len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}" @@ -662,13 +689,13 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, del settings invars = eqn.invars tree = eqn.params["tree"] - sem, sem_indexers, ref, indexers = tree_util.tree_unflatten(tree, invars) + sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, invars) return pp.concat([ - pp.text('dma_wait'), - pp.text(' '), - sp.pp_ref_indexers(context, ref, indexers), - pp.text(' '), - sp.pp_ref_indexers(context, sem, sem_indexers), + pp.text("dma_wait"), + pp.text(" "), + sp.pp_ref_transforms(context, ref, transforms), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), ]) jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn @@ -676,42 +703,53 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, def dma_wait_discharge_rule(in_avals, out_avals, *args, tree, device_id_type): del out_avals, device_id_type - (sem, sem_indexers, ref, ref_indexers) = tree_util.tree_unflatten(tree, args) + (sem, sem_transforms, ref, ref_transforms) = tree_util.tree_unflatten( + tree, args + ) ( sem_aval, - sem_indexers_avals, + sem_transforms_avals, _, - ref_indexers_avals, + ref_transforms_avals, ) = tree_util.tree_unflatten(tree, in_avals) - num_sem_indexers = len(tree_util.tree_leaves(sem_indexers_avals)) - num_indexers = len(tree_util.tree_leaves(ref_indexers_avals)) - updates = state_discharge.index_array(ref, ref_indexers) + num_sem_transforms = len(tree_util.tree_leaves(sem_transforms_avals)) + num_transforms = len(tree_util.tree_leaves(ref_transforms_avals)) + updates = state_discharge.transform_array(ref, ref_transforms) copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - sem_value = _index_semaphore(sem, sem_indexers, sem_aval) - _, new_sem = state_discharge.index_swap_array( - sem, sem_indexers, sem_value - copy_size + sem_value = _transform_semaphore(sem, sem_transforms, sem_aval) + _, new_sem = state_discharge.transform_swap_array( + sem, sem_transforms, sem_value - copy_size ) new_vals = (new_sem,) # sem - new_vals += (None,) * num_sem_indexers + new_vals += (None,) * num_sem_transforms new_vals += (None,) # ref - new_vals += (None,) * num_indexers + new_vals += (None,) * num_transforms return new_vals, [] state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule) -def _get_ref_and_indexers(ref): - if isinstance(ref, state.RefView): - return ref.ref, ref.indexers +def _get_ref_and_transforms(ref): + if isinstance(ref, state.TransformedRef): + return ref.ref, ref.transforms return ref, () def make_async_copy(src_ref, dst_ref, sem): """Issues a DMA copying from src_ref to dst_ref.""" - src_ref, src_indexers = _get_ref_and_indexers(src_ref) - dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref) - sem, sem_indexers = _get_ref_and_indexers(sem) - return AsyncCopyDescriptor(src_ref, src_indexers, dst_ref, dst_indexers, - sem, sem_indexers, None, None, None, - DeviceIdType.MESH) + src_ref, src_transforms = _get_ref_and_transforms(src_ref) + dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref) + sem, sem_transforms = _get_ref_and_transforms(sem) + return AsyncCopyDescriptor( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + sem, + sem_transforms, + None, + None, + None, + DeviceIdType.MESH, + ) def async_copy(src_ref, dst_ref, sem): """Issues a DMA copying from src_ref to dst_ref.""" @@ -739,13 +777,22 @@ def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, Returns: An AsyncCopyDescriptor. """ - src_ref, src_indexers = _get_ref_and_indexers(src_ref) - send_sem, send_sem_indexers = _get_ref_and_indexers(send_sem) - dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref) - recv_sem, recv_sem_indexers = _get_ref_and_indexers(recv_sem) + src_ref, src_transforms = _get_ref_and_transforms(src_ref) + send_sem, send_sem_transforms = _get_ref_and_transforms(send_sem) + dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref) + recv_sem, recv_sem_transforms = _get_ref_and_transforms(recv_sem) return AsyncCopyDescriptor( - src_ref, src_indexers, dst_ref, dst_indexers, recv_sem, recv_sem_indexers, - send_sem, send_sem_indexers, device_id, device_id_type=device_id_type) + src_ref, + src_transforms, + dst_ref, + dst_transforms, + recv_sem, + recv_sem_transforms, + send_sem, + send_sem_transforms, + device_id, + device_id_type=device_id_type, + ) def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type: DeviceIdType = DeviceIdType.MESH): diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py index df186d46373a..bae87226c664 100644 --- a/jax/_src/pallas/mosaic/verification.py +++ b/jax/_src/pallas/mosaic/verification.py @@ -550,13 +550,17 @@ def _pretend_abstract_eval(*_, **params): def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree): if ctx.lowering_context.for_verification: - (base_read_refs, indexers) = tree_util.tree_unflatten(tree, flat_args) + (base_read_refs, transforms) = tree_util.tree_unflatten(tree, flat_args) read_ref_avals, _ = tree_util.tree_unflatten(tree, ctx.avals_in) block_shapes, _ = tree_util.tree_unflatten(tree, ctx.block_shapes) read_refs = [ lowering._index_ref(ref, aval, block_shape, indexer)[0] for ref, aval, block_shape, indexer in zip( - base_read_refs, read_ref_avals, block_shapes, indexers, strict=True, + base_read_refs, + read_ref_avals, + block_shapes, + transforms, + strict=True, ) ] ir.Operation.create("verification.pretend", operands=read_refs) @@ -565,8 +569,10 @@ def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree): lowering.lowering_rules[pretend_p] = _pretend_lowering # type: ignore def pretend(read_refs): - refs, indexers = unzip2(primitives._get_ref_and_indexers(r) for r in read_refs) - flat_args, tree = tree_util.tree_flatten((refs, indexers)) + refs, transforms = unzip2( + primitives._get_ref_and_transforms(r) for r in read_refs + ) + flat_args, tree = tree_util.tree_flatten((refs, transforms)) return pretend_p.bind(*flat_args, tree=tree) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index e41a8cf59975..fbc389aae3fb 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -177,8 +177,10 @@ def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType): def _atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None, atomic_type: AtomicOpType): - x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "atomic_rmw") - args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask)) + x_ref, transforms = sp.get_ref_and_transforms( + x_ref_or_view, idx, "atomic_rmw" + ) + args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask)) return atomic_rmw_p.bind( *args_flat, args_tree=args_tree, atomic_type=atomic_type ) @@ -379,7 +381,7 @@ def _load_pp_rule(eqn, context, settings): result = [ lhs, pp.text(' <- '), - sp.pp_ref_indexers(context, x, indexers) + sp.pp_ref_transforms(context, x, indexers) ] if mask is not None: result += [ @@ -529,7 +531,7 @@ def _swap_pp_rule(eqn, context, settings): # Pretty prints `_ = swap x v i` as `x[i] <- v` y, = eqn.outvars x, indexers, val, mask = eqn.params["args_tree"].unflatten(eqn.invars) - x_i = sp.pp_ref_indexers(context, x, indexers) + x_i = sp.pp_ref_transforms(context, x, indexers) if isinstance(y, jax_core.DropVar): return pp.concat([ x_i, @@ -638,8 +640,10 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None, eviction_policy: TO BE DOCUMENTED. volatile: TO BE DOCUMENTED. """ - x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load") - args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other)) + x_ref, transforms = sp.get_ref_and_transforms(x_ref_or_view, idx, "load") + args_flat, args_tree = tree_util.tree_flatten( + (x_ref, transforms, mask, other) + ) return load_p.bind( *args_flat, args_tree=args_tree, @@ -657,8 +661,10 @@ def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, Returns: The value stored in the ref prior to the swap. """ - x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, _function_name) - args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask)) + x_ref, transforms = sp.get_ref_and_transforms( + x_ref_or_view, idx, _function_name + ) + args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask)) return swap_p.bind( *args_flat, args_tree=args_tree, eviction_policy=eviction_policy ) diff --git a/jax/_src/state/__init__.py b/jax/_src/state/__init__.py index 0041b2506061..2f1c88be495b 100644 --- a/jax/_src/state/__init__.py +++ b/jax/_src/state/__init__.py @@ -12,7 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for state.""" -from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect, - AccumEffect, StateEffect, RefEffect, - get_ref_state_effects, shaped_array_ref, - RefView) +from jax._src.state.types import ( + AbstractRef, + AccumEffect, + ReadEffect, + RefEffect, + StateEffect, + Transform, + TransformedRef, + WriteEffect, + get_ref_state_effects, + shaped_array_ref, +) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 8e1b3732dd3d..6a912abf215b 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -20,10 +20,8 @@ import operator from typing import Any, Protocol, TypeVar -import numpy as np - -from jax._src import api_util from jax._src import ad_util +from jax._src import api_util from jax._src import config from jax._src import core from jax._src import linear_util as lu @@ -35,12 +33,20 @@ from jax._src.lax import lax from jax._src.lax import slicing as lax_slicing from jax._src.state import indexing -from jax._src.state.types import AbstractRef, RefEffect -from jax._src.state.primitives import get_p, swap_p, addupdate_p -from jax._src.state.utils import hoist_consts_to_refs +from jax._src.state.primitives import addupdate_p, get_p, swap_p +from jax._src.state.types import AbstractRef, RefBitcaster, RefEffect +from jax._src.state.utils import bitcast, hoist_consts_to_refs from jax._src.typing import Array -from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache, - partition_list, merge_lists, split_dict) +from jax._src.util import ( + merge_lists, + partition_list, + safe_map, + safe_zip, + split_dict, + split_list, + weakref_lru_cache, +) +import numpy as np ## JAX utilities @@ -264,73 +270,95 @@ def _prepend_scatter(x, indexer, val, *, add=False): return x[None].at[(0, *indexer)].add(val)[0] return x[None].at[(0, *indexer)].set(val)[0] +def _bitcast_array(x, bitcaster: RefBitcaster): + return bitcast(x, bitcaster.dtype) -def index_array(x, indexers): - if indexers is None: - indexers = [] +def _index_array(x, indexer): + if _is_trivial_indexer(indexer): + return x + # Try the three APIs in the following order: `lax.slice`, + # `lax.dynamic_slice` and gather + if maybe_slice := _maybe_convert_to_slice(indexer): + x = lax_slicing.slice(x, *zip(*maybe_slice)) + # If everything in the indexer is a slice or ()-shaped, we can also + # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. + # We need to squeeze out the 1-sized slices at the end. + elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_slice + y = lax_slicing.dynamic_slice(x, starts, sizes) + x = lax.squeeze(y, squeeze_dims) + else: + indexer = _convert_to_array_indexer(indexer) + x = x[None][(np.array(0, "int32"), *indexer)] + return x + + +def transform_array(x, transforms): + if transforms is None: + transforms = [] result = x - for indexer in indexers: - if _is_trivial_indexer(indexer): + for transform in transforms: + if transform is None: continue - if indexer is None: - continue - - # Try the three APIs in the following order: `lax.slice`, - # `lax.dynamic_slice` and gather - if maybe_slice := _maybe_convert_to_slice(indexer): - result = lax_slicing.slice(result, *zip(*maybe_slice)) - # If everything in the indexer is a slice or ()-shaped, we can also - # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. - # We need to squeeze out the 1-sized slices at the end. - elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, sizes, squeeze_dims = maybe_slice - y = lax_slicing.dynamic_slice(result, starts, sizes) - result = lax.squeeze(y, squeeze_dims) + if isinstance(transform, indexing.NDIndexer): + result = _index_array(result, transform) + elif isinstance(transform, RefBitcaster): + result = _bitcast_array(result, transform) else: - indexer = _convert_to_array_indexer(indexer) - result = result[None][(np.array(0, "int32"), *indexer)] + raise NotImplementedError(f"Unsupported transform: {transform}") return result -def index_swap_array(x, indexers, val): - if indexers is None: - indexers = [] +def transform_swap_array(x, transforms, val): + if transforms is None: + transforms = [] result = x result_val = val # Compute updated "val" (result). _results = [x] - for indexer in indexers: - if _is_trivial_indexer(indexer): - _results.append(None) - continue - # If everything in the indexer is a slice or ()-shaped, we can also - # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. - # We need to squeeze out the 1-sized slices at the end. - if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, sizes, squeeze_dims = maybe_slice - result_old = lax_slicing.dynamic_slice(result, starts, sizes) - result = lax.squeeze(result_old, squeeze_dims) + for transform in transforms: + if isinstance(transform, indexing.NDIndexer): + indexer = transform + if _is_trivial_indexer(indexer): + _results.append(None) + continue + # If everything in the indexer is a slice or ()-shaped, we can also + # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. + # We need to squeeze out the 1-sized slices at the end. + if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_slice + result_old = lax_slicing.dynamic_slice(result, starts, sizes) + result = lax.squeeze(result_old, squeeze_dims) + else: + indexer = _convert_to_array_indexer(indexer) + result = _prepend_gather(result, indexer) + _results.append(result) + elif isinstance(transform, RefBitcaster): + _results.append(_bitcast_array(result, transform)) else: - indexer = _convert_to_array_indexer(indexer) - result = _prepend_gather(result, indexer) - _results.append(result) + raise NotImplementedError(f"Unsupported transform: {transform}") # Compute updated "x" (result_val) - for i, indexer in reversed(list(enumerate(indexers))): - if _is_trivial_indexer(indexer): - continue - if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, _, squeeze_dims = maybe_slice - result_val = lax.expand_dims(result_val, squeeze_dims) - result_val = lax_slicing.dynamic_update_slice( - _results[i], result_val, starts) + for i, transform in reversed(list(enumerate(transforms))): + if isinstance(transform, indexing.NDIndexer): + indexer = transform + if _is_trivial_indexer(indexer): + continue + if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, _, squeeze_dims = maybe_slice + result_val = lax.expand_dims(result_val, squeeze_dims) + result_val = lax_slicing.dynamic_update_slice( + _results[i], result_val, starts + ) + else: + indexer = _convert_to_array_indexer(indexer) + result_val = _prepend_scatter(_results[i], indexer, result_val) else: - indexer = _convert_to_array_indexer(indexer) - result_val = _prepend_scatter(_results[i], indexer, result_val) + raise NotImplementedError(f"Unsupported transform: {transform}") return result, result_val def _get_discharge(x, idx, tree): - indexers = tree_util.tree_unflatten(tree, idx) - return index_array(x, indexers) + transforms = tree_util.tree_unflatten(tree, idx) + return transform_array(x, transforms) @register_discharge_rule(swap_p) def _swap_discharge_rule( @@ -342,8 +370,8 @@ def _swap_discharge_rule( return (x_new, None) + (None,) * len(idx), z def _swap_discharge(x, val, idx, tree): - indexers = tree_util.tree_unflatten(tree, idx) - return index_swap_array(x, indexers, val) + transforms = tree_util.tree_unflatten(tree, idx) + return transform_swap_array(x, transforms, val) @register_discharge_rule(addupdate_p) def _addupdate_discharge_rule( @@ -355,10 +383,10 @@ def _addupdate_discharge_rule( return (ans, None) + (None,) * len(idx), [] def _addupdate_discharge(x, val, idx, tree): - indexers = tree_util.tree_unflatten(tree, idx) - if len(indexers) > 1: + transforms = tree_util.tree_unflatten(tree, idx) + if len(transforms) > 1: raise NotImplementedError("Only single indexer is supported.") - indexer = indexers[0] + indexer = transforms[0] if _is_trivial_indexer(indexer): return x + val # If everything in the indexer is a slice or ()-shaped, we can also diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 750d3239a019..988f362290f0 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -18,9 +18,6 @@ import types from typing import Any, Union -import numpy as np - - from jax._src import ad_util from jax._src import core from jax._src import dispatch @@ -28,14 +25,22 @@ from jax._src import tree_util from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax -from jax._src.typing import Array from jax._src.state import indexing -from jax._src.state.types import (AbstractRef, RefView, ReadEffect, WriteEffect, - AccumEffect) +from jax._src.state.types import ( + AbstractRef, + AccumEffect, + ReadEffect, + RefBitcaster, + Transform, + TransformedRef, + WriteEffect, +) +from jax._src.typing import Array from jax._src.util import safe_map, safe_zip +import numpy as np ## General utilities @@ -59,29 +64,29 @@ Indexer = tuple[Union[int, slice, Array, types.EllipsisType], ...] -def get_ref_and_indexers( +def get_ref_and_transforms( ref_or_view: Any, idx: Indexer | None, function_name: str -) -> tuple[Any, tuple[indexing.NDIndexer, ...]]: - if isinstance(ref_or_view, RefView): - ref, indexers = ref_or_view.ref, ref_or_view.indexers +) -> tuple[Any, tuple[Transform, ...]]: + if isinstance(ref_or_view, TransformedRef): + ref, transforms = ref_or_view.ref, ref_or_view.transforms else: - ref, indexers = ref_or_view, () + ref, transforms = ref_or_view, () ref_aval = core.get_aval(ref) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.") if not isinstance(ref_aval.inner_aval, core.ShapedArray): return ref, () if idx is None: - return ref, indexers + return ref, transforms nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape) - return ref, (*indexers, nd_indexer) + return ref, (*transforms, nd_indexer) def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array: """Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" - ref, indexers = get_ref_and_indexers(ref_or_view, idx, "ref_get") - flat_indexers, tree = tree_util.tree_flatten(indexers) - return get_p.bind(ref, *flat_indexers, tree=tree) + ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_get") + flat_transforms, tree = tree_util.tree_flatten(transforms) + return get_p.bind(ref, *flat_transforms, tree=tree) # `swap` mutates a `Ref`, setting its value and returns its previous value. # b = swap_p.bind(x, a) @@ -102,14 +107,22 @@ def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array: swap_p = core.Primitive("swap") swap_p.def_impl(partial(dispatch.apply_primitive, swap_p)) -def ref_swap(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array, - _function_name: str = "ref_swap") -> Array: + +def ref_swap( + ref_or_view: AbstractRef | TransformedRef, + idx: Indexer | None, + value: Array, + _function_name: str = "ref_swap", +) -> Array: """Sets a `Ref`'s value and returns the original value.""" - ref, indexers = get_ref_and_indexers(ref_or_view, idx, _function_name) - flat_indexers, tree = tree_util.tree_flatten(indexers) - return swap_p.bind(ref, value, *flat_indexers, tree=tree) + ref, transforms = get_ref_and_transforms(ref_or_view, idx, _function_name) + flat_transforms, tree = tree_util.tree_flatten(transforms) + return swap_p.bind(ref, value, *flat_transforms, tree=tree) -def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array) -> None: + +def ref_set( + ref_or_view: AbstractRef | TransformedRef, idx: Indexer | None, value: Array +) -> None: """Sets a `Ref`'s value, a.k.a. ref[idx] <- value.""" ref_swap(ref_or_view, idx, value, _function_name="ref_set") @@ -130,34 +143,50 @@ def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Arra def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None: """Mutates a ref with an additive update i.e. `ref[idx] += x`.""" - ref, indexers = get_ref_and_indexers(ref_or_view, idx, "ref_addupdate") - flat_indexers, tree = tree_util.tree_flatten(indexers) - return addupdate_p.bind(ref, x, *flat_indexers, tree=tree) + ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_addupdate") + flat_transforms, tree = tree_util.tree_flatten(transforms) + return addupdate_p.bind(ref, x, *flat_transforms, tree=tree) ## get/set/addupdate abstract evaluation rules -def _shape_after_indexing( - shape: tuple[int | Array, ...], indexers: tuple[indexing.NDIndexer, ...] +def _shape_after_transforming( + shape: tuple[int | Array, ...], transforms: tuple[Transform, ...] ) -> tuple[int | Array, ...]: - for indexer in indexers: - # Run some simple checks that all the indexers have consistent shapes - if not indexer.is_dynamic_size: - assert indexer.shape == shape, (indexer.shape, shape) - shape = indexer.get_indexer_shape() + for transform in transforms: + match transform: + case indexing.NDIndexer(): + # Run some simple checks that all the indexers have consistent shapes + if not transform.is_dynamic_size: + assert transform.shape == shape, (transform.shape, shape) + shape = transform.get_indexer_shape() + case RefBitcaster(): + shape = transform.shape + case _: + raise ValueError(f"Unsupported transform: {transform}") return shape +def _dtype_after_transforming( + dtype: Any, transforms: tuple[Transform, ...] +) -> Any: + for transform in reversed(transforms): + if isinstance(transform, RefBitcaster): + return transform.dtype + return dtype + + def _get_abstract_eval(ref_aval: AbstractRef, *args, tree): - indexers = tree_util.tree_unflatten(tree, args) + transforms = tree_util.tree_unflatten(tree, args) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - out_shape = _shape_after_indexing(ref_aval.shape, indexers) - out_aval = ref_aval.inner_aval.update(shape=out_shape) + out_shape = _shape_after_transforming(ref_aval.shape, transforms) + out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) + out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype) else: - if indexers: + if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") out_aval = ref_aval.inner_aval return (out_aval, {ReadEffect(0)}) @@ -166,27 +195,30 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args, def _swap_abstract_eval(ref_aval: AbstractRef, val_aval: core.AbstractValue, *args: Any, tree): - indexers = tree_util.tree_unflatten(tree, args) + transforms = tree_util.tree_unflatten(tree, args) out_aval: core.AbstractValue if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): val_aval = core.raise_to_shaped(val_aval) assert isinstance(val_aval, core.ShapedArray) - expected_out_shape = _shape_after_indexing(ref_aval.shape, indexers) + expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms) + expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) if expected_out_shape != val_aval.shape: raise ValueError("Invalid shape for `swap`. " f"Ref shape: {ref_aval.shape}. " f"Expected shape: {expected_out_shape}. " f"Value shape: {val_aval.shape}. " - f"Indices: {indexers}. ") - if ref_aval.dtype != val_aval.dtype and not val_aval.weak_type: - raise ValueError("Invalid dtype for `swap`. " - f"Ref dtype: {ref_aval.dtype}. " - f"Value dtype: {val_aval.dtype}. ") - out_aval = core.ShapedArray(expected_out_shape, ref_aval.dtype) + f"Transforms: {transforms}. ") + if expected_out_dtype != val_aval.dtype and not val_aval.weak_type: + raise ValueError( + "Invalid dtype for `swap`. " + f"Ref dtype: {expected_out_dtype}. " + f"Value dtype: {val_aval.dtype}. " + ) + out_aval = core.ShapedArray(expected_out_shape, expected_out_dtype) else: - if indexers: + if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") out_aval = ref_aval.inner_aval return (out_aval, {WriteEffect(0)}) @@ -196,26 +228,29 @@ def _swap_abstract_eval(ref_aval: AbstractRef, def _addupdate_abstract_eval(ref_aval: AbstractRef, val_aval: core.AbstractValue, *args: Any, tree): - indexers = tree_util.tree_unflatten(tree, args) + transforms = tree_util.tree_unflatten(tree, args) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): val_aval = core.raise_to_shaped(val_aval) - slice_shape = _shape_after_indexing(ref_aval.shape, indexers) + out_shape = _shape_after_transforming(ref_aval.shape, transforms) + out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) assert isinstance(val_aval, core.ShapedArray) - if slice_shape != val_aval.shape: - raise ValueError("Invalid shape for `addupdate`. " - f"Ref shape: {ref_aval.shape}. " - f"Slice shape: {slice_shape}. " - f"Value shape: {val_aval.shape}. " - f"Indices: {indexers}. ") - if ref_aval.dtype != val_aval.dtype: + if out_shape != val_aval.shape: + raise ValueError( + "Invalid shape for `addupdate`. " + f"Ref shape: {ref_aval.shape}. " + f"Expected shape: {out_shape}. " + f"Value shape: {val_aval.shape}. " + f"Transforms: {transforms}. " + ) + if out_dtype != val_aval.dtype: raise ValueError("Invalid dtype for `addupdate`. " f"Ref dtype: {ref_aval.dtype}. " f"Value shape: {val_aval.dtype}. ") else: - # Check that the indexers are valid - if indexers: + # Check that the transforms are valid + if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") return [], {AccumEffect(0)} addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval) @@ -261,52 +296,73 @@ def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer indices.append(core.pp_var(idx, context)) # type: ignore return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")]) -def _pp_indexers( - context: core.JaxprPpContext, indexers: tuple[indexing.NDIndexer, ...], + +def pp_bitcaster( + context: core.JaxprPpContext, bitcaster: RefBitcaster +) -> pp.Doc: + del context + return pp.text( + f"[bitcast({bitcaster.dtype}[{','.join(str(d) for d in bitcaster.shape)}])]" + ) + + +def pp_transform(context: core.JaxprPpContext, transform: Transform) -> pp.Doc: + match transform: + case indexing.NDIndexer(): + return pp_indexer(context, transform) + case RefBitcaster(): + return pp_bitcaster(context, transform) + case _: + raise ValueError(f"Unsupported transform: {transform}") + + +def _pp_transforms( + context: core.JaxprPpContext, + transforms: tuple[Transform, ...], ): - if not indexers: + if not transforms: return pp.text("[...]") return pp.concat( - [pp_indexer(context, indexer) for indexer in indexers] + [pp_transform(context, transform) for transform in transforms] ) -def pp_ref_indexers(context: core.JaxprPpContext, ref, indexers): + +def pp_ref_transforms(context: core.JaxprPpContext, ref, transforms): return pp_ref_var( pp.concat([ pp.text(core.pp_var(ref, context)), - _pp_indexers(context, indexers), + _pp_transforms(context, transforms), ]) ) + def _get_pp_rule(eqn, context, settings) -> pp.Doc: # Pretty prints `a = get x i` as `x[i] <- a` y, = eqn.outvars x, *flat_idx = eqn.invars - indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) + transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes) - return pp.concat([ - lhs, - pp.text(' <- '), - pp_ref_indexers(context, x, indexers) - ]) + return pp.concat( + [lhs, pp.text(" <- "), pp_ref_transforms(context, x, transforms)] + ) core.pp_eqn_rules[get_p] = _get_pp_rule def _swap_pp_rule(eqn, context, settings) -> pp.Doc: y, = eqn.outvars x, v, *flat_idx = eqn.invars - indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) + transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) if type(y) is core.DropVar: # In the case of a set (ignored return value), # pretty print `_ = swap x v i` as `x[i] <- v` del y return pp.concat([ - pp_ref_indexers(context, x, indexers), - pp.text(' <- '), - pp.text(core.pp_var(v, context)) - ]) + pp_ref_transforms(context, x, transforms), + pp.text(" <- "), + pp.text(core.pp_var(v, context)), + ]) else: # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v` - x_i = pp_ref_indexers(context, x, indexers) + x_i = pp_ref_transforms(context, x, transforms) y = core.pp_vars([y], context, print_shapes=settings.print_shapes) return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '), x_i, pp.text(', '), @@ -318,11 +374,12 @@ def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc: # pretty-print ` = addupdate x i v` as `x[i] += v` () = eqn.outvars x, v, *flat_idx = eqn.invars - indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) + transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) return pp.concat([ - pp_ref_indexers(context, x, indexers), - pp.text(' += '), - pp.text(core.pp_var(v, context))]) + pp_ref_transforms(context, x, transforms), + pp.text(" += "), + pp.text(core.pp_var(v, context)), + ]) core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule ## get/swap/addupdate JVP rules diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index a71d671c5345..05368e978593 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -21,11 +21,13 @@ from typing import Any, Union from jax._src import core +from jax._src import dtypes from jax._src import effects from jax._src import pretty_printer as pp +from jax._src import tree_util from jax._src.state import indexing -from jax._src.util import safe_map, safe_zip from jax._src.typing import Array +from jax._src.util import safe_map, safe_zip ## JAX utilities @@ -72,7 +74,39 @@ class AccumEffect(RefEffect): StateEffect = Union[ReadEffect, WriteEffect, AccumEffect] + # ## `Ref`s +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class RefBitcaster: + dtype: dtypes.DType + shape: tuple[int, ...] + + @classmethod + def from_ref_new_dtype(cls, ref_or_view: Any, dtype) -> RefBitcaster: + if isinstance(ref_or_view, TransformedRef): + if ref_or_view.is_dynamic_size: + raise NotImplementedError( + "Bitcast ref with dynamic size is not supported." + ) + from jax._src.state.utils import eval_bitcast_shape # pytype: disable=import-error + dtype = dtypes.dtype(dtype) + return cls(dtype, eval_bitcast_shape(ref_or_view, dtype)) + + @property + def is_dynamic_size(self): + return False + + def tree_flatten(self): + return (), (self.dtype, self.shape) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + +Transform = indexing.NDIndexer | RefBitcaster @dataclasses.dataclass class RefIndexer: @@ -82,37 +116,47 @@ def __getitem__(self, slc): if not isinstance(slc, tuple): slc = (slc,) indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape) - if isinstance(self.ref_or_view, RefView): + if isinstance(self.ref_or_view, TransformedRef): view = self.ref_or_view - return RefView(view.ref, (*view.indexers, indexer)) - return RefView(self.ref_or_view, (indexer,)) + return TransformedRef(view.ref, (*view.transforms, indexer)) + return TransformedRef(self.ref_or_view, (indexer,)) -Indexer = Any @dataclasses.dataclass -class RefView: +class TransformedRef: ref: Any - indexers: tuple[indexing.NDIndexer, ...] + transforms: tuple[Transform, ...] @property def is_dynamic_size(self): - return self.indexers[-1].is_dynamic_size + return self.transforms[-1].is_dynamic_size @property def shape(self) -> tuple[int | Array, ...]: assert ( - len(self.indexers) > 0 - ), "Should not be able to create a trivial RefView" - return self.indexers[-1].get_indexer_shape() + len(self.transforms) > 0 + ), "Should not be able to create a trivial TransformedRef" + if isinstance(self.transforms[-1], indexing.NDIndexer): + return self.transforms[-1].get_indexer_shape() + return self.transforms[-1].shape @property def dtype(self): + for transform in reversed(self.transforms): + if isinstance(transform, RefBitcaster): + return transform.dtype return self.ref.dtype @property def at(self) -> RefIndexer: return RefIndexer(self) + def bitcast(self, dtype): + return TransformedRef( + self.ref, + (*self.transforms, RefBitcaster.from_ref_new_dtype(self, dtype)), + ) + def __getattr__(self, name): return getattr(self.ref, name) @@ -166,6 +210,10 @@ def dtype(self): def at(self): return RefIndexer(self) + @core.aval_method + def bitcast(self, dtype): + return TransformedRef(self, (RefBitcaster.from_ref_new_dtype(self, dtype),)) + @core.aval_method @staticmethod def get(tracer, idx=()): diff --git a/jax/_src/state/utils.py b/jax/_src/state/utils.py index 33fced775fad..909e84c3a6e3 100644 --- a/jax/_src/state/utils.py +++ b/jax/_src/state/utils.py @@ -13,14 +13,18 @@ # limitations under the License. """Utilities for tracing stateful functions.""" +from functools import partial from typing import Callable -from jax._src.interpreters import partial_eval as pe +import jax from jax._src import core +from jax._src import dtypes from jax._src import linear_util as lu +from jax._src.interpreters import partial_eval as pe from jax._src.state import AbstractRef -from jax._src.util import split_list, safe_map, safe_zip from jax._src.state.primitives import ref_get +from jax._src.typing import DTypeLike +from jax._src.util import safe_map, safe_zip, split_list map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -79,3 +83,41 @@ def val_to_ref_aval(x) -> AbstractRef: if type(aval) is not core.ShapedArray: raise TypeError(f"can't make ref from {x}") return AbstractRef(aval) + + +def dtype_bitwidth(dtype: DTypeLike) -> int: + if dtypes.isdtype(dtype, "integral"): + return dtypes.iinfo(dtype).bits + return dtypes.dtype(dtype).itemsize * 8 + + +def bitcast(x, dtype: DTypeLike): + x_bitwidth = dtype_bitwidth(x.dtype) + y_bitwidth = dtype_bitwidth(dtype) + shape = list(x.shape) + if x_bitwidth != y_bitwidth: + if len(shape) < 2: + raise NotImplementedError( + "Bitcast 1D ref with bitwidth change is not supported." + ) + # Note: this is only valid on TPU. + if shape[-2] * x_bitwidth % y_bitwidth != 0: + raise ValueError( + "Expected input and output shapes are the same after multiplying" + " the second-minor dimension by the bitwidths." + ) + shape[-2] = shape[-2] * x_bitwidth // y_bitwidth + if x_bitwidth < y_bitwidth: + ratio = y_bitwidth // x_bitwidth + x = x.reshape(*x.shape[:-2], x.shape[-2] // ratio, ratio, -1).swapaxes( + -1, -2 + ) + y = jax.lax.bitcast_convert_type(x, dtype) + if x_bitwidth > y_bitwidth: + y = y.swapaxes(-1, -2).reshape(shape) + return y + + +def eval_bitcast_shape(x, dtype: DTypeLike): + f = partial(bitcast, dtype=dtype) + return jax.eval_shape(f, jax.ShapeDtypeStruct(x.shape, x.dtype)).shape diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index a34c2b2f2f61..1d57dc164294 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -16,15 +16,15 @@ import sys import unittest -import numpy as np from absl.testing import absltest from absl.testing import parameterized - import jax from jax import lax -import jax.numpy as jnp from jax._src import test_util as jtu +from jax._src.pallas import utils as pallas_utils from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np if sys.platform != "win32": from jax.experimental.pallas import tpu as pltpu @@ -67,28 +67,29 @@ def pallas_call(cls, *args, **kwargs): class OpsTest(PallasBaseTest): - @parameterized.product(from_dtype=_JAX_DTYPES, to_dtype=_JAX_DTYPES) - def test_bitcast(self, from_dtype, to_dtype): - # TODO(jevinjiang): remove this after 2nd minor large tiling is enabled. - if (not jtu.is_device_tpu_at_least(version=5)) and ( - from_dtype in (jnp.int8, jnp.int16) or to_dtype in (jnp.int8, jnp.int16) - ): - self.skipTest( - "Not implemented: packing and unpacking int8, int16 are not supported" - " on < TPUv5" - ) + @parameterized.product( + from_dtype=_JAX_DTYPES, to_dtype=_JAX_DTYPES, is_ref_bitcast=[False, True] + ) + def test_bitcast(self, from_dtype, to_dtype, is_ref_bitcast): + if not jtu.is_device_tpu_at_least(version=4): + self.skipTest("Run on TPUv4+ to have expected memory layout") if from_dtype == to_dtype: self.skipTest("No bitcast needed") if from_dtype == jnp.bool_ or to_dtype == jnp.bool_: self.skipTest("Bitcasting with bool is not supported") def kernel(x_ref, y_ref): - y_ref[...] = pltpu.bitcast(x_ref[...], to_dtype) - - m, n = 32, 256 - shape = (m, n) - out_shape = (m * from_dtype.dtype.itemsize // to_dtype.dtype.itemsize, n) - inp = np.arange(np.prod(shape), dtype=from_dtype).reshape(shape) + if is_ref_bitcast: + y_ref[...] = x_ref.bitcast(to_dtype)[...] + else: + y_ref[...] = pltpu.bitcast(x_ref[...], to_dtype) + + m, n = 1, 256 + in_packing = 32 // pallas_utils.dtype_bitwidth(from_dtype) + out_packing = 32 // pallas_utils.dtype_bitwidth(to_dtype) + in_shape = (m * in_packing, n) + out_shape = (m * out_packing, n) + inp = np.arange(np.prod(in_shape), dtype=from_dtype).reshape(in_shape) out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(out_shape, to_dtype), diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 83ca6a5787cc..e100a5a39e49 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -31,6 +31,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lib import xla_extension from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax._src.state import utils as state_utils from jax.experimental import mesh_utils from jax.experimental import mosaic from jax.experimental import pallas as pl @@ -1926,6 +1927,100 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): np.testing.assert_array_equal(out, expected) +class PallasCallRefTransformTest(PallasBaseTest): + + @parameterized.product(slice_first=[True, False]) + def test_dma_bitcasted_ref(self, slice_first): + if not jtu.is_device_tpu_at_least(4): + self.skipTest('DMAs not supported on TPU generations <= 3') + + def kernel(x_hbm_ref, y_hbm_ref): + def body(sem): + ref = ( + x_hbm_ref.at[:8, :, :128].bitcast(jnp.int16) + if slice_first + else x_hbm_ref.bitcast(jnp.int16).at[:8, :, :128] + ) + pltpu.async_copy(ref, y_hbm_ref.at[...], sem).wait() + + pl.run_scoped(body, pltpu.SemaphoreType.DMA) + + x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 1, 256)) + y = self.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_shape=jax.ShapeDtypeStruct((8, 2, 128), jnp.int16), + )(x) + expected = ( + state_utils.bitcast(x[:8, :, :128], jnp.int16) + if slice_first + else state_utils.bitcast(x, jnp.int16)[:8, :, :128] + ) + np.testing.assert_array_equal(y, expected) + + @parameterized.product(slice_first=[True, False]) + def test_load_bitcasted_ref(self, slice_first: bool): + def kernel(x_ref, y_ref): + ref = ( + x_ref.at[:8, :128].bitcast(jnp.int16) + if slice_first + else x_ref.bitcast(jnp.int16).at[:16, :128] + ) + y_ref[...] = ref[...] + + x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 256)) + y = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 128), jnp.int16), + )(x) + expected = ( + state_utils.bitcast(x[:8, :128], jnp.int16) + if slice_first + else state_utils.bitcast(x, jnp.int16)[:16, :128] + ) + np.testing.assert_array_equal(y, expected) + + @parameterized.product(slice_first=[True, False]) + def test_store_bitcasted_ref(self, slice_first): + def kernel(x_ref, y_ref): + ref = ( + y_ref.at[:8, :128].bitcast(jnp.bfloat16) + if slice_first + else y_ref.bitcast(jnp.bfloat16).at[:16, :128] + ) + ref[...] = x_ref[...] + + x = jnp.arange(16 * 128, dtype=jnp.bfloat16).reshape((16, 128)) + y = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32), + )(x) + expected = state_utils.bitcast(x, jnp.int32) + np.testing.assert_array_equal(y[:8, :128], expected) + + def test_multiple_ref_transforms(self): + + def kernel(x_ref, y_ref): + ref = ( + x_ref.at[:8, :256] + .bitcast(jnp.int16) + .bitcast(jnp.float16) + .at[:, :128] + .bitcast(jnp.int32) + ) + y_ref[...] = ref[...] + + x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 256)) + y = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), + )(x) + np.testing.assert_array_equal(y, x[:8, :128]) + + class PallasCallPrintTest(PallasBaseTest): def test_debug_print(self): From d5ceb78708e4228c1b81aa3c5ff90fe163e94253 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 16 Sep 2024 06:05:19 -0700 Subject: [PATCH 503/702] Better documentation for jnp.atleast_*d --- jax/_src/numpy/lax_numpy.py | 134 ++++++++++++++++++++++++++++++++++-- 1 file changed, 130 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6c6cacef673e..5137b20bc898 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4304,10 +4304,44 @@ def atleast_1d(x: ArrayLike, /) -> Array: @overload def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... -@util.implements(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: - # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. + """Convert inputs to arrays with at least 1 dimension. + + JAX implementation of :func:`numpy.atleast_1d`. + + Args: + zero or more arraylike arguments. + + Returns: + an array or list of arrays corresponding to the input values. Arrays + of shape ``()`` are converted to shape ``(1,)``, and arrays with other + shapes are returned unchanged. + + See also: + - :func:`jax.numpy.asarray` + - :func:`jax.numpy.atleast_2d` + - :func:`jax.numpy.atleast_3d` + + Examples: + Scalar arguments are converted to 1D, length-1 arrays: + + >>> x = jnp.float32(1.0) + >>> jnp.atleast_1d(x) + Array([1.], dtype=float32) + + Higher dimensional inputs are returned unchanged: + + >>> y = jnp.arange(4) + >>> jnp.atleast_1d(y) + Array([0, 1, 2, 3], dtype=int32) + + Multiple arguments can be passed to the function at once, in which + case a list of results is returned: + + >>> jnp.atleast_1d(x, y) + [Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)] + """ util.check_arraylike("atleast_1d", *arys, emit_warning=True) if len(arys) == 1: return array(arys[0], copy=False, ndmin=1) @@ -4324,9 +4358,52 @@ def atleast_2d(x: ArrayLike, /) -> Array: @overload def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... -@util.implements(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: + """Convert inputs to arrays with at least 2 dimensions. + + JAX implementation of :func:`numpy.atleast_2d`. + + Args: + zero or more arraylike arguments. + + Returns: + an array or list of arrays corresponding to the input values. Arrays + of shape ``()`` are converted to shape ``(1, 1)``, 1D arrays of shape + ``(N,)`` are converted to shape ``(1, N)``, and arrays of all other + shapes are returned unchanged. + + See also: + - :func:`jax.numpy.asarray` + - :func:`jax.numpy.atleast_1d` + - :func:`jax.numpy.atleast_3d` + + Examples: + Scalar arguments are converted to 2D, size-1 arrays: + + >>> x = jnp.float32(1.0) + >>> jnp.atleast_2d(x) + Array([[1.]], dtype=float32) + + One-dimensional arguments have a unit dimension prepended to the shape: + + >>> y = jnp.arange(4) + >>> jnp.atleast_2d(y) + Array([[0, 1, 2, 3]], dtype=int32) + + Higher dimensional inputs are returned unchanged: + + >>> z = jnp.ones((2, 3)) + >>> jnp.atleast_2d(z) + Array([[1., 1., 1.], + [1., 1., 1.]], dtype=float32) + + Multiple arguments can be passed to the function at once, in which + case a list of results is returned: + + >>> jnp.atleast_2d(x, y) + [Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)] + """ # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("atleast_2d", *arys, emit_warning=True) if len(arys) == 1: @@ -4344,9 +4421,58 @@ def atleast_3d(x: ArrayLike, /) -> Array: @overload def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... -@util.implements(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: + """Convert inputs to arrays with at least 3 dimensions. + + JAX implementation of :func:`numpy.atleast_3d`. + + Args: + zero or more arraylike arguments. + + Returns: + an array or list of arrays corresponding to the input values. Arrays + of shape ``()`` are converted to shape ``(1, 1, 1)``, 1D arrays of + shape ``(N,)`` are converted to shape ``(1, N, 1)``, 2D arrays of + shape ``(M, N)`` are converted to shape ``(M, N, 1)``, and arrays + of all other shapes are returned unchanged. + + See also: + - :func:`jax.numpy.asarray` + - :func:`jax.numpy.atleast_1d` + - :func:`jax.numpy.atleast_2d` + + Examples: + Scalar arguments are converted to 3D, size-1 arrays: + + >>> x = jnp.float32(1.0) + >>> jnp.atleast_3d(x) + Array([[[1.]]], dtype=float32) + + 1D arrays have a unit dimension prepended and appended: + + >>> y = jnp.arange(4) + >>> jnp.atleast_3d(y).shape + (1, 4, 1) + + 2D arrays have a unit dimension appended: + + >>> z = jnp.ones((2, 3)) + >>> jnp.atleast_3d(z).shape + (2, 3, 1) + + Multiple arguments can be passed to the function at once, in which + case a list of results is returned: + + >>> x3, y3 = jnp.atleast_3d(x, y) + >>> print(x3) + [[[1.]]] + >>> print(y3) + [[[0] + [1] + [2] + [3]]] + """ # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("atleast_3d", *arys, emit_warning=True) if len(arys) == 1: From 80e1c94de63e7f89667cdf35f38d8fe298e97a50 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 16 Sep 2024 13:20:56 +0000 Subject: [PATCH 504/702] Prepare for v0.4.33 release. This release is branched off the v0.4.32 release, with two changes: a) a fixed libtpu pin, and b) a patch to revert an F64 tanh issue on CPU. --- CHANGELOG.md | 24 ++++++++++++++++++++++-- jax/version.py | 4 ++-- setup.py | 4 ++-- third_party/xla/tanh.patch | 14 ++++++++++++++ third_party/xla/workspace.bzl | 3 +++ 5 files changed, 43 insertions(+), 6 deletions(-) create mode 100644 third_party/xla/tanh.patch diff --git a/CHANGELOG.md b/CHANGELOG.md index 869b9dfdd196..affd94c156a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,24 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.4.32 +## jax 0.4.33 + +This is a patch release on top of jax 0.4.32, that fixes two bugs found in that +release. + +A TPU-only data corruption bug was found in the version of libtpu pinned by +JAX 0.4.32, which manifested only if multiple TPU slices were present in the +same job, for example, if training on multiple v5e slices. +This release fixes that issue by pinning a fixed version of `libtpu`. + +## jaxlib 0.4.33 + +This release fixes an inaccurate result for F64 tanh on CPU (#23590). + +## jax 0.4.32 (September 11, 2024) + +Note: This release was yanked from PyPi because of a data corruption bug on TPU. +See the 0.4.33 release notes for more details. * New Functionality * Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering` @@ -65,7 +82,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. The argument to {func}`jax.dlpack.from_dlpack` should be an array from another framework that implements the ``__dlpack__`` protocol. -## jaxlib 0.4.32 +## jaxlib 0.4.32 (September 11, 2024) + +Note: This release was yanked from PyPi because of a data corruption bug on TPU. +See the 0.4.33 release notes for more details. * Breaking changes * Hermetic CUDA support is added. diff --git a/jax/version.py b/jax/version.py index f2c34d275b01..8047895a640e 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.4.32" +_version = "0.4.33" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.32" +_minimum_jaxlib_version = "0.4.33" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index 027e5aefbc2f..a7e9bb3df7ba 100644 --- a/setup.py +++ b/setup.py @@ -19,10 +19,10 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.32' +_current_jaxlib_version = '0.4.33' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.31' -_libtpu_version = '0.1.dev20240911' +_libtpu_version = '0.1.dev20240916' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( diff --git a/third_party/xla/tanh.patch b/third_party/xla/tanh.patch new file mode 100644 index 000000000000..fce0ee6fdb68 --- /dev/null +++ b/third_party/xla/tanh.patch @@ -0,0 +1,14 @@ +diff --git a/xla/service/cpu/llvm_ir_runtime.cc b/xla/service/cpu/llvm_ir_runtime.cc +index 89b40b915caa3..25541c16bfd61 100644 +--- a/xla/service/cpu/llvm_ir_runtime.cc ++++ b/xla/service/cpu/llvm_ir_runtime.cc +@@ -410,7 +410,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, + rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8); + rewrite_calls(kTanhV16F32SymbolName, GenerateVF32Tanh, /*vector_width=*/16); + +- rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1); ++ // TODO(penporn): Re-enable after fixing JAX issue #23590. ++ // rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1); + + rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1); + rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1); \ No newline at end of file diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 8f4accca508c..78b3736972e5 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -30,6 +30,9 @@ def repo(): sha256 = XLA_SHA256, strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), + patch_file = [ + "//third_party/xla:tanh.patch", + ], ) # For development, one often wants to make changes to the TF repository as well From 8c39d0373afaf67ecbaf9295fcf67780405e913c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 16 Sep 2024 08:18:01 -0700 Subject: [PATCH 505/702] Added a new primitive for copying GMEM<->SMEM in Pallas Mosaic GPU kernels The copy is async and needs to be awaited via `plgpu.wait_inflight(...)` for SMEM->GMEM copies and via `plgpu.wait(barrier)` for GMEM->SMEM copies. I decided to have distinct functions for SMEM->GMEM and GMEM->SMEM copies and for the ways to await the result, because the underlying Mosaic GPU APIs (and PTX ISA) *are* in fact very different. PiperOrigin-RevId: 675155317 --- jax/_src/pallas/mosaic_gpu/BUILD | 15 +++ jax/_src/pallas/mosaic_gpu/__init__.py | 16 ++++ jax/_src/pallas/mosaic_gpu/core.py | 40 ++++++-- jax/_src/pallas/mosaic_gpu/lowering.py | 115 ++++++++++++++++------- jax/_src/pallas/mosaic_gpu/primitives.py | 105 +++++++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 41 +++++++- 6 files changed, 292 insertions(+), 40 deletions(-) create mode 100644 jax/_src/pallas/mosaic_gpu/primitives.py diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index c3e8fc8b83de..171ff0439085 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -34,6 +34,7 @@ py_library( deps = [ ":core", ":pallas_call_registration", + ":primitives", ], ) @@ -72,8 +73,22 @@ pytype_strict_library( deps = [ "//jax", "//jax:core", + "//jax:dtypes", "//jax:mosaic_gpu", "//jax:tree_util", "//jax/_src/pallas", ] + py_deps("numpy"), ) + +pytype_strict_library( + name = "primitives", + srcs = ["primitives.py"], + deps = [ + ":core", + ":lowering", + "//jax", + "//jax:core", + "//jax:mosaic_gpu", + "//jax/_src/pallas", + ], +) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index 862a661e24b9..97732acbd830 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -11,3 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# TODO(slebedev): Move these imports to ``jax.experimental.pallas``. + +from jax._src.pallas.mosaic_gpu.core import Barrier +from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec +from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams +from jax._src.pallas.mosaic_gpu.core import GPUGridSpec +from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace +from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem +from jax._src.pallas.mosaic_gpu.primitives import wait +from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem + +GMEM = GPUMemorySpace.GMEM +SMEM = GPUMemorySpace.SMEM +REGS = GPUMemorySpace.REGS diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 025b9f1b57d0..3ef205d336d0 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -18,8 +18,9 @@ import dataclasses import enum from typing import Any, ClassVar, Literal, Protocol -from jax import core as jax_core -from jax._src import core + +from jax._src import core as jax_core +from jax._src import dtypes from jax._src import tree_util from jax._src.pallas import core as pallas_core from jax.experimental.mosaic import gpu as mosaic_gpu @@ -118,9 +119,9 @@ class GPUBlockSpec(pallas_core.BlockSpec): def to_block_mapping( self, origin: pallas_core.OriginStr, - array_aval: core.ShapedArray, + array_aval: jax_core.ShapedArray, *, - index_map_avals: Sequence[core.AbstractValue], + index_map_avals: Sequence[jax_core.AbstractValue], index_map_tree: tree_util.PyTreeDef, grid: pallas_core.GridMappingGrid, mapped_dims: tuple[int, ...], @@ -163,8 +164,8 @@ def __init__( super().__init__(grid, in_specs, out_specs) self.scratch_shapes = tuple(scratch_shapes) - def _make_scratch_aval(self, obj: object) -> core.AbstractValue: - if isinstance(obj, MemoryRef): + def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: + if isinstance(obj, (MemoryRef, Barrier)): return obj.get_aval() raise TypeError(f"Cannot convert {obj} to an abstract value") @@ -186,3 +187,30 @@ def get_aval(self) -> AbstractMemoryRef: GMEM = GPUMemorySpace.GMEM SMEM = GPUMemorySpace.SMEM REGS = GPUMemorySpace.REGS + + +class barrier_dtype(dtypes.extended): + pass + + +@dataclasses.dataclass(frozen=True) +class BarrierType(dtypes.ExtendedDType): + type: ClassVar[Any] = barrier_dtype + name: ClassVar[str] = "barrier" + + num_arrivals: int + + def __str__(self): + return self.name + + +@dataclasses.dataclass(frozen=True) +class Barrier: + num_arrivals: int + num_barriers: int = 1 + + def get_aval(self) -> AbstractMemoryRef: + aval = jax_core.ShapedArray( + [self.num_barriers], BarrierType(self.num_arrivals) + ) + return AbstractMemoryRef(aval, SMEM) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 7c8f2e85f27b..39483e674681 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -158,7 +158,8 @@ def stack_free_smem(self, bytes: int): @dataclasses.dataclass(frozen=True) class LoweringRuleContext: - module_context: ModuleContext + module_ctx: ModuleContext + launch_ctx: mosaic_gpu.LaunchContext avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] @@ -177,10 +178,13 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name def _eval_index_map( - ctx: ModuleContext, idx: ir.Value, block_mapping: pallas_core.BlockMapping + module_ctx: ModuleContext, + launch_ctx: mosaic_gpu.LaunchContext, + idx: ir.Value, + block_mapping: pallas_core.BlockMapping, ) -> Sequence[ir.Value]: block_indices = lower_jaxpr_to_mosaic_gpu( - ctx, block_mapping.index_map_jaxpr.jaxpr, idx + module_ctx, launch_ctx, block_mapping.index_map_jaxpr.jaxpr, idx ) result = [] for i, b in zip(block_indices, block_mapping.block_shape): @@ -241,8 +245,8 @@ def lower_jaxpr_to_module( dimension_semantics = ["parallel"] * len(grid_mapping.grid) elif len(dimension_semantics) != len(grid_mapping.grid): raise ValueError( - "dimension_semantics must have an entrey for each grid dimension:" - f" {len(dimension_semantics)=}, but len(grid={grid_mapping.grid})." + "dimension_semantics must have an entry for each grid dimension:" + f" {len(dimension_semantics)=}, but len(grid) is {grid_mapping.grid})." ) sequential_axes = tuple( i for i, s in enumerate(dimension_semantics) if s == "sequential" @@ -250,6 +254,14 @@ def lower_jaxpr_to_module( assert all(grid[axis] for axis in sequential_axes) assert all(block[axis] == 1 for axis in sequential_axes) + in_in_smem, out_in_smem = util.split_list( + [ + bm.block_aval.memory_space in (None, gpu_core.SMEM) + for bm in block_mappings + ], + [grid_mapping.num_inputs], + ) + in_structs_gmem = [*grid_mapping.in_shapes] in_block_shapes = [ bm.block_shape @@ -260,7 +272,11 @@ def lower_jaxpr_to_module( [num_stages, *bm.ref_aval.inner_aval.shape], bm.ref_aval.inner_aval.dtype, ) - for bm in block_mappings[: grid_mapping.num_inputs] + if in_smem + else None + for bm, in_smem in zip( + block_mappings[: grid_mapping.num_inputs], in_in_smem + ) ] in_gmem_transforms = [ cast(gpu_core.MemoryRefTransform, bm.transforms) @@ -277,31 +293,37 @@ def lower_jaxpr_to_module( # TODO(justinfu): Implement output Memref transforms out_structs_smem = [ jax.ShapeDtypeStruct([num_stages, *bm.block_shape], s.dtype) - for bm, s in zip( + if in_smem + else None + for bm, in_smem, s in zip( block_mappings[grid_mapping.num_inputs :], + out_in_smem, grid_mapping.out_shapes, ) ] def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value): - *buffers_gmem, (*buffers_smem, runtime_smem, barriers) = buffers - assert ( - len(buffers_gmem) - == len(buffers_smem) - grid_mapping.num_scratch_operands - ) + *buffers_gmem, ( + buffers_smem, + *scratch_buffers_smem, + runtime_smem, + barriers, + ) = buffers + assert len(buffers_gmem) == len(buffers_smem) in_buffers_gmem, out_buffers_gmem = util.split_list( buffers_gmem, [grid_mapping.num_inputs] ) - in_buffers_smem, out_buffers_smem, scratch_buffers_smem = util.split_list( - buffers_smem, [grid_mapping.num_inputs, grid_mapping.num_outputs] + in_buffers_smem, out_buffers_smem = util.split_list( + buffers_smem, [grid_mapping.num_inputs] ) + barriers, *extra_barriers = barriers module_ctx = ModuleContext( name_and_src_info.name, grid_mapping, approx_math, runtime_smem ) program_ids = map(_program_id, range(len(grid_mapping.grid))) start_indices = map( - functools.partial(_eval_index_map, module_ctx, program_ids), + functools.partial(_eval_index_map, module_ctx, launch_ctx, program_ids), block_mappings, ) in_start_indices, out_start_indices = util.split_list( @@ -311,7 +333,9 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value): # Precompute the total number of bytes transferred from GMEM to SMEM, # so that we can do a single arrive instruction for all of the inputs. in_transfer_bytes = 0 - for b_smem in in_buffers_smem: + for in_smem, b_smem in zip(in_in_smem, in_buffers_smem): + if not in_smem: + continue b_smem_type = ir.MemRefType(b_smem.type) in_transfer_bytes += math.prod(b_smem_type.shape[1:]) * mgpu.bytewidth( b_smem_type.element_type @@ -335,6 +359,9 @@ def gmem_slice( ) def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: + if not in_in_smem[idx]: + return + # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. gmem_transforms = (x.to_gpu_transform() for x in in_gmem_transforms[idx]) launch_ctx.async_copy( @@ -353,6 +380,9 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: ) def store(idx: int, step: ir.Value, slot: ir.Value) -> None: + if not out_in_smem[idx]: + return + # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. launch_ctx.async_copy( src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot), @@ -377,6 +407,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: if any( b_gmem.shape[sequential_axis] % b_smem.shape[1 + sequential_axis] for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) + if b_smem ): raise ValueError( "Array dimensions along the sequential axis must be divisible by" @@ -385,6 +416,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: num_steps, *rest = { b_gmem.shape[sequential_axis] // b_smem.shape[1 + sequential_axis] for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) + if b_smem } if rest: raise ValueError( @@ -409,11 +441,14 @@ def _(step, _): barriers[slot].wait() args = [ - mgpu.memref_slice(b_smem, slot) - for b_smem in it.chain(in_buffers_smem, out_buffers_smem) + mgpu.memref_slice(buffers_smem[idx], slot) + if in_smem + else buffers_gmem[idx] + for idx, in_smem in enumerate(it.chain(in_in_smem, out_in_smem)) ] args.extend(scratch_buffers_smem) - _ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, args) + args.extend(extra_barriers) + _ = lower_jaxpr_to_mosaic_gpu(module_ctx, launch_ctx, jaxpr, args) mgpu.commit_shared() with mgpu.single_thread(): @@ -444,8 +479,15 @@ def _(step, _): raise TypeError( f"All scratch operands must be in SMEM, but got: {scratch_avals}" ) + extra_barriers = [ + mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape) + for aval in scratch_avals + if isinstance(aval.dtype, gpu_core.BarrierType) + ] extra_smem_scratch = [ - jax.ShapeDtypeStruct(aval.shape, aval.dtype) for aval in scratch_avals + jax.ShapeDtypeStruct(aval.shape, aval.dtype) + for aval in scratch_avals + if not isinstance(aval.dtype, gpu_core.BarrierType) ] smem_scratch_bytes = compiler_params.get("smem_scratch_bytes") if smem_scratch_bytes is None: @@ -462,10 +504,12 @@ def _(step, _): in_shapes=in_structs_gmem, out_shape=out_structs_gmem, smem_scratch_shape=( - *in_structs_smem, - *out_structs_smem, + (*in_structs_smem, *out_structs_smem), *extra_smem_scratch, - mgpu.Barrier(arrival_count=1, num_barriers=num_stages), + ( + mgpu.Barrier(arrival_count=1, num_barriers=num_stages), + *extra_barriers, + ), ), module_name=name_and_src_info.name, ) @@ -485,7 +529,8 @@ def deco(fn): def lower_jaxpr_to_mosaic_gpu( - ctx: ModuleContext, + module_ctx: ModuleContext, + launch_ctx: mosaic_gpu.LaunchContext, jaxpr: jax_core.Jaxpr, args: Sequence[ir.Value], consts=(), @@ -510,7 +555,8 @@ def write_env(var: jax_core.Var, val): ) rule = mosaic_lowering_rules[eqn.primitive] rule_ctx = LoweringRuleContext( - ctx, + module_ctx, + launch_ctx, avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], ) @@ -577,7 +623,9 @@ def _swap_lowering_rule( def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): if jaxpr.consts: raise NotImplementedError - return lower_jaxpr_to_mosaic_gpu(ctx.module_context, jaxpr.jaxpr, args) + return lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args + ) @register_lowering_rule(lax.broadcast_in_dim_p) @@ -624,7 +672,7 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): @register_lowering_rule(lax.rsqrt_p) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): - return _ensure_fa(x, *ctx.avals_in).rsqrt(ctx.module_context.approx_math) + return _ensure_fa(x, *ctx.avals_in).rsqrt(ctx.module_ctx.approx_math) @register_lowering_rule(lax.reduce_sum_p) @@ -632,7 +680,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): if axes != (0,): raise NotImplementedError("No support for axes other than 0 yet") [x_aval] = ctx.avals_in - _, [scratch] = ctx.module_context.scratch_view( + _, [scratch] = ctx.module_ctx.scratch_view( [jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype)] ) return mgpu.FragmentedArray.splat(x.reduce_sum(scratch), ()) @@ -657,13 +705,14 @@ def _run_scoped_lowering_rule( ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr ): in_avals = [v.aval.inner_aval for v in jaxpr.invars] - bytes_allocated, input_refs = ctx.module_context.scratch_view( - [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype) for aval in in_avals] - ) + bytes_allocated, input_refs = ctx.module_ctx.scratch_view([ + jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype) + for aval in in_avals + ]) outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_context, jaxpr, input_refs, consts + ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts ) - ctx.module_context.stack_free_smem(bytes_allocated) + ctx.module_ctx.stack_free_smem(bytes_allocated) return outs diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py new file mode 100644 index 000000000000..b53901e30612 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -0,0 +1,105 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU-specific Pallas primitives.""" + +from __future__ import annotations + +from jax._src import core as jax_core +from jax._src import state +from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.pallas.mosaic_gpu import lowering + + +async_copy_p = jax_core.Primitive("async_copy") +async_copy_p.multiple_results = True + + +@async_copy_p.def_effectful_abstract_eval +def _async_copy_abstract_eval(*avals): + del avals # Unused. + return (), {state.ReadEffect(0), state.WriteEffect(1)} + + +@lowering.register_lowering_rule(async_copy_p) +def _async_copy_lowering_rule( + ctx: lowering.LoweringRuleContext, src, dst, barrier=None +): + ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, barrier=barrier) + return () + + +def async_copy_smem_to_gmem( + src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef +) -> None: + if src.memory_space is not gpu_core.SMEM: + raise TypeError(f"src must be a SMEM reference, got {src.memory_space}") + if dst.memory_space is not gpu_core.GMEM: + raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}") + async_copy_p.bind(src, dst) + return None + + +def async_copy_gmem_to_smem( + src: pallas_core.AbstractMemoryRef, + dst: pallas_core.AbstractMemoryRef, + *, + barrier: pallas_core.AbstractMemoryRef, +) -> None: + if src.memory_space is not gpu_core.GMEM: + raise TypeError(f"src must be a GMEM reference, got {src.memory_space}") + if dst.memory_space is not gpu_core.SMEM: + raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}") + async_copy_p.bind(src, dst, barrier) + return None + + +class WaitEffect(jax_core.Effect): + ... + + +wait_effect = WaitEffect() + + +wait_p = jax_core.Primitive("wait") +wait_p.multiple_results = True + + +@wait_p.def_effectful_abstract_eval +def _wait_abstract_eval(*avals, **params): + del avals, params # Unused. + return (), {wait_effect} + + +@lowering.register_lowering_rule(wait_p) +def _wait_lowering_rule( + ctx: lowering.LoweringRuleContext, barrier=None, allow_groups=None, +): + if barrier is not None: + barrier.wait() + else: + assert allow_groups is not None + ctx.launch_ctx.await_async_copy(allow_groups=allow_groups) + return () + + +def wait_smem_to_gmem(allow_groups: int) -> None: + """Waits until there are no more than the given number of SMEM->GMEM copies in flight.""" + wait_p.bind(allow_groups=allow_groups) + + +def wait(barrier: pallas_core.AbstractMemoryRef) -> None: + """Waits on the given barrier.""" + wait_p.bind(barrier) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0ee12edd883e..ee8858282a33 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -19,7 +19,7 @@ import jax from jax._src import config from jax._src import test_util as jtu -import jax._src.pallas.mosaic_gpu.core as plgpu +import jax._src.pallas.mosaic_gpu as plgpu from jax.experimental import pallas as pl import jax.numpy as jnp import numpy as np @@ -116,6 +116,45 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_add_one_with_async_copy_smem_to_gmem(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + grid_spec=plgpu.GPUGridSpec( + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], + ), + ) + def kernel(x_ref, o_ref_gmem, scratch_ref): + scratch_ref[...] = x_ref[...] + 1 + plgpu.async_copy_smem_to_gmem(scratch_ref, o_ref_gmem) + plgpu.wait_smem_to_gmem(0) + + x = jnp.arange(128).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + + def test_add_one_with_async_copy_gmem_to_smem(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + grid_spec=plgpu.GPUGridSpec( + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((128,), jnp.float32), + plgpu.Barrier(num_arrivals=1), + ], + ), + ) + def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): + plgpu.async_copy_gmem_to_smem( + x_ref_gmem, scratch_ref, barrier=barrier_ref + ) + plgpu.wait(barrier_ref) + o_ref[...] = scratch_ref[...] + 1 + + x = jnp.arange(128).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_add_doubled_sum(self): @functools.partial( pl.pallas_call, From 0942458d7135b15e5281fc8df085b7888c673d32 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 16 Sep 2024 23:16:50 +0530 Subject: [PATCH 506/702] Improve doc for jnp.logaddexp2 --- jax/_src/numpy/ufuncs.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index d54795f02b5d..b45b3370fe53 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -1744,11 +1744,39 @@ def _wrap_between(x, _a): return lax.sub(rem, a) -@custom_jvp -@implements(np.logaddexp2, module='numpy') @jit def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. + + JAX implementation of :obj:`numpy.logaddexp2`. + + Args: + x1: input array or scalar. + x2: input array or scalar. ``x1`` and ``x2`` should either have same shape or + be broadcast compatible. + + Returns: + An array containing the result, :math:`log_2(2^{x1}+2^{x2})`, element-wise. + + See also: + - :func:`jax.numpy.logaddexp`: Computes ``log(exp(x1) + exp(x2))``, element-wise. + - :func:`jax.numpy.log2`: Calculates the base-2 logarithm of ``x`` element-wise. + + Examples: + >>> x1 = jnp.array([[3, -1, 4], + ... [8, 5, -2]]) + >>> x2 = jnp.array([2, 3, -5]) + >>> result1 = jnp.logaddexp2(x1, x2) + >>> result2 = jnp.log2(jnp.exp2(x1) + jnp.exp2(x2)) + >>> jnp.allclose(result1, result2) + Array(True, dtype=bool) + """ x1, x2 = promote_args_inexact("logaddexp2", x1, x2) + return _logaddexp2(x1, x2) + + +@custom_jvp +def _logaddexp2(x1, x2): amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) @@ -1762,7 +1790,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) -@logaddexp2.defjvp +@_logaddexp2.defjvp def _logaddexp2_jvp(primals, tangents): x1, x2 = primals t1, t2 = tangents @@ -1775,7 +1803,7 @@ def _logaddexp2_jvp(primals, tangents): @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: - """Calculates the base-2 logarithm of x element-wise + """Calculates the base-2 logarithm of ``x`` element-wise. JAX implementation of :obj:`numpy.log2`. From 321b4fbfbf2ace28ad32c83e9890691751fcd46e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 16 Sep 2024 11:03:11 -0700 Subject: [PATCH 507/702] Remove unused string global --- jax/_src/numpy/lax_numpy.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 5137b20bc898..1fd67302eaa6 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4497,14 +4497,6 @@ def _supports_buffer_protocol(obj): return True -_ARRAY_DOC = """ -This function will create arrays on JAX's default device. For control of the -device placement of data, see :func:`jax.device_put`. More information is -available in the JAX FAQ at :ref:`faq-data-placement` (full FAQ at -https://jax.readthedocs.io/en/latest/faq.html). -""" - - def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0, *, device: xc.Device | Sharding | None = None) -> Array: From 7dde9b29098c587b47d2718f7fe5a0a01633e471 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Tue, 10 Sep 2024 16:16:02 -0700 Subject: [PATCH 508/702] frombuffer_docstring_added description_changed_examp_added doc_byte_fixed discription_modified --- jax/_src/numpy/lax_numpy.py | 43 ++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6c6cacef673e..ef6a30400d30 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5194,9 +5194,50 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: # General np.from* style functions mostly delegate to numpy. -@util.implements(np.frombuffer) def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: + r"""Convert a buffer into a 1-D JAX array. + + JAX implementation of :func:`numpy.frombuffer`. + + Args: + buffer: an object containing the data. It must be either a bytes object with + a length that is an integer multiple of the dtype element size, or + it must be an object exporting the `Python buffer interface`_. + dtype: optional. Desired data type for the array. Default is ``float64``. + This specifes the dtype used to parse the buffer, but note that after parsing, + 64-bit values will be cast to 32-bit JAX arrays if the ``jax_enable_x64`` + flag is set to ``False``. + count: optional integer specifying the number of items to read from the buffer. + If -1 (default), all items from the buffer are read. + offset: optional integer specifying the number of bytes to skip at the beginning + of the buffer. Default is 0. + + Returns: + A 1-D JAX array representing the interpreted data from the buffer. + + See also: + - :func:`jax.numpy.fromstring`: convert a string of text into 1-D JAX array. + + Examples: + Using a bytes buffer: + + >>> buf = b"\x00\x01\x02\x03\x04" + >>> jnp.frombuffer(buf, dtype=jnp.uint8) + Array([0, 1, 2, 3, 4], dtype=uint8) + >>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1) + Array([1, 2, 3, 4], dtype=uint8) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [0, 1, 2, 3, 4]) + >>> jnp.frombuffer(pybuffer, dtype=jnp.int32) + Array([0, 1, 2, 3, 4], dtype=int32) + + .. _Python buffer interface: https://docs.python.org/3/c-api/buffer.html + """ return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset)) From 8ab66c8103f4e8f86f1786c54076b5d332cd5535 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 16 Sep 2024 11:46:23 -0700 Subject: [PATCH 509/702] Fix the TPU and GPU nightly install instructions. PiperOrigin-RevId: 675233702 --- docs/installation.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 7a12f7c541a2..93df4a240a55 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -282,13 +282,13 @@ pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/j - Google Cloud TPU: ```bash -pip install -U --pre jax[tpu] jaxlib libtpu-nightly -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` - NVIDIA GPU (CUDA 12): ```bash -pip install -U --pre jax[cuda12] jaxlib jax-cuda12-plugin jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html ``` - NVIDIA GPU (CUDA 12) legacy: From 8804be02295d341d5ad7c080a83b186e557b5922 Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Mon, 16 Sep 2024 12:13:47 -0700 Subject: [PATCH 510/702] Add Python 3.130rc2 support to the build. This PR depends on https://github.com/openxla/xla/pull/17169. The change does not fail existing builds, but to be able to use python 3.13 functionality in jax the corresponding XLA pr needs to land first and get integrated with JAX (happens automatically). PiperOrigin-RevId: 675243989 --- WORKSPACE | 7 +- build/requirements.in | 5 +- build/requirements_lock_3_13.txt | 722 +++++++++++++++++++++++++++++-- 3 files changed, 684 insertions(+), 50 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 383adf810766..ed284acadf81 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -37,9 +37,10 @@ install_deps() load("@xla//third_party/py:python_repo.bzl", "custom_python_interpreter") custom_python_interpreter( name = "python_dev", - urls = ["https://www.python.org/ftp/python/3.13.0/Python-{version}.tgz"], - strip_prefix = "Python-{version}", - version = "3.13.0a6", + urls = ["https://www.python.org/ftp/python/{version}/Python-{version_variant}.tgz"], + strip_prefix = "Python-{version_variant}", + version = "3.13.0", + version_variant = "3.13.0rc2", ) load("@xla//:workspace4.bzl", "xla_workspace4") diff --git a/build/requirements.in b/build/requirements.in index add6b8577350..f6b5b18b2660 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -11,12 +11,13 @@ matplotlib; python_version>="3.11" # # build deps # -numpy~=2.0.0 +numpy~=2.0.0; python_version<="3.12" +numpy~=2.1.0; python_version>="3.13" # # runtime deps # -scipy~=1.13.1 +scipy>=1.13.1 ml_dtypes>=0.4.0 opt_einsum diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 62b5e14e65b4..ef121b73713b 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -1,52 +1,423 @@ # -# This file is autogenerated by pip-compile with Python 3.12 +# This file is autogenerated by pip-compile with Python 3.13 # by the following command: # -# bazel run //build:requirements_dev.update +# bazel run //build:requirements.update # -absl-py==2.1.0 +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff # via -r build/test-requirements.txt -attrs==23.2.0 +attrs==24.2.0 \ + --hash=sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346 \ + --hash=sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2 # via hypothesis -build==1.2.1 +build==1.2.2 \ + --hash=sha256:119b2fb462adef986483438377a13b2f42064a2a3a4161f24a0cca698a07ac8c \ + --hash=sha256:277ccc71619d98afdd841a0e96ac9fe1593b823af481d3b0cea748e8894e0613 # via -r build/test-requirements.txt -cloudpickle==3.0.0 +cloudpickle==3.0.0 \ + --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ + --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 # via -r build/test-requirements.txt -colorama==0.4.6 +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 # via -r build/test-requirements.txt -contourpy==1.2.1 +contourpy==1.3.0 \ + --hash=sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0 \ + --hash=sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639 \ + --hash=sha256:0dce35502151b6bd35027ac39ba6e5a44be13a68f55735c3612c568cac3805fd \ + --hash=sha256:0fa4c02abe6c446ba70d96ece336e621efa4aecae43eaa9b030ae5fb92b309ad \ + --hash=sha256:14e262f67bd7e6eb6880bc564dcda30b15e351a594657e55b7eec94b6ef72843 \ + --hash=sha256:167d6c890815e1dac9536dca00828b445d5d0df4d6a8c6adb4a7ec3166812fa8 \ + --hash=sha256:1ec4dc6bf570f5b22ed0d7efba0dfa9c5b9e0431aeea7581aa217542d9e809a4 \ + --hash=sha256:303c252947ab4b14c08afeb52375b26781ccd6a5ccd81abcdfc1fafd14cf93c1 \ + --hash=sha256:31cd3a85dbdf1fc002280c65caa7e2b5f65e4a973fcdf70dd2fdcb9868069294 \ + --hash=sha256:32b238b3b3b649e09ce9aaf51f0c261d38644bdfa35cbaf7b263457850957a84 \ + --hash=sha256:33c92cdae89ec5135d036e7218e69b0bb2851206077251f04a6c4e0e21f03927 \ + --hash=sha256:345af746d7766821d05d72cb8f3845dfd08dd137101a2cb9b24de277d716def8 \ + --hash=sha256:3634b5385c6716c258d0419c46d05c8aa7dc8cb70326c9a4fb66b69ad2b52e09 \ + --hash=sha256:364174c2a76057feef647c802652f00953b575723062560498dc7930fc9b1cb7 \ + --hash=sha256:36e0cff201bcb17a0a8ecc7f454fe078437fa6bda730e695a92f2d9932bd507f \ + --hash=sha256:36f965570cff02b874773c49bfe85562b47030805d7d8360748f3eca570f4cab \ + --hash=sha256:3bb3808858a9dc68f6f03d319acd5f1b8a337e6cdda197f02f4b8ff67ad2057b \ + --hash=sha256:3e1c7fa44aaae40a2247e2e8e0627f4bea3dd257014764aa644f319a5f8600e3 \ + --hash=sha256:3faeb2998e4fcb256542e8a926d08da08977f7f5e62cf733f3c211c2a5586223 \ + --hash=sha256:420d39daa61aab1221567b42eecb01112908b2cab7f1b4106a52caaec8d36973 \ + --hash=sha256:4553c421929ec95fb07b3aaca0fae668b2eb5a5203d1217ca7c34c063c53d087 \ + --hash=sha256:4865cd1d419e0c7a7bf6de1777b185eebdc51470800a9f42b9e9decf17762081 \ + --hash=sha256:4cfb5c62ce023dfc410d6059c936dcf96442ba40814aefbfa575425a3a7f19dc \ + --hash=sha256:4d63ee447261e963af02642ffcb864e5a2ee4cbfd78080657a9880b8b1868e18 \ + --hash=sha256:570ef7cf892f0afbe5b2ee410c507ce12e15a5fa91017a0009f79f7d93a1268f \ + --hash=sha256:637f674226be46f6ba372fd29d9523dd977a291f66ab2a74fbeb5530bb3f445d \ + --hash=sha256:68a32389b06b82c2fdd68276148d7b9275b5f5cf13e5417e4252f6d1a34f72a2 \ + --hash=sha256:69375194457ad0fad3a839b9e29aa0b0ed53bb54db1bfb6c3ae43d111c31ce41 \ + --hash=sha256:6cb6cc968059db9c62cb35fbf70248f40994dfcd7aa10444bbf8b3faeb7c2d67 \ + --hash=sha256:710a26b3dc80c0e4febf04555de66f5fd17e9cf7170a7b08000601a10570bda6 \ + --hash=sha256:732896af21716b29ab3e988d4ce14bc5133733b85956316fb0c56355f398099b \ + --hash=sha256:75ee7cb1a14c617f34a51d11fa7524173e56551646828353c4af859c56b766e2 \ + --hash=sha256:76a896b2f195b57db25d6b44e7e03f221d32fe318d03ede41f8b4d9ba1bff53c \ + --hash=sha256:76c905ef940a4474a6289c71d53122a4f77766eef23c03cd57016ce19d0f7b42 \ + --hash=sha256:7a52040312b1a858b5e31ef28c2e865376a386c60c0e248370bbea2d3f3b760d \ + --hash=sha256:7ffa0db17717a8ffb127efd0c95a4362d996b892c2904db72428d5b52e1938a4 \ + --hash=sha256:81cb5ed4952aae6014bc9d0421dec7c5835c9c8c31cdf51910b708f548cf58e5 \ + --hash=sha256:834e0cfe17ba12f79963861e0f908556b2cedd52e1f75e6578801febcc6a9f49 \ + --hash=sha256:87ddffef1dbe5e669b5c2440b643d3fdd8622a348fe1983fad7a0f0ccb1cd67b \ + --hash=sha256:880ea32e5c774634f9fcd46504bf9f080a41ad855f4fef54f5380f5133d343c7 \ + --hash=sha256:8ca947601224119117f7c19c9cdf6b3ab54c5726ef1d906aa4a69dfb6dd58102 \ + --hash=sha256:90f73a5116ad1ba7174341ef3ea5c3150ddf20b024b98fb0c3b29034752c8aeb \ + --hash=sha256:92f8557cbb07415a4d6fa191f20fd9d2d9eb9c0b61d1b2f52a8926e43c6e9af7 \ + --hash=sha256:94e848a6b83da10898cbf1311a815f770acc9b6a3f2d646f330d57eb4e87592e \ + --hash=sha256:9c0da700bf58f6e0b65312d0a5e695179a71d0163957fa381bb3c1f72972537c \ + --hash=sha256:a11077e395f67ffc2c44ec2418cfebed032cd6da3022a94fc227b6faf8e2acb8 \ + --hash=sha256:aea348f053c645100612b333adc5983d87be69acdc6d77d3169c090d3b01dc35 \ + --hash=sha256:b11b39aea6be6764f84360fce6c82211a9db32a7c7de8fa6dd5397cf1d079c3b \ + --hash=sha256:c6c7c2408b7048082932cf4e641fa3b8ca848259212f51c8c59c45aa7ac18f14 \ + --hash=sha256:c6ec93afeb848a0845a18989da3beca3eec2c0f852322efe21af1931147d12cb \ + --hash=sha256:cacd81e2d4b6f89c9f8a5b69b86490152ff39afc58a95af002a398273e5ce589 \ + --hash=sha256:d402880b84df3bec6eab53cd0cf802cae6a2ef9537e70cf75e91618a3801c20c \ + --hash=sha256:d51fca85f9f7ad0b65b4b9fe800406d0d77017d7270d31ec3fb1cc07358fdea0 \ + --hash=sha256:d73f659398a0904e125280836ae6f88ba9b178b2fed6884f3b1f95b989d2c8da \ + --hash=sha256:d78ab28a03c854a873787a0a42254a0ccb3cb133c672f645c9f9c8f3ae9d0800 \ + --hash=sha256:da84c537cb8b97d153e9fb208c221c45605f73147bd4cadd23bdae915042aad6 \ + --hash=sha256:dbc4c3217eee163fa3984fd1567632b48d6dfd29216da3ded3d7b844a8014a66 \ + --hash=sha256:e12968fdfd5bb45ffdf6192a590bd8ddd3ba9e58360b29683c6bb71a7b41edca \ + --hash=sha256:e1fd23e9d01591bab45546c089ae89d926917a66dceb3abcf01f6105d927e2cb \ + --hash=sha256:e8134301d7e204c88ed7ab50028ba06c683000040ede1d617298611f9dc6240c \ + --hash=sha256:eb8b141bb00fa977d9122636b16aa67d37fd40a3d8b52dd837e536d64b9a4d06 \ + --hash=sha256:eca7e17a65f72a5133bdbec9ecf22401c62bcf4821361ef7811faee695799779 \ + --hash=sha256:f317576606de89da6b7e0861cf6061f6146ead3528acabff9236458a6ba467f8 \ + --hash=sha256:fd2a0fc506eccaaa7595b7e1418951f213cf8255be2600f1ea1b61e46a60c55f \ + --hash=sha256:fe41b41505a5a33aeaed2a613dccaeaa74e0e3ead6dd6fd3a118fb471644fd6c # via matplotlib -cycler==0.12.1 +cycler==0.12.1 \ + --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ + --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c # via matplotlib -etils[epath,epy]==1.8.0 +etils[epath,epy]==1.9.4 \ + --hash=sha256:4387e7a4911a3b5cc4b92b99a9211386d176b43bae1dac8e2fe345fc2cb95e4b \ + --hash=sha256:fad950414f0a1ca58c70c70915b0014f9953dd9bcf8aa951a0f75ff9becbeb24 # via -r build/requirements.in -execnet==2.1.1 +execnet==2.1.1 \ + --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ + --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 # via pytest-xdist -flatbuffers==24.3.25 +filelock==3.16.0 \ + --hash=sha256:81de9eb8453c769b63369f87f11131a7ab04e367f8d97ad39dc230daa07e3bec \ + --hash=sha256:f6ed4c963184f4c84dd5557ce8fece759a3724b37b80c6c4f20a2f63a4dc6609 # via -r build/test-requirements.txt -fonttools==4.51.0 +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r build/test-requirements.txt +fonttools==4.53.1 \ + --hash=sha256:02569e9a810f9d11f4ae82c391ebc6fb5730d95a0657d24d754ed7763fb2d122 \ + --hash=sha256:0679a30b59d74b6242909945429dbddb08496935b82f91ea9bf6ad240ec23397 \ + --hash=sha256:10f5e6c3510b79ea27bb1ebfcc67048cde9ec67afa87c7dd7efa5c700491ac7f \ + --hash=sha256:2af40ae9cdcb204fc1d8f26b190aa16534fcd4f0df756268df674a270eab575d \ + --hash=sha256:32f029c095ad66c425b0ee85553d0dc326d45d7059dbc227330fc29b43e8ba60 \ + --hash=sha256:35250099b0cfb32d799fb5d6c651220a642fe2e3c7d2560490e6f1d3f9ae9169 \ + --hash=sha256:3b3c8ebafbee8d9002bd8f1195d09ed2bd9ff134ddec37ee8f6a6375e6a4f0e8 \ + --hash=sha256:4824c198f714ab5559c5be10fd1adf876712aa7989882a4ec887bf1ef3e00e31 \ + --hash=sha256:5ff7e5e9bad94e3a70c5cd2fa27f20b9bb9385e10cddab567b85ce5d306ea923 \ + --hash=sha256:651390c3b26b0c7d1f4407cad281ee7a5a85a31a110cbac5269de72a51551ba2 \ + --hash=sha256:6e08f572625a1ee682115223eabebc4c6a2035a6917eac6f60350aba297ccadb \ + --hash=sha256:6ed170b5e17da0264b9f6fae86073be3db15fa1bd74061c8331022bca6d09bab \ + --hash=sha256:73379d3ffdeecb376640cd8ed03e9d2d0e568c9d1a4e9b16504a834ebadc2dfb \ + --hash=sha256:75a157d8d26c06e64ace9df037ee93a4938a4606a38cb7ffaf6635e60e253b7a \ + --hash=sha256:791b31ebbc05197d7aa096bbc7bd76d591f05905d2fd908bf103af4488e60670 \ + --hash=sha256:7b6b35e52ddc8fb0db562133894e6ef5b4e54e1283dff606fda3eed938c36fc8 \ + --hash=sha256:84ec3fb43befb54be490147b4a922b5314e16372a643004f182babee9f9c3407 \ + --hash=sha256:8959a59de5af6d2bec27489e98ef25a397cfa1774b375d5787509c06659b3671 \ + --hash=sha256:9dfdae43b7996af46ff9da520998a32b105c7f098aeea06b2226b30e74fbba88 \ + --hash=sha256:9e6ceba2a01b448e36754983d376064730690401da1dd104ddb543519470a15f \ + --hash=sha256:9efd176f874cb6402e607e4cc9b4a9cd584d82fc34a4b0c811970b32ba62501f \ + --hash=sha256:a1c7c5aa18dd3b17995898b4a9b5929d69ef6ae2af5b96d585ff4005033d82f0 \ + --hash=sha256:aae7bd54187e8bf7fd69f8ab87b2885253d3575163ad4d669a262fe97f0136cb \ + --hash=sha256:b21952c092ffd827504de7e66b62aba26fdb5f9d1e435c52477e6486e9d128b2 \ + --hash=sha256:b96cd370a61f4d083c9c0053bf634279b094308d52fdc2dd9a22d8372fdd590d \ + --hash=sha256:becc5d7cb89c7b7afa8321b6bb3dbee0eec2b57855c90b3e9bf5fb816671fa7c \ + --hash=sha256:bee32ea8765e859670c4447b0817514ca79054463b6b79784b08a8df3a4d78e3 \ + --hash=sha256:c6e7170d675d12eac12ad1a981d90f118c06cf680b42a2d74c6c931e54b50719 \ + --hash=sha256:c818c058404eb2bba05e728d38049438afd649e3c409796723dfc17cd3f08749 \ + --hash=sha256:c8696544c964500aa9439efb6761947393b70b17ef4e82d73277413f291260a4 \ + --hash=sha256:c9cd19cf4fe0595ebdd1d4915882b9440c3a6d30b008f3cc7587c1da7b95be5f \ + --hash=sha256:d4d0096cb1ac7a77b3b41cd78c9b6bc4a400550e21dc7a92f2b5ab53ed74eb02 \ + --hash=sha256:d92d3c2a1b39631a6131c2fa25b5406855f97969b068e7e08413325bc0afba58 \ + --hash=sha256:da33440b1413bad53a8674393c5d29ce64d8c1a15ef8a77c642ffd900d07bfe1 \ + --hash=sha256:e013aae589c1c12505da64a7d8d023e584987e51e62006e1bb30d72f26522c41 \ + --hash=sha256:e128778a8e9bc11159ce5447f76766cefbd876f44bd79aff030287254e4752c4 \ + --hash=sha256:e54f1bba2f655924c1138bbc7fa91abd61f45c68bd65ab5ed985942712864bbb \ + --hash=sha256:e5b708073ea3d684235648786f5f6153a48dc8762cdfe5563c57e80787c29fbb \ + --hash=sha256:e8bf06b94694251861ba7fdeea15c8ec0967f84c3d4143ae9daf42bbc7717fe3 \ + --hash=sha256:f08df60fbd8d289152079a65da4e66a447efc1d5d5a4d3f299cdd39e3b2e4a7d \ + --hash=sha256:f1f8758a2ad110bd6432203a344269f445a2907dc24ef6bccfd0ac4e14e0d71d \ + --hash=sha256:f677ce218976496a587ab17140da141557beb91d2a5c1a14212c994093f2eae2 # via matplotlib -fsspec==2024.3.1 +fsspec==2024.9.0 \ + --hash=sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8 \ + --hash=sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b # via etils -hypothesis==6.100.1 +hypothesis==6.112.1 \ + --hash=sha256:93631b1498b20d2c205ed304cbd41d50e9c069d78a9c773c1324ca094c5e30ce \ + --hash=sha256:b070d7a1bb9bd84706c31885c9aeddc138e2b36a9c112a91984f49501c567856 # via -r build/test-requirements.txt -importlib-resources==6.4.0 +importlib-resources==6.4.5 \ + --hash=sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065 \ + --hash=sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717 # via etils -iniconfig==2.0.0 +iniconfig==2.0.0 \ + --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ + --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -kiwisolver==1.4.5 +kiwisolver==1.4.7 \ + --hash=sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a \ + --hash=sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95 \ + --hash=sha256:0c18ec74c0472de033e1bebb2911c3c310eef5649133dd0bedf2a169a1b269e5 \ + --hash=sha256:0c6c43471bc764fad4bc99c5c2d6d16a676b1abf844ca7c8702bdae92df01ee0 \ + --hash=sha256:10849fb2c1ecbfae45a693c070e0320a91b35dd4bcf58172c023b994283a124d \ + --hash=sha256:18077b53dc3bb490e330669a99920c5e6a496889ae8c63b58fbc57c3d7f33a18 \ + --hash=sha256:18e0cca3e008e17fe9b164b55735a325140a5a35faad8de92dd80265cd5eb80b \ + --hash=sha256:22f499f6157236c19f4bbbd472fa55b063db77a16cd74d49afe28992dff8c258 \ + --hash=sha256:2a8781ac3edc42ea4b90bc23e7d37b665d89423818e26eb6df90698aa2287c95 \ + --hash=sha256:2e6039dcbe79a8e0f044f1c39db1986a1b8071051efba3ee4d74f5b365f5226e \ + --hash=sha256:34ea1de54beef1c104422d210c47c7d2a4999bdecf42c7b5718fbe59a4cac383 \ + --hash=sha256:3ab58c12a2cd0fc769089e6d38466c46d7f76aced0a1f54c77652446733d2d02 \ + --hash=sha256:3abc5b19d24af4b77d1598a585b8a719beb8569a71568b66f4ebe1fb0449460b \ + --hash=sha256:3bf1ed55088f214ba6427484c59553123fdd9b218a42bbc8c6496d6754b1e523 \ + --hash=sha256:3ce6b2b0231bda412463e152fc18335ba32faf4e8c23a754ad50ffa70e4091ee \ + --hash=sha256:3da53da805b71e41053dc670f9a820d1157aae77b6b944e08024d17bcd51ef88 \ + --hash=sha256:3f9362ecfca44c863569d3d3c033dbe8ba452ff8eed6f6b5806382741a1334bd \ + --hash=sha256:409afdfe1e2e90e6ee7fc896f3df9a7fec8e793e58bfa0d052c8a82f99c37abb \ + --hash=sha256:40fa14dbd66b8b8f470d5fc79c089a66185619d31645f9b0773b88b19f7223c4 \ + --hash=sha256:4322872d5772cae7369f8351da1edf255a604ea7087fe295411397d0cfd9655e \ + --hash=sha256:44756f9fd339de0fb6ee4f8c1696cfd19b2422e0d70b4cefc1cc7f1f64045a8c \ + --hash=sha256:46707a10836894b559e04b0fd143e343945c97fd170d69a2d26d640b4e297935 \ + --hash=sha256:48b571ecd8bae15702e4f22d3ff6a0f13e54d3d00cd25216d5e7f658242065ee \ + --hash=sha256:48be928f59a1f5c8207154f935334d374e79f2b5d212826307d072595ad76a2e \ + --hash=sha256:4bfa75a048c056a411f9705856abfc872558e33c055d80af6a380e3658766038 \ + --hash=sha256:4c00336b9dd5ad96d0a558fd18a8b6f711b7449acce4c157e7343ba92dd0cf3d \ + --hash=sha256:4c26ed10c4f6fa6ddb329a5120ba3b6db349ca192ae211e882970bfc9d91420b \ + --hash=sha256:4d05d81ecb47d11e7f8932bd8b61b720bf0b41199358f3f5e36d38e28f0532c5 \ + --hash=sha256:4e77f2126c3e0b0d055f44513ed349038ac180371ed9b52fe96a32aa071a5107 \ + --hash=sha256:5337ec7809bcd0f424c6b705ecf97941c46279cf5ed92311782c7c9c2026f07f \ + --hash=sha256:5360cc32706dab3931f738d3079652d20982511f7c0ac5711483e6eab08efff2 \ + --hash=sha256:58370b1ffbd35407444d57057b57da5d6549d2d854fa30249771775c63b5fe17 \ + --hash=sha256:58cb20602b18f86f83a5c87d3ee1c766a79c0d452f8def86d925e6c60fbf7bfb \ + --hash=sha256:599b5c873c63a1f6ed7eead644a8a380cfbdf5db91dcb6f85707aaab213b1674 \ + --hash=sha256:5b7dfa3b546da08a9f622bb6becdb14b3e24aaa30adba66749d38f3cc7ea9706 \ + --hash=sha256:5b9c3f4ee0b9a439d2415012bd1b1cc2df59e4d6a9939f4d669241d30b414327 \ + --hash=sha256:5d34eb8494bea691a1a450141ebb5385e4b69d38bb8403b5146ad279f4b30fa3 \ + --hash=sha256:5d5abf8f8ec1f4e22882273c423e16cae834c36856cac348cfbfa68e01c40f3a \ + --hash=sha256:5e3bc157fed2a4c02ec468de4ecd12a6e22818d4f09cde2c31ee3226ffbefab2 \ + --hash=sha256:612a10bdae23404a72941a0fc8fa2660c6ea1217c4ce0dbcab8a8f6543ea9e7f \ + --hash=sha256:657a05857bda581c3656bfc3b20e353c232e9193eb167766ad2dc58b56504948 \ + --hash=sha256:65e720d2ab2b53f1f72fb5da5fb477455905ce2c88aaa671ff0a447c2c80e8e3 \ + --hash=sha256:693902d433cf585133699972b6d7c42a8b9f8f826ebcaf0132ff55200afc599e \ + --hash=sha256:6af936f79086a89b3680a280c47ea90b4df7047b5bdf3aa5c524bbedddb9e545 \ + --hash=sha256:71bb308552200fb2c195e35ef05de12f0c878c07fc91c270eb3d6e41698c3bcc \ + --hash=sha256:764202cc7e70f767dab49e8df52c7455e8de0df5d858fa801a11aa0d882ccf3f \ + --hash=sha256:76c8094ac20ec259471ac53e774623eb62e6e1f56cd8690c67ce6ce4fcb05650 \ + --hash=sha256:78a42513018c41c2ffd262eb676442315cbfe3c44eed82385c2ed043bc63210a \ + --hash=sha256:79849239c39b5e1fd906556c474d9b0439ea6792b637511f3fe3a41158d89ca8 \ + --hash=sha256:7ab9ccab2b5bd5702ab0803676a580fffa2aa178c2badc5557a84cc943fcf750 \ + --hash=sha256:7bbfcb7165ce3d54a3dfbe731e470f65739c4c1f85bb1018ee912bae139e263b \ + --hash=sha256:7c06a4c7cf15ec739ce0e5971b26c93638730090add60e183530d70848ebdd34 \ + --hash=sha256:801fa7802e5cfabe3ab0c81a34c323a319b097dfb5004be950482d882f3d7225 \ + --hash=sha256:803b8e1459341c1bb56d1c5c010406d5edec8a0713a0945851290a7930679b51 \ + --hash=sha256:82a5c2f4b87c26bb1a0ef3d16b5c4753434633b83d365cc0ddf2770c93829e3c \ + --hash=sha256:84ec80df401cfee1457063732d90022f93951944b5b58975d34ab56bb150dfb3 \ + --hash=sha256:8705f17dfeb43139a692298cb6637ee2e59c0194538153e83e9ee0c75c2eddde \ + --hash=sha256:88a9ca9c710d598fd75ee5de59d5bda2684d9db36a9f50b6125eaea3969c2599 \ + --hash=sha256:88f17c5ffa8e9462fb79f62746428dd57b46eb931698e42e990ad63103f35e6c \ + --hash=sha256:8a3ec5aa8e38fc4c8af308917ce12c536f1c88452ce554027e55b22cbbfbff76 \ + --hash=sha256:8a9c83f75223d5e48b0bc9cb1bf2776cf01563e00ade8775ffe13b0b6e1af3a6 \ + --hash=sha256:8b01aac285f91ca889c800042c35ad3b239e704b150cfd3382adfc9dcc780e39 \ + --hash=sha256:8d53103597a252fb3ab8b5845af04c7a26d5e7ea8122303dd7a021176a87e8b9 \ + --hash=sha256:8e045731a5416357638d1700927529e2b8ab304811671f665b225f8bf8d8f933 \ + --hash=sha256:8f0ea6da6d393d8b2e187e6a5e3fb81f5862010a40c3945e2c6d12ae45cfb2ad \ + --hash=sha256:90da3b5f694b85231cf93586dad5e90e2d71b9428f9aad96952c99055582f520 \ + --hash=sha256:913983ad2deb14e66d83c28b632fd35ba2b825031f2fa4ca29675e665dfecbe1 \ + --hash=sha256:9242795d174daa40105c1d86aba618e8eab7bf96ba8c3ee614da8302a9f95503 \ + --hash=sha256:929e294c1ac1e9f615c62a4e4313ca1823ba37326c164ec720a803287c4c499b \ + --hash=sha256:933d4de052939d90afbe6e9d5273ae05fb836cc86c15b686edd4b3560cc0ee36 \ + --hash=sha256:942216596dc64ddb25adb215c3c783215b23626f8d84e8eff8d6d45c3f29f75a \ + --hash=sha256:94252291e3fe68001b1dd747b4c0b3be12582839b95ad4d1b641924d68fd4643 \ + --hash=sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60 \ + --hash=sha256:9e838bba3a3bac0fe06d849d29772eb1afb9745a59710762e4ba3f4cb8424483 \ + --hash=sha256:a0f64a48bb81af7450e641e3fe0b0394d7381e342805479178b3d335d60ca7cf \ + --hash=sha256:a17f6a29cf8935e587cc8a4dbfc8368c55edc645283db0ce9801016f83526c2d \ + --hash=sha256:a1ecf0ac1c518487d9d23b1cd7139a6a65bc460cd101ab01f1be82ecf09794b6 \ + --hash=sha256:a79ae34384df2b615eefca647a2873842ac3b596418032bef9a7283675962644 \ + --hash=sha256:a91b5f9f1205845d488c928e8570dcb62b893372f63b8b6e98b863ebd2368ff2 \ + --hash=sha256:aa0abdf853e09aff551db11fce173e2177d00786c688203f52c87ad7fcd91ef9 \ + --hash=sha256:ac542bf38a8a4be2dc6b15248d36315ccc65f0743f7b1a76688ffb6b5129a5c2 \ + --hash=sha256:ad42ba922c67c5f219097b28fae965e10045ddf145d2928bfac2eb2e17673640 \ + --hash=sha256:aeb3531b196ef6f11776c21674dba836aeea9d5bd1cf630f869e3d90b16cfade \ + --hash=sha256:b38ac83d5f04b15e515fd86f312479d950d05ce2368d5413d46c088dda7de90a \ + --hash=sha256:b7d755065e4e866a8086c9bdada157133ff466476a2ad7861828e17b6026e22c \ + --hash=sha256:bd3de6481f4ed8b734da5df134cd5a6a64fe32124fe83dde1e5b5f29fe30b1e6 \ + --hash=sha256:bfa1acfa0c54932d5607e19a2c24646fb4c1ae2694437789129cf099789a3b00 \ + --hash=sha256:c619b101e6de2222c1fcb0531e1b17bbffbe54294bfba43ea0d411d428618c27 \ + --hash=sha256:ce8be0466f4c0d585cdb6c1e2ed07232221df101a4c6f28821d2aa754ca2d9e2 \ + --hash=sha256:cf0438b42121a66a3a667de17e779330fc0f20b0d97d59d2f2121e182b0505e4 \ + --hash=sha256:cf8bcc23ceb5a1b624572a1623b9f79d2c3b337c8c455405ef231933a10da379 \ + --hash=sha256:d2b0e12a42fb4e72d509fc994713d099cbb15ebf1103545e8a45f14da2dfca54 \ + --hash=sha256:d83db7cde68459fc803052a55ace60bea2bae361fc3b7a6d5da07e11954e4b09 \ + --hash=sha256:dda56c24d869b1193fcc763f1284b9126550eaf84b88bbc7256e15028f19188a \ + --hash=sha256:dea0bf229319828467d7fca8c7c189780aa9ff679c94539eed7532ebe33ed37c \ + --hash=sha256:e1631290ee9271dffe3062d2634c3ecac02c83890ada077d225e081aca8aab89 \ + --hash=sha256:e28c7fea2196bf4c2f8d46a0415c77a1c480cc0724722f23d7410ffe9842c407 \ + --hash=sha256:e2e6c39bd7b9372b0be21456caab138e8e69cc0fc1190a9dfa92bd45a1e6e904 \ + --hash=sha256:e33e8fbd440c917106b237ef1a2f1449dfbb9b6f6e1ce17c94cd6a1e0d438376 \ + --hash=sha256:e8df2eb9b2bac43ef8b082e06f750350fbbaf2887534a5be97f6cf07b19d9583 \ + --hash=sha256:e968b84db54f9d42046cf154e02911e39c0435c9801681e3fc9ce8a3c4130278 \ + --hash=sha256:eb542fe7933aa09d8d8f9d9097ef37532a7df6497819d16efe4359890a2f417a \ + --hash=sha256:edcfc407e4eb17e037bca59be0e85a2031a2ac87e4fed26d3e9df88b4165f92d \ + --hash=sha256:eee3ea935c3d227d49b4eb85660ff631556841f6e567f0f7bda972df6c2c9935 \ + --hash=sha256:ef97b8df011141c9b0f6caf23b29379f87dd13183c978a30a3c546d2c47314cb \ + --hash=sha256:f106407dda69ae456dd1227966bf445b157ccc80ba0dff3802bb63f30b74e895 \ + --hash=sha256:f3160309af4396e0ed04db259c3ccbfdc3621b5559b5453075e5de555e1f3a1b \ + --hash=sha256:f32d6edbc638cde7652bd690c3e728b25332acbadd7cad670cc4a02558d9c417 \ + --hash=sha256:f37cfe618a117e50d8c240555331160d73d0411422b59b5ee217843d7b693608 \ + --hash=sha256:f4c9aee212bc89d4e13f58be11a56cc8036cabad119259d12ace14b34476fd07 \ + --hash=sha256:f4d742cb7af1c28303a51b7a27aaee540e71bb8e24f68c736f6f2ffc82f2bf05 \ + --hash=sha256:f5a8b53bdc0b3961f8b6125e198617c40aeed638b387913bf1ce78afb1b0be2a \ + --hash=sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d \ + --hash=sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052 # via matplotlib -markdown-it-py==3.0.0 +markdown-it-py==3.0.0 \ + --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ + --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.8.3 +matplotlib==3.9.2 ; python_version >= "3.11" \ + --hash=sha256:039082812cacd6c6bec8e17a9c1e6baca230d4116d522e81e1f63a74d01d2e21 \ + --hash=sha256:03ba9c1299c920964e8d3857ba27173b4dbb51ca4bab47ffc2c2ba0eb5e2cbc5 \ + --hash=sha256:050598c2b29e0b9832cde72bcf97627bf00262adbc4a54e2b856426bb2ef0697 \ + --hash=sha256:18128cc08f0d3cfff10b76baa2f296fc28c4607368a8402de61bb3f2eb33c7d9 \ + --hash=sha256:1cd93b91ab47a3616b4d3c42b52f8363b88ca021e340804c6ab2536344fad9ca \ + --hash=sha256:1d94ff717eb2bd0b58fe66380bd8b14ac35f48a98e7c6765117fe67fb7684e64 \ + --hash=sha256:306c8dfc73239f0e72ac50e5a9cf19cc4e8e331dd0c54f5e69ca8758550f1e1e \ + --hash=sha256:37e51dd1c2db16ede9cfd7b5cabdfc818b2c6397c83f8b10e0e797501c963a03 \ + --hash=sha256:3fd595f34aa8a55b7fc8bf9ebea8aa665a84c82d275190a61118d33fbc82ccae \ + --hash=sha256:4876d7d40219e8ae8bb70f9263bcbe5714415acfdf781086601211335e24f8aa \ + --hash=sha256:5413401594cfaff0052f9d8b1aafc6d305b4bd7c4331dccd18f561ff7e1d3bd3 \ + --hash=sha256:5816b1e1fe8c192cbc013f8f3e3368ac56fbecf02fb41b8f8559303f24c5015e \ + --hash=sha256:65aacf95b62272d568044531e41de26285d54aec8cb859031f511f84bd8b495a \ + --hash=sha256:6758baae2ed64f2331d4fd19be38b7b4eae3ecec210049a26b6a4f3ae1c85dcc \ + --hash=sha256:6d1ce5ed2aefcdce11904fc5bbea7d9c21fff3d5f543841edf3dea84451a09ea \ + --hash=sha256:6d9f07a80deab4bb0b82858a9e9ad53d1382fd122be8cde11080f4e7dfedb38b \ + --hash=sha256:7741f26a58a240f43bee74965c4882b6c93df3e7eb3de160126d8c8f53a6ae6e \ + --hash=sha256:8912ef7c2362f7193b5819d17dae8629b34a95c58603d781329712ada83f9447 \ + --hash=sha256:909645cce2dc28b735674ce0931a4ac94e12f5b13f6bb0b5a5e65e7cea2c192b \ + --hash=sha256:96ab43906269ca64a6366934106fa01534454a69e471b7bf3d79083981aaab92 \ + --hash=sha256:9d78bbc0cbc891ad55b4f39a48c22182e9bdaea7fc0e5dbd364f49f729ca1bbb \ + --hash=sha256:ab68d50c06938ef28681073327795c5db99bb4666214d2d5f880ed11aeaded66 \ + --hash=sha256:ac43031375a65c3196bee99f6001e7fa5bdfb00ddf43379d3c0609bdca042df9 \ + --hash=sha256:ae82a14dab96fbfad7965403c643cafe6515e386de723e498cf3eeb1e0b70cc7 \ + --hash=sha256:b2696efdc08648536efd4e1601b5fd491fd47f4db97a5fbfd175549a7365c1b2 \ + --hash=sha256:b82c5045cebcecd8496a4d694d43f9cc84aeeb49fe2133e036b207abe73f4d30 \ + --hash=sha256:be0fc24a5e4531ae4d8e858a1a548c1fe33b176bb13eff7f9d0d38ce5112a27d \ + --hash=sha256:bf81de2926c2db243c9b2cbc3917619a0fc85796c6ba4e58f541df814bbf83c7 \ + --hash=sha256:c375cc72229614632c87355366bdf2570c2dac01ac66b8ad048d2dabadf2d0d4 \ + --hash=sha256:c797dac8bb9c7a3fd3382b16fe8f215b4cf0f22adccea36f1545a6d7be310b41 \ + --hash=sha256:cef2a73d06601437be399908cf13aee74e86932a5ccc6ccdf173408ebc5f6bb2 \ + --hash=sha256:d52a3b618cb1cbb769ce2ee1dcdb333c3ab6e823944e9a2d36e37253815f9556 \ + --hash=sha256:d719465db13267bcef19ea8954a971db03b9f48b4647e3860e4bc8e6ed86610f \ + --hash=sha256:d8dd059447824eec055e829258ab092b56bb0579fc3164fa09c64f3acd478772 \ + --hash=sha256:dbe196377a8248972f5cede786d4c5508ed5f5ca4a1e09b44bda889958b33f8c \ + --hash=sha256:e0830e188029c14e891fadd99702fd90d317df294c3298aad682739c5533721a \ + --hash=sha256:f053c40f94bc51bc03832a41b4f153d83f2062d88c72b5e79997072594e97e51 \ + --hash=sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49 \ + --hash=sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c \ + --hash=sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413 # via -r build/requirements.in -mdurl==0.1.2 +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.4.0 +ml-dtypes==0.5.0 \ + --hash=sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe \ + --hash=sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f \ + --hash=sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128 \ + --hash=sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab \ + --hash=sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15 \ + --hash=sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215 \ + --hash=sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f \ + --hash=sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf \ + --hash=sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4 \ + --hash=sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8 \ + --hash=sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859 \ + --hash=sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a \ + --hash=sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff \ + --hash=sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8 \ + --hash=sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02 \ + --hash=sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856 \ + --hash=sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07 \ + --hash=sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f \ + --hash=sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0 \ + --hash=sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7 \ + --hash=sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599 # via -r build/requirements.in -mpmath==1.3.0 +mpmath==1.4.0a1 \ + --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ + --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==1.26.4 +numpy==2.1.1 ; python_version >= "3.13" \ + --hash=sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5 \ + --hash=sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0 \ + --hash=sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550 \ + --hash=sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c \ + --hash=sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7 \ + --hash=sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2 \ + --hash=sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b \ + --hash=sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df \ + --hash=sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f \ + --hash=sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d \ + --hash=sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270 \ + --hash=sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd \ + --hash=sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504 \ + --hash=sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec \ + --hash=sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647 \ + --hash=sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f \ + --hash=sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab \ + --hash=sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe \ + --hash=sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5 \ + --hash=sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5 \ + --hash=sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e \ + --hash=sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd \ + --hash=sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313 \ + --hash=sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0 \ + --hash=sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f \ + --hash=sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6 \ + --hash=sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553 \ + --hash=sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed \ + --hash=sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb \ + --hash=sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e \ + --hash=sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39 \ + --hash=sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728 \ + --hash=sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e \ + --hash=sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a \ + --hash=sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95 \ + --hash=sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f \ + --hash=sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480 \ + --hash=sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9 \ + --hash=sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0 \ + --hash=sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f \ + --hash=sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd \ + --hash=sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae \ + --hash=sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201 \ + --hash=sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136 \ + --hash=sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf \ + --hash=sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78 \ + --hash=sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468 \ + --hash=sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca \ + --hash=sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef \ + --hash=sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0 \ + --hash=sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556 \ + --hash=sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521 \ + --hash=sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b # via # -r build/requirements.in # -r build/test-requirements.txt @@ -55,52 +426,313 @@ numpy==1.26.4 # ml-dtypes # opt-einsum # scipy -opt-einsum==3.3.0 +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 # via -r build/requirements.in -packaging==24.0 +packaging==24.1 \ + --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ + --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 # via # build # matplotlib # pytest -pillow==10.3.0 +pillow==10.4.0 \ + --hash=sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885 \ + --hash=sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea \ + --hash=sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df \ + --hash=sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5 \ + --hash=sha256:0a9ec697746f268507404647e531e92889890a087e03681a3606d9b920fbee3c \ + --hash=sha256:0ae24a547e8b711ccaaf99c9ae3cd975470e1a30caa80a6aaee9a2f19c05701d \ + --hash=sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd \ + --hash=sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06 \ + --hash=sha256:1b5dea9831a90e9d0721ec417a80d4cbd7022093ac38a568db2dd78363b00908 \ + --hash=sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a \ + --hash=sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be \ + --hash=sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0 \ + --hash=sha256:298478fe4f77a4408895605f3482b6cc6222c018b2ce565c2b6b9c354ac3229b \ + --hash=sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80 \ + --hash=sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a \ + --hash=sha256:32cda9e3d601a52baccb2856b8ea1fc213c90b340c542dcef77140dfa3278a9e \ + --hash=sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9 \ + --hash=sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696 \ + --hash=sha256:43efea75eb06b95d1631cb784aa40156177bf9dd5b4b03ff38979e048258bc6b \ + --hash=sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309 \ + --hash=sha256:4d9667937cfa347525b319ae34375c37b9ee6b525440f3ef48542fcf66f2731e \ + --hash=sha256:5161eef006d335e46895297f642341111945e2c1c899eb406882a6c61a4357ab \ + --hash=sha256:543f3dc61c18dafb755773efc89aae60d06b6596a63914107f75459cf984164d \ + --hash=sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060 \ + --hash=sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d \ + --hash=sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d \ + --hash=sha256:5b4815f2e65b30f5fbae9dfffa8636d992d49705723fe86a3661806e069352d4 \ + --hash=sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3 \ + --hash=sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6 \ + --hash=sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb \ + --hash=sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94 \ + --hash=sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b \ + --hash=sha256:7086cc1d5eebb91ad24ded9f58bec6c688e9f0ed7eb3dbbf1e4800280a896496 \ + --hash=sha256:73664fe514b34c8f02452ffb73b7a92c6774e39a647087f83d67f010eb9a0cf0 \ + --hash=sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319 \ + --hash=sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b \ + --hash=sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856 \ + --hash=sha256:7970285ab628a3779aecc35823296a7869f889b8329c16ad5a71e4901a3dc4ef \ + --hash=sha256:7a8d4bade9952ea9a77d0c3e49cbd8b2890a399422258a77f357b9cc9be8d680 \ + --hash=sha256:7c1ee6f42250df403c5f103cbd2768a28fe1a0ea1f0f03fe151c8741e1469c8b \ + --hash=sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42 \ + --hash=sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e \ + --hash=sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597 \ + --hash=sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a \ + --hash=sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8 \ + --hash=sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3 \ + --hash=sha256:8d4d5063501b6dd4024b8ac2f04962d661222d120381272deea52e3fc52d3736 \ + --hash=sha256:8f0aef4ef59694b12cadee839e2ba6afeab89c0f39a3adc02ed51d109117b8da \ + --hash=sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126 \ + --hash=sha256:950be4d8ba92aca4b2bb0741285a46bfae3ca699ef913ec8416c1b78eadd64cd \ + --hash=sha256:961a7293b2457b405967af9c77dcaa43cc1a8cd50d23c532e62d48ab6cdd56f5 \ + --hash=sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b \ + --hash=sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026 \ + --hash=sha256:a02364621fe369e06200d4a16558e056fe2805d3468350df3aef21e00d26214b \ + --hash=sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc \ + --hash=sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46 \ + --hash=sha256:b15e02e9bb4c21e39876698abf233c8c579127986f8207200bc8a8f6bb27acf2 \ + --hash=sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c \ + --hash=sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe \ + --hash=sha256:bcd5e41a859bf2e84fdc42f4edb7d9aba0a13d29a2abadccafad99de3feff984 \ + --hash=sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a \ + --hash=sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70 \ + --hash=sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca \ + --hash=sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b \ + --hash=sha256:cbed61494057c0f83b83eb3a310f0bf774b09513307c434d4366ed64f4128a91 \ + --hash=sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3 \ + --hash=sha256:d7480af14364494365e89d6fddc510a13e5a2c3584cb19ef65415ca57252fb84 \ + --hash=sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1 \ + --hash=sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5 \ + --hash=sha256:dfe91cb65544a1321e631e696759491ae04a2ea11d36715eca01ce07284738be \ + --hash=sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f \ + --hash=sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc \ + --hash=sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9 \ + --hash=sha256:e88d5e6ad0d026fba7bdab8c3f225a69f063f116462c49892b0149e21b6c0a0e \ + --hash=sha256:ecd85a8d3e79cd7158dec1c9e5808e821feea088e2f69a974db5edf84dc53141 \ + --hash=sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef \ + --hash=sha256:f5f0c3e969c8f12dd2bb7e0b15d5c468b51e5017e01e2e867335c81903046a22 \ + --hash=sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27 \ + --hash=sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e \ + --hash=sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1 # via # -r build/test-requirements.txt # matplotlib -pluggy==1.4.0 +pluggy==1.5.0 \ + --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ + --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest -portpicker==1.6.0 +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r build/test-requirements.txt -psutil==5.9.8 +psutil==6.0.0 \ + --hash=sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35 \ + --hash=sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0 \ + --hash=sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c \ + --hash=sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1 \ + --hash=sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3 \ + --hash=sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c \ + --hash=sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd \ + --hash=sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3 \ + --hash=sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0 \ + --hash=sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2 \ + --hash=sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6 \ + --hash=sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d \ + --hash=sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c \ + --hash=sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0 \ + --hash=sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132 \ + --hash=sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14 \ + --hash=sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0 # via portpicker -pygments==2.17.2 +pygments==2.18.0 \ + --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ + --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a # via rich -pyparsing==3.1.2 +pyparsing==3.2.0b1 \ + --hash=sha256:51e00c907f7b2ac2d2c35c4d431e944c525ddcfd58b09517f308f40d70e0ddca \ + --hash=sha256:ecf0805530839936196a802cd6d6d65ffa9328eebdc8ee5b8f4b358be5f16666 # via matplotlib -pyproject-hooks==1.0.0 +pyproject-hooks==1.1.0 \ + --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ + --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 # via build -pytest==8.1.1 +pytest==8.3.3 \ + --hash=sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181 \ + --hash=sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2 # via pytest-xdist -pytest-xdist==3.5.0 +pytest-xdist==3.6.1 \ + --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ + --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d # via -r build/test-requirements.txt -python-dateutil==2.9.0.post0 +python-dateutil==2.9.0.post0 \ + --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ + --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 # via matplotlib -rich==13.7.1 +rich==13.8.1 \ + --hash=sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06 \ + --hash=sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a # via -r build/test-requirements.txt -scipy==1.13.1 +scipy==1.14.1 \ + --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ + --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \ + --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \ + --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \ + --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \ + --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \ + --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \ + --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \ + --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \ + --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \ + --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \ + --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \ + --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \ + --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \ + --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \ + --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \ + --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \ + --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \ + --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \ + --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \ + --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \ + --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \ + --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \ + --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \ + --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \ + --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \ + --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \ + --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \ + --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \ + --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \ + --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \ + --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ + --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 # via -r build/requirements.in -six==1.16.0 +six==1.16.0 \ + --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ + --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via python-dateutil -sortedcontainers==2.4.0 +sortedcontainers==2.4.0 \ + --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ + --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis -typing-extensions==4.11.0 +typing-extensions==4.12.2 \ + --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ + --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 # via etils -wheel==0.43.0 +wheel==0.44.0 \ + --hash=sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f \ + --hash=sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49 # via -r build/test-requirements.txt -zipp==3.18.1 +zipp==3.20.2 \ + --hash=sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350 \ + --hash=sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29 # via etils -zstandard==0.22.0 +zstandard==0.23.0 \ + --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ + --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ + --hash=sha256:11e3bf3c924853a2d5835b24f03eeba7fc9b07d8ca499e247e06ff5676461a15 \ + --hash=sha256:12a289832e520c6bd4dcaad68e944b86da3bad0d339ef7989fb7e88f92e96072 \ + --hash=sha256:1516c8c37d3a053b01c1c15b182f3b5f5eef19ced9b930b684a73bad121addf4 \ + --hash=sha256:157e89ceb4054029a289fb504c98c6a9fe8010f1680de0201b3eb5dc20aa6d9e \ + --hash=sha256:1bfe8de1da6d104f15a60d4a8a768288f66aa953bbe00d027398b93fb9680b26 \ + --hash=sha256:1e172f57cd78c20f13a3415cc8dfe24bf388614324d25539146594c16d78fcc8 \ + --hash=sha256:1fd7e0f1cfb70eb2f95a19b472ee7ad6d9a0a992ec0ae53286870c104ca939e5 \ + --hash=sha256:203d236f4c94cd8379d1ea61db2fce20730b4c38d7f1c34506a31b34edc87bdd \ + --hash=sha256:27d3ef2252d2e62476389ca8f9b0cf2bbafb082a3b6bfe9d90cbcbb5529ecf7c \ + --hash=sha256:29a2bc7c1b09b0af938b7a8343174b987ae021705acabcbae560166567f5a8db \ + --hash=sha256:2ef230a8fd217a2015bc91b74f6b3b7d6522ba48be29ad4ea0ca3a3775bf7dd5 \ + --hash=sha256:2ef3775758346d9ac6214123887d25c7061c92afe1f2b354f9388e9e4d48acfc \ + --hash=sha256:2f146f50723defec2975fb7e388ae3a024eb7151542d1599527ec2aa9cacb152 \ + --hash=sha256:2fb4535137de7e244c230e24f9d1ec194f61721c86ebea04e1581d9d06ea1269 \ + --hash=sha256:32ba3b5ccde2d581b1e6aa952c836a6291e8435d788f656fe5976445865ae045 \ + --hash=sha256:34895a41273ad33347b2fc70e1bff4240556de3c46c6ea430a7ed91f9042aa4e \ + --hash=sha256:379b378ae694ba78cef921581ebd420c938936a153ded602c4fea612b7eaa90d \ + --hash=sha256:38302b78a850ff82656beaddeb0bb989a0322a8bbb1bf1ab10c17506681d772a \ + --hash=sha256:3aa014d55c3af933c1315eb4bb06dd0459661cc0b15cd61077afa6489bec63bb \ + --hash=sha256:4051e406288b8cdbb993798b9a45c59a4896b6ecee2f875424ec10276a895740 \ + --hash=sha256:40b33d93c6eddf02d2c19f5773196068d875c41ca25730e8288e9b672897c105 \ + --hash=sha256:43da0f0092281bf501f9c5f6f3b4c975a8a0ea82de49ba3f7100e64d422a1274 \ + --hash=sha256:445e4cb5048b04e90ce96a79b4b63140e3f4ab5f662321975679b5f6360b90e2 \ + --hash=sha256:48ef6a43b1846f6025dde6ed9fee0c24e1149c1c25f7fb0a0585572b2f3adc58 \ + --hash=sha256:50a80baba0285386f97ea36239855f6020ce452456605f262b2d33ac35c7770b \ + --hash=sha256:519fbf169dfac1222a76ba8861ef4ac7f0530c35dd79ba5727014613f91613d4 \ + --hash=sha256:53dd9d5e3d29f95acd5de6802e909ada8d8d8cfa37a3ac64836f3bc4bc5512db \ + --hash=sha256:53ea7cdc96c6eb56e76bb06894bcfb5dfa93b7adcf59d61c6b92674e24e2dd5e \ + --hash=sha256:576856e8594e6649aee06ddbfc738fec6a834f7c85bf7cadd1c53d4a58186ef9 \ + --hash=sha256:59556bf80a7094d0cfb9f5e50bb2db27fefb75d5138bb16fb052b61b0e0eeeb0 \ + --hash=sha256:5d41d5e025f1e0bccae4928981e71b2334c60f580bdc8345f824e7c0a4c2a813 \ + --hash=sha256:61062387ad820c654b6a6b5f0b94484fa19515e0c5116faf29f41a6bc91ded6e \ + --hash=sha256:61f89436cbfede4bc4e91b4397eaa3e2108ebe96d05e93d6ccc95ab5714be512 \ + --hash=sha256:62136da96a973bd2557f06ddd4e8e807f9e13cbb0bfb9cc06cfe6d98ea90dfe0 \ + --hash=sha256:64585e1dba664dc67c7cdabd56c1e5685233fbb1fc1966cfba2a340ec0dfff7b \ + --hash=sha256:65308f4b4890aa12d9b6ad9f2844b7ee42c7f7a4fd3390425b242ffc57498f48 \ + --hash=sha256:66b689c107857eceabf2cf3d3fc699c3c0fe8ccd18df2219d978c0283e4c508a \ + --hash=sha256:6a41c120c3dbc0d81a8e8adc73312d668cd34acd7725f036992b1b72d22c1772 \ + --hash=sha256:6f77fa49079891a4aab203d0b1744acc85577ed16d767b52fc089d83faf8d8ed \ + --hash=sha256:72c68dda124a1a138340fb62fa21b9bf4848437d9ca60bd35db36f2d3345f373 \ + --hash=sha256:752bf8a74412b9892f4e5b58f2f890a039f57037f52c89a740757ebd807f33ea \ + --hash=sha256:76e79bc28a65f467e0409098fa2c4376931fd3207fbeb6b956c7c476d53746dd \ + --hash=sha256:774d45b1fac1461f48698a9d4b5fa19a69d47ece02fa469825b442263f04021f \ + --hash=sha256:77da4c6bfa20dd5ea25cbf12c76f181a8e8cd7ea231c673828d0386b1740b8dc \ + --hash=sha256:77ea385f7dd5b5676d7fd943292ffa18fbf5c72ba98f7d09fc1fb9e819b34c23 \ + --hash=sha256:80080816b4f52a9d886e67f1f96912891074903238fe54f2de8b786f86baded2 \ + --hash=sha256:80a539906390591dd39ebb8d773771dc4db82ace6372c4d41e2d293f8e32b8db \ + --hash=sha256:82d17e94d735c99621bf8ebf9995f870a6b3e6d14543b99e201ae046dfe7de70 \ + --hash=sha256:837bb6764be6919963ef41235fd56a6486b132ea64afe5fafb4cb279ac44f259 \ + --hash=sha256:84433dddea68571a6d6bd4fbf8ff398236031149116a7fff6f777ff95cad3df9 \ + --hash=sha256:8c24f21fa2af4bb9f2c492a86fe0c34e6d2c63812a839590edaf177b7398f700 \ + --hash=sha256:8ed7d27cb56b3e058d3cf684d7200703bcae623e1dcc06ed1e18ecda39fee003 \ + --hash=sha256:9206649ec587e6b02bd124fb7799b86cddec350f6f6c14bc82a2b70183e708ba \ + --hash=sha256:983b6efd649723474f29ed42e1467f90a35a74793437d0bc64a5bf482bedfa0a \ + --hash=sha256:98da17ce9cbf3bfe4617e836d561e433f871129e3a7ac16d6ef4c680f13a839c \ + --hash=sha256:9c236e635582742fee16603042553d276cca506e824fa2e6489db04039521e90 \ + --hash=sha256:9da6bc32faac9a293ddfdcb9108d4b20416219461e4ec64dfea8383cac186690 \ + --hash=sha256:a05e6d6218461eb1b4771d973728f0133b2a4613a6779995df557f70794fd60f \ + --hash=sha256:a0817825b900fcd43ac5d05b8b3079937073d2b1ff9cf89427590718b70dd840 \ + --hash=sha256:a4ae99c57668ca1e78597d8b06d5af837f377f340f4cce993b551b2d7731778d \ + --hash=sha256:a8c86881813a78a6f4508ef9daf9d4995b8ac2d147dcb1a450448941398091c9 \ + --hash=sha256:a8fffdbd9d1408006baaf02f1068d7dd1f016c6bcb7538682622c556e7b68e35 \ + --hash=sha256:a9b07268d0c3ca5c170a385a0ab9fb7fdd9f5fd866be004c4ea39e44edce47dd \ + --hash=sha256:ab19a2d91963ed9e42b4e8d77cd847ae8381576585bad79dbd0a8837a9f6620a \ + --hash=sha256:ac184f87ff521f4840e6ea0b10c0ec90c6b1dcd0bad2f1e4a9a1b4fa177982ea \ + --hash=sha256:b0e166f698c5a3e914947388c162be2583e0c638a4703fc6a543e23a88dea3c1 \ + --hash=sha256:b2170c7e0367dde86a2647ed5b6f57394ea7f53545746104c6b09fc1f4223573 \ + --hash=sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09 \ + --hash=sha256:b4567955a6bc1b20e9c31612e615af6b53733491aeaa19a6b3b37f3b65477094 \ + --hash=sha256:b69bb4f51daf461b15e7b3db033160937d3ff88303a7bc808c67bbc1eaf98c78 \ + --hash=sha256:b8c0bd73aeac689beacd4e7667d48c299f61b959475cdbb91e7d3d88d27c56b9 \ + --hash=sha256:be9b5b8659dff1f913039c2feee1aca499cfbc19e98fa12bc85e037c17ec6ca5 \ + --hash=sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9 \ + --hash=sha256:c16842b846a8d2a145223f520b7e18b57c8f476924bda92aeee3a88d11cfc391 \ + --hash=sha256:c363b53e257246a954ebc7c488304b5592b9c53fbe74d03bc1c64dda153fb847 \ + --hash=sha256:c7c517d74bea1a6afd39aa612fa025e6b8011982a0897768a2f7c8ab4ebb78a2 \ + --hash=sha256:d20fd853fbb5807c8e84c136c278827b6167ded66c72ec6f9a14b863d809211c \ + --hash=sha256:d2240ddc86b74966c34554c49d00eaafa8200a18d3a5b6ffbf7da63b11d74ee2 \ + --hash=sha256:d477ed829077cd945b01fc3115edd132c47e6540ddcd96ca169facff28173057 \ + --hash=sha256:d50d31bfedd53a928fed6707b15a8dbeef011bb6366297cc435accc888b27c20 \ + --hash=sha256:dc1d33abb8a0d754ea4763bad944fd965d3d95b5baef6b121c0c9013eaf1907d \ + --hash=sha256:dc5d1a49d3f8262be192589a4b72f0d03b72dcf46c51ad5852a4fdc67be7b9e4 \ + --hash=sha256:e2d1a054f8f0a191004675755448d12be47fa9bebbcffa3cdf01db19f2d30a54 \ + --hash=sha256:e7792606d606c8df5277c32ccb58f29b9b8603bf83b48639b7aedf6df4fe8171 \ + --hash=sha256:ed1708dbf4d2e3a1c5c69110ba2b4eb6678262028afd6c6fbcc5a8dac9cda68e \ + --hash=sha256:f2d4380bf5f62daabd7b751ea2339c1a21d1c9463f1feb7fc2bdcea2c29c3160 \ + --hash=sha256:f3513916e8c645d0610815c257cbfd3242adfd5c4cfa78be514e5a3ebb42a41b \ + --hash=sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58 \ + --hash=sha256:f83fa6cae3fff8e98691248c9320356971b59678a17f20656a9e59cd32cee6d8 \ + --hash=sha256:fa6ce8b52c5987b3e34d5674b0ab529a4602b632ebab0a93b07bfb4dfc8f8a33 \ + --hash=sha256:fb2b1ecfef1e67897d336de3a0e3f52478182d6a47eda86cbd42504c5cbd009a \ + --hash=sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880 \ + --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ + --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ + --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==69.2.0 +setuptools==75.1.0 \ + --hash=sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2 \ + --hash=sha256:d59a21b17a275fb872a9c3dae73963160ae079f1049ed956880cd7c09b120538 # via -r build/test-requirements.txt From 2ff26ff3e0aa16133982c64a5c2866c70d9ef05f Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 16 Sep 2024 09:55:17 +0200 Subject: [PATCH 511/702] Add `scalar_first` argument to `jax.scipy.spatial.transform.Rotation.as_quat` --- jax/_src/scipy/spatial/transform.py | 10 +++++----- tests/scipy_spatial_test.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/jax/_src/scipy/spatial/transform.py b/jax/_src/scipy/spatial/transform.py index 46bd873bd029..debd37dde64f 100644 --- a/jax/_src/scipy/spatial/transform.py +++ b/jax/_src/scipy/spatial/transform.py @@ -167,12 +167,12 @@ def as_rotvec(self, degrees: bool = False) -> jax.Array: """Represent as rotation vectors.""" return _as_rotvec(self.quat, degrees) - def as_quat(self, canonical: bool=False) -> jax.Array: + def as_quat(self, canonical: bool=False, scalar_first: bool=False) -> jax.Array: """Represent as quaternions.""" - if canonical: - return _make_canonical(self.quat) - else: - return self.quat + quat = _make_canonical(self.quat) if canonical else self.quat + if scalar_first: + return jnp.roll(quat, shift=1, axis=-1) + return quat def inv(self): """Invert this rotation.""" diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index c02653dd171b..540136b33870 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -132,6 +132,20 @@ def testRotationAsQuatCanonical(self, shape, dtype): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) + @jtu.sample_product( + dtype=float_dtypes, + shape=[(4,), (num_samples, 4)], + ) + def testRotationAsQuatScalarFirst(self, shape, dtype): + if scipy_version < (1, 14, 0): + self.skipTest("Scipy 1.14.0 added the `scalar_first` arg.") + rng = jtu.rand_default(self.rng()) + args_maker = lambda: (rng(shape, dtype),) + jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat(scalar_first=True) + np_fn = lambda q: osp_Rotation.from_quat(q).as_quat(scalar_first=True).astype(dtype) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) + @jtu.sample_product( dtype=float_dtypes, shape=[(num_samples, 4)], From 4fac852a908609fd25f1154b9cd6392a963e9f7f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 16 Sep 2024 20:05:25 +0000 Subject: [PATCH 512/702] Remove XLA tanh fix cherry-pick, to avoid CI breakages when the XLA commit is bumped. --- third_party/xla/tanh.patch | 14 -------------- third_party/xla/workspace.bzl | 3 --- 2 files changed, 17 deletions(-) delete mode 100644 third_party/xla/tanh.patch diff --git a/third_party/xla/tanh.patch b/third_party/xla/tanh.patch deleted file mode 100644 index fce0ee6fdb68..000000000000 --- a/third_party/xla/tanh.patch +++ /dev/null @@ -1,14 +0,0 @@ -diff --git a/xla/service/cpu/llvm_ir_runtime.cc b/xla/service/cpu/llvm_ir_runtime.cc -index 89b40b915caa3..25541c16bfd61 100644 ---- a/xla/service/cpu/llvm_ir_runtime.cc -+++ b/xla/service/cpu/llvm_ir_runtime.cc -@@ -410,7 +410,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, - rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8); - rewrite_calls(kTanhV16F32SymbolName, GenerateVF32Tanh, /*vector_width=*/16); - -- rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1); -+ // TODO(penporn): Re-enable after fixing JAX issue #23590. -+ // rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1); - - rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1); - rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1); \ No newline at end of file diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 88cfbfbac89f..efd799616237 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -30,9 +30,6 @@ def repo(): sha256 = XLA_SHA256, strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), - patch_file = [ - "//third_party/xla:tanh.patch", - ], ) # For development, one often wants to make changes to the TF repository as well From 29163fcefdc46b3ec1f48051301b264886629ce7 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 16 Sep 2024 13:21:31 -0700 Subject: [PATCH 513/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/90be6e3a11b3489451dcacf918febda2a32f7b10. PiperOrigin-RevId: 675267262 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index efd799616237..31aaf75f4504 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "af733ec6fb9885ddebffdac13acf94b839e049df" -XLA_SHA256 = "7ee9cd3a28d18b3fa2a699b6f5baae0d65351f3c631486b816d45296b1e3328a" +XLA_COMMIT = "90be6e3a11b3489451dcacf918febda2a32f7b10" +XLA_SHA256 = "e924c732353ab0fa3fcee5c29b316118678d8dcbac0d16fa06c72e2cb8fb96b1" def repo(): tf_http_archive( From 940860625e14d0f0601382298b968799b885792f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 16 Sep 2024 14:29:21 -0700 Subject: [PATCH 514/702] Remove code that existed to support jaxlib < 0.4.32. New minimum versions: * jaxlib 0.4.32 * xla_extension_version 283 * mlir_api_version 57 PiperOrigin-RevId: 675291231 --- jax/_src/compiler.py | 4 +- jax/_src/interpreters/pxla.py | 9 +--- jax/_src/lax/linalg.py | 29 ++++------- tests/export_back_compat_test.py | 85 ++++++++++++-------------------- tests/layout_test.py | 3 -- tests/linalg_test.py | 4 -- tests/pjit_test.py | 11 ----- tests/shape_poly_test.py | 5 -- tests/shard_map_test.py | 7 --- 9 files changed, 45 insertions(+), 112 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 81457f1cbd07..108741b5f8fd 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -33,7 +33,6 @@ from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir import numpy as np @@ -157,8 +156,7 @@ def get_compile_options( build_options = compile_options.executable_build_options build_options.use_spmd_partitioning = use_spmd_partitioning build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning - if xla_extension_version >= 280: - build_options.use_shardy_partitioner = use_shardy_partitioner + build_options.use_shardy_partitioner = use_shardy_partitioner if fdo_profile is not None: build_options.fdo_profile = fdo_profile if use_auto_spmd_partitioning: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 882f71d58671..b7d68f73c2a4 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -61,7 +61,6 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -3022,12 +3021,8 @@ def aot_cache_miss(*args, **kwargs): self.unsafe_call.name, None, aot_cache_miss, [], [], [], tree_util.dispatch_registry, cc_shard_arg) -if xla_extension_version < 282: - def cc_shard_arg(x, sharding): - return shard_args([sharding], [None], [x])[0] -else: - def cc_shard_arg(x, sharding, layout): # type: ignore - return shard_args([sharding], [layout], [x])[0] +def cc_shard_arg(x, sharding, layout): + return shard_args([sharding], [layout], [x])[0] def check_arg_avals_for_call(ref_avals, arg_avals, diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 0cc0e774af53..8752e0b6d1de 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -514,11 +514,7 @@ def _cholesky_cpu_lowering(ctx, operand): out_aval, = ctx.avals_out batch_dims = operand_aval.shape[:-2] op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - # TODO(b/344892332): Remove the check after the compatibility period. - if jaxlib_version < (0, 4, 31): - ctx_arg = () - else: - ctx_arg = (ctx,) + ctx_arg = (ctx,) result, info = lapack.potrf_hlo(*ctx_arg, operand_aval.dtype, operand, lower=True, a_shape_vals=op_shape_vals) @@ -556,7 +552,7 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector): def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector): # TODO(b/360781533): Remove guard after 3 week forward compatibility period. - if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32): + if ctx.is_forward_compat(): r_matrix_aval, _ = ctx.avals_in try: [platform] = ctx.module_context.platforms @@ -726,8 +722,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - # TODO(b/344892332): Remove the conditional after the compatibility period. - ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + ctx_args = (ctx,) w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand, input_shape_vals=op_shape_vals, jobvl=compute_left_eigenvectors, @@ -937,8 +932,7 @@ def _eigh_cpu_gpu_lowering( op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) cpu_args = [] if platform == "cpu": - # TODO(b/344892332): Remove the conditional after the compatibility period. - ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + ctx_args = (ctx,) cpu_args.extend(ctx_args) v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand, a_shape_vals=op_shape_vals, lower=lower) @@ -1511,9 +1505,9 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, platform: str, info_aval = ShapedArray(batch_dims, np.dtype(np.int32)) m = operand_aval.shape[-2] - # TODO(b/357034884): Remove version gate once jaxlib 0.4.32 is the minimum - # version and the forward compat flag after the 3 week compatibility window. - if jaxlib_version < (0, 4, 32) or ctx.is_forward_compat(): + # TODO(b/357034884): Remove version gate on the forward compat flag after the + # 3 week compatibility window. + if ctx.is_forward_compat(): if not is_constant_shape(operand_aval.shape[-2:]): raise NotImplementedError( "Shape polymorphism for native lowering for lu on CPU and GPU is " @@ -1757,9 +1751,8 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *, a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) - # TODO(b/344892332): Remove the conditional after the compatibility period ctx_args = ( - (ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else () + (ctx,) if platform == "cpu" else () ) a_out, taus, *maybe_info_geqrf = geqrf_impl( *ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals @@ -1881,9 +1874,8 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *, f"on GPU is not implemented; b/261671778; {a_aval.shape}") a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus) else: - # TODO(b/344892332): Remove the conditional after the compatibility period ctx_args = ( - (ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else () + (ctx,) if platform == "cpu" else () ) a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape) @@ -2152,8 +2144,7 @@ def _svd_cpu_gpu_lowering( compute_uv=compute_uv) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - # TODO(b/344892332): Remove the conditional after the compatibility period. - ctx_args = (ctx,) if jaxlib_version >= (0, 4, 32) else () + ctx_args = (ctx,) s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand, full_matrices=full_matrices, compute_uv=compute_uv, diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 4e7898d57fe0..103357ac18ac 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -66,7 +66,6 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.lib import cuda_versions -from jax._src.lib import version as jaxlib_version config.parse_flags_with_absl() @@ -190,14 +189,11 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"): atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name]) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 31) self.run_one_test(func, data, rtol=rtol, atol=atol) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) @@ -258,14 +254,11 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_eig_results) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_eig_results) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=check_eig_results) @staticmethod def eigh_input(shape, dtype): @@ -316,14 +309,11 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - # FFI Kernel test - with config.export_ignore_forward_compatibility(True): - data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand)) + # FFI Kernel test + with config.export_ignore_forward_compatibility(True): + data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_eigh_results, operand)) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{variant}", @@ -385,8 +375,6 @@ def test_cuda_lu_pivots_to_permutation(self): def test_cuda_lu_lapack_getrf(self, dtype_name:str): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") - if jaxlib_version < (0, 4, 32): - self.skipTest("Not implemented in older versions of jaxlib") dtype = dict(f32=np.float32, f64=np.float64, c64=np.complex64, c128=np.complex128)[dtype_name] shape = (3, 4) @@ -416,15 +404,12 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] self.run_one_test(func, data, rtol=rtol) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata( - cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{batched}", @@ -502,14 +487,11 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_lu_results, operand, dtype=dtype)) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_lu_results, operand, + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_lu_results, operand, dtype=dtype)) def check_svd_results(self, input, res_run, res_exp, @@ -629,16 +611,13 @@ def func(input): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_svd_results, input)) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata( - cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_svd_results, input)) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_svd_results, input)) @jtu.parameterized_filterable( kwargs=[ diff --git a/tests/layout_test.py b/tests/layout_test.py index f14120e46116..1d18179ccfee 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -15,7 +15,6 @@ import contextlib import math from functools import partial -import unittest from absl.testing import absltest import numpy as np @@ -511,8 +510,6 @@ def g(x): 'Layout passed to jit does not match the layout on the respective arg'): g(arr) - @unittest.skipIf(xla_extension_version < 282, - "Requires xla_extension_version >= 282") def test_in_layouts_jit_jnp_input(self): major_last_layout = DLL(major_to_minor=(1, 0)) sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 4dcdeb19e1ef..446e10abd097 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,7 +16,6 @@ from functools import partial import itertools -import unittest import numpy as np import scipy @@ -2194,9 +2193,6 @@ def testHilbert(self, n): symmetrize_output=[True, False], ) @jtu.skip_on_devices("tpu") - @unittest.skipIf( - jax._src.lib.version < (0, 4, 32), "requires jaxlib >= 0.4.32" - ) def testSymmetricProduct(self, shape, dtype, symmetrize_output): rng = jtu.rand_default(self.rng()) batch_size = 10 diff --git a/tests/pjit_test.py b/tests/pjit_test.py index dbb867ab9a39..6c022653581d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -56,7 +56,6 @@ from jax._src.lib.mlir import dialects from jax._src import xla_bridge from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib import xla_extension from jax._src.util import curry, unzip2 @@ -4433,8 +4432,6 @@ def f(x): "Compiled object called with input sharding.*does not match"): compiled(cpu_arr) - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_different_devices_wsc_abstract_mesh_cache_hit(self): if jax.device_count() < 4: self.skipTest('Requires >=4 devices') @@ -4463,8 +4460,6 @@ def f(x): self.assertEqual(lowering_count[0], 1) self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ. - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_wsc_abstract_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) @@ -4484,8 +4479,6 @@ def f(x): self.assertArraysEqual(out_eager, np_inp * 2) self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x'))) - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_wsc_sds_abstract_mesh(self): mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P()) @@ -4499,8 +4492,6 @@ def f(x): sds = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s) f.eval_shape(sds) # doesn't crash - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_wsc_vmap_abstract_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) @@ -4517,8 +4508,6 @@ def f(x): out2 = jax.jit(jax.vmap(f, spmd_axis_name='y'))(arr) self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', 'x'))) - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_wsc_abstract_mesh_errors(self): mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 323d44b542d6..27199c874332 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2843,11 +2843,6 @@ def test_vmap_error(self): ((2, 3, 8, 4), "b1, b2, ..."), ((2, 3, 4, 5), "b1, b2, m, n"), ] - # TODO(danfm): Remove once jaxlib v0.4.32 is the minimum version. - # jaxlib versions before 0.4.32 require a static shape for the non-batch - # dimensions because these are used for computing the "permuation_size" - # which is passed to lu_pivots_to_permutation. - if jaxlib_version >= (0, 4, 32) or not poly.endswith("m, n") ], [ # The random primitive tests, with threefry (both partitionable and diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 0c6848f94c03..3d9b567e2ef4 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -45,7 +45,6 @@ from jax._src import linear_util as lu from jax._src import tree_util import jax.numpy as jnp -from jax._src.lib import xla_extension_version from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.shard_map import shard_map @@ -777,8 +776,6 @@ def with_capture(y_slice): # is over an axis of size 2. This is a problem at the moment. jax.make_jaxpr(mapped)(x, y).jaxpr - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_shard_map_abstract_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) @@ -803,8 +800,6 @@ def f(x): self.assertArraysEqual(out2, np_inp) self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_different_devices_shmap_abstract_mesh_cache_hit(self): if jax.device_count() < 4: self.skipTest('Requires >=4 devices') @@ -835,8 +830,6 @@ def f(x): self.assertEqual(lowering_count[0], 1) self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ. - @unittest.skipIf(xla_extension_version < 281, - 'Requires xla_extension_version >= 281') def test_shmap_abstract_mesh_errors(self): mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) From cd083377177e3fddc70815873fd107ab9266edfb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 16 Sep 2024 14:55:44 -0700 Subject: [PATCH 515/702] CI: set concurrency for workflows --- .github/workflows/ci-build.yaml | 29 ++++----------------------- .github/workflows/jax-array-api.yml | 4 ++++ .github/workflows/metal_plugin_ci.yml | 4 ++++ .github/workflows/wheel_win_x64.yml | 9 ++++----- .github/workflows/windows_ci.yml | 8 ++++---- 5 files changed, 20 insertions(+), 34 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 5d46f8fbf0d8..0f90cd72e463 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -20,16 +20,15 @@ permissions: contents: read # to fetch code actions: write # to cancel previous workflows +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: lint_and_typecheck: runs-on: ubuntu-latest timeout-minutes: 5 steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python 3.11 uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 @@ -58,11 +57,6 @@ jobs: prng-upgrade: 0 num_generated_cases: 1 steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 @@ -110,11 +104,6 @@ jobs: matrix: python-version: ['3.10'] steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 @@ -152,11 +141,6 @@ jobs: matrix: python-version: ['3.10'] steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 @@ -193,11 +177,6 @@ jobs: enable-x64: 0 num_generated_cases: 10 steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index cdba39b3642a..cbe383f21ffe 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -9,6 +9,10 @@ on: - '**workflows/jax-array-api.yml' - '**experimental/array_api/**' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: build: diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml index 0c739619df1a..75f4bba1a367 100644 --- a/.github/workflows/metal_plugin_ci.yml +++ b/.github/workflows/metal_plugin_ci.yml @@ -11,6 +11,10 @@ on: paths: - '**workflows/metal_plugin_ci.yml' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: jax-metal-plugin-test: diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 03b5de37b85b..447ccba4f8c2 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -2,6 +2,10 @@ name: Wheel build - Windows CPU x86_64 on: workflow_dispatch: # allows triggering the workflow run manually +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + env: DISTUTILS_USE_SDK: 1 MSSdk: 1 @@ -18,11 +22,6 @@ jobs: runs-on: ${{ matrix.os }} steps: - - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - - name: Install LLVM/Clang run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 42083f1d087d..194cac6fa79a 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -6,6 +6,10 @@ on: pull_request: types: [ labeled ] # allow force-windows-run label +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + env: DISTUTILS_USE_SDK: 1 MSSdk: 1 @@ -23,10 +27,6 @@ jobs: runs-on: ${{ matrix.os }} steps: - - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - name: Install LLVM/Clang run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade From d27fce6981e8a6d27a9e2ef683f024cbf62ae2bc Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Mon, 16 Sep 2024 17:59:28 -0700 Subject: [PATCH 516/702] [Pallas TPU] Fix dtype_bitwidth for int in util. PiperOrigin-RevId: 675357560 --- jax/_src/pallas/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index cfca0769d13d..e1fbbde61c56 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -72,7 +72,7 @@ def next_power_of_2(x: int) -> int: return 1 if x == 0 else 2 ** (x - 1).bit_length() def dtype_bitwidth(dtype: np.dtype | jnp.dtype) -> int: - if isinstance(dtype, jnp.integer): + if jnp.issubdtype(dtype, jnp.integer): return jnp.iinfo(dtype).bits return np.dtype(dtype).itemsize * 8 From 3555b2b2c13b04e8af42b21aee1a10b67c260b70 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 17 Sep 2024 05:33:47 -0700 Subject: [PATCH 517/702] Renamed `plgpu.wait` to `plgpu.wait_barrier` This avoid a potential ambiguity with waiting for a WGMMA to complete. PiperOrigin-RevId: 675528768 --- jax/_src/pallas/mosaic_gpu/__init__.py | 2 +- jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- tests/pallas/mosaic_gpu_test.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index 97732acbd830..11258f741b7f 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -21,7 +21,7 @@ from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem -from jax._src.pallas.mosaic_gpu.primitives import wait +from jax._src.pallas.mosaic_gpu.primitives import wait_barrier from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem GMEM = GPUMemorySpace.GMEM diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index b53901e30612..e96574612bfa 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -100,6 +100,6 @@ def wait_smem_to_gmem(allow_groups: int) -> None: wait_p.bind(allow_groups=allow_groups) -def wait(barrier: pallas_core.AbstractMemoryRef) -> None: +def wait_barrier(barrier: pallas_core.AbstractMemoryRef) -> None: """Waits on the given barrier.""" wait_p.bind(barrier) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index ee8858282a33..bd9df6182793 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -149,7 +149,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): plgpu.async_copy_gmem_to_smem( x_ref_gmem, scratch_ref, barrier=barrier_ref ) - plgpu.wait(barrier_ref) + plgpu.wait_barrier(barrier_ref) o_ref[...] = scratch_ref[...] + 1 x = jnp.arange(128).astype(jnp.float32) From 2e73d507f2c57d9cef28c3e41d548696dd5adcb8 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Tue, 17 Sep 2024 12:04:35 -0400 Subject: [PATCH 518/702] Fix minor typo in unique_indices error for _scatter_transpose_rule. --- jax/_src/lax/slicing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 2a3a63e89a35..39d4b31588c1 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -2384,7 +2384,7 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): if not unique_indices: - raise NotImplementedError("scatter transpose is only implemented where" + raise NotImplementedError("scatter transpose is only implemented where " "unique_indices=True") assert not ad.is_undefined_primal(indices) if ad.is_undefined_primal(updates): From 187eeb9e989e02ddda0e13f508b21ded74338345 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 17 Sep 2024 09:25:27 -0700 Subject: [PATCH 519/702] Update docs for jnp.argmax & related functions --- jax/_src/numpy/lax_numpy.py | 160 +++++++++++++++++++++++++++++++++--- 1 file changed, 150 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 72dcbe872058..818476414ed6 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -8519,9 +8519,41 @@ def argwhere( return result.reshape(result.shape[0], ndim(a)) -@util.implements(np.argmax, skip_params=['out']) def argmax(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: + """Return the index of the maximum value of an array. + + JAX implementation of :func:`numpy.argmax`. + + Args: + a: input array + axis: optional integer specifying the axis along which to find the maximum + value. If ``axis`` is not specified, ``a`` will be flattened. + out: unused by JAX + keepdims: if True, then return an array with the same number of dimensions + as ``a``. + + Returns: + an array containing the index of the maximum value along the specified axis. + + See also: + - :func:`jax.numpy.argmin`: return the index of the minimum value. + - :func:`jax.numpy.nanargmax`: compute ``argmax`` while ignoring NaN values. + + Examples: + >>> x = jnp.array([1, 3, 5, 4, 2]) + >>> jnp.argmax(x) + Array(2, dtype=int32) + + >>> x = jnp.array([[1, 3, 2], + ... [5, 4, 1]]) + >>> jnp.argmax(x, axis=1) + Array([1, 0], dtype=int32) + + >>> jnp.argmax(x, axis=1, keepdims=True) + Array([[1], + [0]], dtype=int32) + """ util.check_arraylike("argmax", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.") @@ -8541,9 +8573,42 @@ def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: result = lax.argmax(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_)) return expand_dims(result, dims) if keepdims else result -@util.implements(np.argmin, skip_params=['out']) + def argmin(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: + """Return the index of the minimum value of an array. + + JAX implementation of :func:`numpy.argmax`. + + Args: + a: input array + axis: optional integer specifying the axis along which to find the maximum + value. If ``axis`` is not specified, ``a`` will be flattened. + out: unused by JAX + keepdims: if True, then return an array with the same number of dimensions + as ``a``. + + Returns: + an array containing the index of the maximum value along the specified axis. + + See also: + - :func:`jax.numpy.argmax`: return the index of the maximum value. + - :func:`jax.numpy.nanargmin`: compute ``argmin`` while ignoring NaN values. + + Examples: + >>> x = jnp.array([1, 3, 5, 4, 2]) + >>> jnp.argmin(x) + Array(0, dtype=int32) + + >>> x = jnp.array([[1, 3, 2], + ... [5, 4, 1]]) + >>> jnp.argmin(x, axis=1) + Array([0, 2], dtype=int32) + + >>> jnp.argmin(x, axis=1, keepdims=True) + Array([[0], + [2]], dtype=int32) + """ util.check_arraylike("argmin", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.") @@ -8564,19 +8629,57 @@ def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: return expand_dims(result, dims) if keepdims else result -_NANARG_DOC = """\ -Warning: jax.numpy.arg{} returns -1 for all-NaN slices and does not raise -an error. -""" - - -@util.implements(np.nanargmax, lax_description=_NANARG_DOC.format("max"), skip_params=['out']) def nanargmax( a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None, ) -> Array: + """Return the index of the maximum value of an array, ignoring NaNs. + + JAX implementation of :func:`numpy.nanargmax`. + + Args: + a: input array + axis: optional integer specifying the axis along which to find the maximum + value. If ``axis`` is not specified, ``a`` will be flattened. + out: unused by JAX + keepdims: if True, then return an array with the same number of dimensions + as ``a``. + + Returns: + an array containing the index of the maximum value along the specified axis. + + Note: + In the case of an axis with all-NaN values, the returned index will be -1. + This differs from the behavior of :func:`numpy.nanargmax`, which raises an error. + + See also: + - :func:`jax.numpy.argmax`: return the index of the maximum value. + - :func:`jax.numpy.nanargmin`: compute ``argmin`` while ignoring NaN values. + + Examples: + >>> x = jnp.array([1, 3, 5, 4, jnp.nan]) + + Using a standard :func:`~jax.numpy.argmax` leads to potentially unexpected results: + + >>> jnp.argmax(x) + Array(4, dtype=int32) + + Using ``nanargmax`` returns the index of the maximum non-NaN value. + + >>> jnp.nanargmax(x) + Array(2, dtype=int32) + + >>> x = jnp.array([[1, 3, jnp.nan], + ... [5, 4, jnp.nan]]) + >>> jnp.nanargmax(x, axis=1) + Array([1, 0], dtype=int32) + + >>> jnp.nanargmax(x, axis=1, keepdims=True) + Array([[1], + [0]], dtype=int32) + """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmax is not supported.") return _nanargmax(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @@ -8593,13 +8696,50 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) -@util.implements(np.nanargmin, lax_description=_NANARG_DOC.format("min"), skip_params=['out']) def nanargmin( a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None, ) -> Array: + + """Return the index of the minimum value of an array, ignoring NaNs. + + JAX implementation of :func:`numpy.nanargmin`. + + Args: + a: input array + axis: optional integer specifying the axis along which to find the maximum + value. If ``axis`` is not specified, ``a`` will be flattened. + out: unused by JAX + keepdims: if True, then return an array with the same number of dimensions + as ``a``. + + Returns: + an array containing the index of the minimum value along the specified axis. + + Note: + In the case of an axis with all-NaN values, the returned index will be -1. + This differs from the behavior of :func:`numpy.nanargmin`, which raises an error. + + See also: + - :func:`jax.numpy.argmin`: return the index of the minimum value. + - :func:`jax.numpy.nanargmax`: compute ``argmax`` while ignoring NaN values. + + Examples: + >>> x = jnp.array([jnp.nan, 3, 5, 4, 2]) + >>> jnp.nanargmin(x) + Array(4, dtype=int32) + + >>> x = jnp.array([[1, 3, jnp.nan], + ... [5, 4, jnp.nan]]) + >>> jnp.nanargmin(x, axis=1) + Array([0, 1], dtype=int32) + + >>> jnp.nanargmin(x, axis=1, keepdims=True) + Array([[0], + [1]], dtype=int32) + """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmin is not supported.") return _nanargmin(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) From c61e49cd4a6c58b3b9823a32fe1320d65c98c45d Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 17 Sep 2024 11:22:49 -0700 Subject: [PATCH 520/702] Simplify logic in jaxlib FFI_ASSIGN_OR_RETURN macro, and fix gcc build. In https://github.com/google/jax/issues/23687, it was reported that recent jaxlib changes introduced issues when building from source using gcc, instead of the clang build that we test. I'm not 100% sure why the previous macro didn't work, but in investigating I found a version that seems to work on both clang and gcc with simpler logic. PiperOrigin-RevId: 675641259 --- jaxlib/ffi_helpers.h | 38 +++++++++----------------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index fba57d11b9f2..47505020f3b8 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -62,35 +62,15 @@ namespace jax { FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) // All the macros below here are to handle the case in FFI_ASSIGN_OR_RETURN -// where the LHS is wrapped in parentheses. -#define FFI_ASSIGN_OR_RETURN_EAT(...) -#define FFI_ASSIGN_OR_RETURN_REM(...) __VA_ARGS__ -#define FFI_ASSIGN_OR_RETURN_EMPTY() - -#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(...) \ - FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER((__VA_ARGS__, 0, 1)) -#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER(args) \ - FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I args -#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) is_empty - -#define FFI_ASSIGN_OR_RETURN_IS_EMPTY(...) \ - FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(__VA_ARGS__) -#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(...) \ - FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(_, ##__VA_ARGS__) - -#define FFI_ASSIGN_OR_RETURN_IF_1(_Then, _Else) _Then -#define FFI_ASSIGN_OR_RETURN_IF_0(_Then, _Else) _Else -#define FFI_ASSIGN_OR_RETURN_IF(_Cond, _Then, _Else) \ - FFI_ASSIGN_OR_RETURN_CONCAT_(FFI_ASSIGN_OR_RETURN_IF_, _Cond)(_Then, _Else) - -#define FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(...) \ - FFI_ASSIGN_OR_RETURN_IS_EMPTY(FFI_ASSIGN_OR_RETURN_EAT __VA_ARGS__) - -#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(...) \ - FFI_ASSIGN_OR_RETURN_IF(FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(__VA_ARGS__), \ - FFI_ASSIGN_OR_RETURN_REM, \ - FFI_ASSIGN_OR_RETURN_EMPTY()) \ - __VA_ARGS__ +// where the LHS is wrapped in parentheses. See a more detailed discussion at +// https://stackoverflow.com/a/62984543 +#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(X) \ + FFI_ASSIGN_OR_RETURN_ESCAPE(FFI_ASSIGN_OR_RETURN_EMPTY X) +#define FFI_ASSIGN_OR_RETURN_EMPTY(...) FFI_ASSIGN_OR_RETURN_EMPTY __VA_ARGS__ +#define FFI_ASSIGN_OR_RETURN_ESCAPE(...) \ + FFI_ASSIGN_OR_RETURN_ESCAPE_(__VA_ARGS__) +#define FFI_ASSIGN_OR_RETURN_ESCAPE_(...) FFI_ASSIGN_OR_RETURN_##__VA_ARGS__ +#define FFI_ASSIGN_OR_RETURN_FFI_ASSIGN_OR_RETURN_EMPTY template inline absl::StatusOr MaybeCastNoOverflow( From affdca91e6c304f1dc96dafbb851503a154498be Mon Sep 17 00:00:00 2001 From: tchatow Date: Mon, 16 Sep 2024 14:29:10 -0400 Subject: [PATCH 521/702] Add underlying method argument to jax.numpy.digitize --- jax/_src/numpy/lax_numpy.py | 18 ++++++++++-------- tests/lax_numpy_test.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 5137b20bc898..bd9576ea52f6 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -10493,7 +10493,7 @@ def body_fun(state, _): def _searchsorted_via_sort(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: working_dtype = int32 if sorted_arr.size + query.size < np.iinfo(np.int32).max else int64 def _rank(x): - idx = lax.iota(working_dtype, len(x)) + idx = lax.iota(working_dtype, x.shape[0]) return zeros_like(idx).at[argsort(x)].set(idx) query_flat = query.ravel() if side == 'left': @@ -10586,8 +10586,8 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', a, v = util.promote_dtypes(a, v) if sorter is not None: a = a[sorter] - dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64 - if len(a) == 0: + dtype = int32 if a.shape[0] <= np.iinfo(np.int32).max else int64 + if a.shape[0] == 0: return zeros_like(v, dtype=dtype) impl = { 'scan': partial(_searchsorted_via_scan, False), @@ -10597,9 +10597,11 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', }[method] return impl(asarray(a), asarray(v), side, dtype) # type: ignore -@util.implements(np.digitize) -@partial(jit, static_argnames=('right',)) -def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: +@util.implements(np.digitize, lax_description=_dedent(""" + Optionally, the ``method`` argument can be used to configure the + underlying :func:`jax.numpy.searchsorted` algorithm.""")) +@partial(jit, static_argnames=('right', 'method')) +def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str = 'scan') -> Array: util.check_arraylike("digitize", x, bins) right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()") bins_arr = asarray(bins) @@ -10610,8 +10612,8 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: side = 'right' if not right else 'left' return where( bins_arr[-1] >= bins_arr[0], - searchsorted(bins_arr, x, side=side), - len(bins_arr) - searchsorted(bins_arr[::-1], x, side=side) + searchsorted(bins_arr, x, side=side, method=method), + bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, method=method) ) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 704bb90116ea..01c89caf7a22 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2800,6 +2800,23 @@ def testDigitize(self, xshape, binshape, right, reverse, dtype): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + xshape=[(20,), (5, 4)], + binshape=[(0,), (1,), (5,)], + right=[True, False], + method=['scan', 'scan_unrolled', 'sort', 'compare_all'], + reverse=[True, False], + dtype=default_dtypes, + ) + def testDigitizeMethod(self, xshape, binshape, right, method, reverse, dtype): + order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:] + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]] + np_fun = lambda x, bins: np.digitize(x, bins, right=right).astype('int32') + jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right, method=method) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( dtypes=[ [np.float32], From dbb34f56dd6ac9c7a82d20fe220692ce8196a9ab Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 17 Sep 2024 15:04:09 -0400 Subject: [PATCH 522/702] Raise a clearer error message when `closure_convert`ed function is called with inputs with the wrong structure. Fixes https://github.com/google/jax/issues/23588 --- jax/_src/custom_derivatives.py | 7 ++++++- tests/api_test.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 64a37b782358..019948c36683 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1176,7 +1176,12 @@ def converted_fun(*args_hconsts): args, hoisted_consts = split_list(args_hconsts, [num_args]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten(tuple(args)) - assert in_tree == in_tree2 + if in_tree != in_tree2: + msg = ("The inputs to the closure produced by closure_convert must have " + "the same Pytree structure as the example arguments passed when " + f"closure_convert was called. Expected {in_tree}, but got " + f"{in_tree2}") + raise TypeError(msg) out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) diff --git a/tests/api_test.py b/tests/api_test.py index 0390e2e4b636..8b75cb624f1b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -8988,6 +8988,19 @@ def closure(x): self.assertAllClose(g_c, 42. * c, check_dtypes=False) self.assertAllClose(g_x, 17. * x, check_dtypes=False) + def test_closure_convert_pytree_mismatch(self): + # See https://github.com/google/jax/issues/23588 + def f(x, z): + return z * x + + x, z = 2.0, 3.0 + _, vjp = api.vjp(f, x, z) + vjp_pure, vjp_aux_args = jax.closure_convert(vjp, x) + vjp_pure(x, *vjp_aux_args) + with self.assertRaisesRegex( + TypeError, "The inputs to the closure produced by closure_convert"): + vjp_pure(x, vjp_aux_args) + def test_float0_cotangents_automatically_handled(self): @jax.custom_vjp def f(x, y): From 9d3762bd476b95a187bab22284e62525901255f7 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 16 Sep 2024 19:18:22 -0700 Subject: [PATCH 523/702] [Pallas] Add design note for async ops on TPU --- docs/pallas/async_note.md | 675 ++++++++++++++++++++++++++++++++++++++ docs/pallas/index.rst | 7 + 2 files changed, 682 insertions(+) create mode 100644 docs/pallas/async_note.md diff --git a/docs/pallas/async_note.md b/docs/pallas/async_note.md new file mode 100644 index 000000000000..96370ee48625 --- /dev/null +++ b/docs/pallas/async_note.md @@ -0,0 +1,675 @@ +# Pallas Async Operations + +## Background \+ Motivation + +We’d like to expose APIs in Pallas to explicitly overlap computation and communication *across multiple kernels*. + +### XLA Async Decomposition + +As motivation, consider the following JAX pseudocode: + +```py +def f(x): + y = ppermute(x) + z = x + 1 + return y, z +``` + +In this function, we could perform the `ppermute` at the same time as the `x + 1`. This is an optimization XLA does automatically by: + +1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future. +2. scheduling the `x + 1` between the `ppermute_start` and `ppermute_done`, + +resulting in the following program: + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 # happens at the same time as ppermute + y = ppermute_done(fut) + return y, z +``` + +### Async ops inside kernels + +Now imagine we aren’t using XLA’s `ppermute` but have our own custom Pallas `ppermute`. + +```py +def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem): + right_neighbor = ... + descriptor = pltpu.make_remote_async_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor) + descriptor.start() + descriptor.wait_send() + descriptor.wait_recv() + +def ppermute(x): + return pl.pallas_call(ppermute_kernel, out_shape=x, ...)(x) +``` + +Currently, we cannot decompose `ppermute` into a `start/done` pair as XLA does, so instead we explicitly **fuse** the `x + 1` into the kernel. + +```py +def add_one(x_ref, z_ref): + z_ref[...] = x_ref[...] + 1 + +def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem): + right_neighbor = ... + descriptor = pltpu.make_remote_async_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor) + descriptor.start() + + # Explicitly schedule inner kernel between start/wait + pltpu.emit_pipeline(add_one)(x_ref, z_ref) + + descriptor.wait_send() + descriptor.wait_recv() + +def ppermute_and_add_one(x): + return pl.pallas_call(ppermute_add_one_kernel, out_shape=(x, x), ...)(x) + +``` + +The goal is to enable writing separate kernels for starting the `ppermute` and waiting on it to complete, so that we can use a regular old `x + 1` in between (or whatever compute we want). This makes the code more readable, maintainable, and less bug-prone. + +## How do we implement decomposed Pallas async operations (on TPU)? + +The main thing to figure out when implementing decomposed async operations in Pallas is what the `future` that is passed between them contains. Specifically, it must contain some important state about the operation happening in the background. + +If we look at the Pallas code, we can see that we need a “descriptor” to both start and wait on a remote copy. Can we plumb this descriptor out of the Pallas kernel, and then pass it into another one? Well kinda. The underlying TPU hardware tracks async op progress via a pair of semaphores: `send_sem` enables us to wait on when a device is done sending data to its neighbor and `recv_sem` tracks the data transfer sent to a device from their neighbor. If we imagine writing a start kernel and a done kernel, all we’d need to pass from the start to the done would be the semaphores and some information about how much to wait on those semaphores. + +We can do this via extending Pallas to support returning semaphores from kernels. + +```py +def ppermute_start_kernel( + in_ref, send_sem, recv_sem, out_ref, *, axis_name, +): + axis_size = jax.lax.psum(1, axis_name) + left_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size + ) + right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size) + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_wait(barrier_sem, 1) + pltpu.make_async_remote_copy( + in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor + ).start() + +def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]: + send_sem, recv_sem, out = pl.pallas_call( + functools.partial(ppermute_start_kernel, axis_name=axis_name), + out_shape=( + pltpu.SemaphoreType.DMA(()), + pltpu.SemaphoreType.DMA(()), + jax.ShapeDtypeStruct( + x.shape, + dtype=x.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.ANY), + ), + )(x) + return send_sem, recv_sem, out +``` + +Note that something subtle is happening here. Pallas is telling XLA that it would like some outputs to be semaphores (a.k.a. sync flags) and XLA will treat them as “reserved” (e.g. while they are alive in the XLA program, those sync flags cannot be allocated by other kernels). They behave similarly to barrier semaphores, which are reserved semaphores managed by XLA. + +Another thing to notice is that we return the output buffer `out` from the start kernel *while it’s being actively copied into*. + +Now we write the `done` kernel that performs the blocking operation. We pass `out` into the kernel to compute the shape needed to block on the semaphore. + +```py +def ppermute_done_kernel(ref, send_sem, recv_sem, _): + pltpu.make_async_copy(ref, ref, send_sem).wait() + pltpu.make_async_copy(ref, ref, recv_sem).wait() + +def ppermute_done(send_sem, recv_sem, out) ->Array: + out = pl.pallas_call( + ppermute_done_kernel, + out_shape=( + jax.ShapeDtypeStruct( + out.shape, + dtype=out.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0:0} + )(out, send_sem, recv_sem) + return out +``` + +Note: we i/o alias the output buffer here to guarantee that the consumers are downstream of the `ppermute_done`. + +We now can implement the decomposed collective permute. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 # happens at the same time as ppermute + y = ppermute_done(fut) + return y, z +``` + +***OR CAN WE?*** + +## Why *doesn’t* this work? + +There are three remaining issues with this, each of which exists outside of Pallas to some degree. Here they are at a high level. + +1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX. +2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory. +3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness. + +We will go over these issues one by one and suggest fixes. + +### Scheduling + +How do we explicitly force ops to happen in a particular order in JAX? Note that this is not a Pallas specific problem, and if we had async ops implemented using an alternative method, we’d still run into this. + +One way is to introduce an optimization barrier into the XLA program. The optimization barrier will prevent XLA moving ops around it. + +Here’s our original code: + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z +``` + +XLA could choose to execute `x + 1` in any of three places: + +```py +def f(x): + z = x + 1 + fut = ppermute_start(x) + y = ppermute_done(fut) + return y, z + +# OR + +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z + +# OR + +def f(x): + fut = ppermute_start(x) + y = ppermute_done(fut) + z = x + 1 + return y, z +``` + +To force the `x + 1` to happen between the `ppermute` ops, we can use `optimization_barrier`, which is semantically the identity function (i.e. `lambda x: x`) but introduces an explicit data dependency between values. Specifically, if we make the `x` that is used in `x + 1` dependent on the `fut` returned by `ppermute_start`, it must happen after `ppermute_start`. + +We also introduce a dependency that forces the output value `y` to depend on `z`. + +```py +def f(x): + fut = ppermute_start(x) + x, fut = optimization_barrier((x, fut)) # x now depends on fut + z = x + 1 + z, fut = optimization_barrier((z, fut)) # fut now depends on z + y = ppermute_done(fut) + return y, z +``` + +`optimization_barrier` is a good enough hammer for us to explicitly write out schedules. + +### Lifetimes + +Let’s look at our original code again and assume the ops are happening in the correct order. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z +``` + +Let’s look at which point in the program XLA believes it is okay to free the buffer for `x`. It would be the point after which `x` is no longer used, specifically after `z = x + 1`. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + # XLA can free x here! + y = ppermute_done(fut) + return y, z +``` + +If XLA frees `x` after `z = x + 1` has completed, we run into a very bad problem. The `ppermute` could still be actively copying `x` to the neighbor after `z = x + 1` which means if `x` is freed, the `ppermute` will be reading from garbage memory\! + +How do we extend `x`’s lifetime to the `ppermute_done`? Well we can introduce a data dependency\! We need to modify our kernels a little bit to make this happen. + +First, we rewrite `ppermute_start` to return `x`, aliasing it through the kernel. + +```py +def ppermute_start_kernel( + in_ref, send_sem, recv_sem, out_ref, _, *, axis_name, +): + axis_size = jax.lax.psum(1, axis_name) + left_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size + ) + right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size) + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_wait(barrier_sem, 1) + pltpu.make_async_remote_copy( + in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor + ).start() + +def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]: + send_sem, recv_sem, x, out = pl.pallas_call( + functools.partial(ppermute_start_kernel, axis_name=axis_name), + out_shape=( + pltpu.SemaphoreType.DMA(()), + pltpu.SemaphoreType.DMA(()), + jax.ShapeDtypeStruct( + x.shape, + dtype=x.dtype, + ), + jax.ShapeDtypeStruct( + x.shape, + dtype=x.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ), + input_output_aliases={0:2} + )(x) + return send_sem, recv_sem, x, out +``` + +We then have `ppermute_done` take in `x` and do nothing with it. + +```py +def ppermute_done_kernel(_, ref, send_sem, recv_sem, _): + pltpu.make_async_copy(ref, ref, send_sem).wait() + pltpu.make_async_copy(ref, ref, recv_sem).wait() + +def ppermute_done(send_sem, recv_sem, x, out) ->Array: + out = pl.pallas_call( + ppermute_done_kernel, + out_shape=( + jax.ShapeDtypeStruct( + out.shape, + dtype=out.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={1:0} + )(x, out, send_sem, recv_sem) + return out + +``` + +Now when we write + +```py +def f(x): + *sems, x ,out = ppermute_start(x) + z = x + 1 + y = ppermute_done(*sems, x, out) + return y, z +``` + +XLA can no longer free `x` because it is an input to `ppermute_done`\! This means that `x`’s lifetime is tied to the `ppermute` and this code is now correct. + +### Defensive copies + +XLA, in its buffer assignment pass, analyzes which buffers are aliased to each other and inserts copies whenever an operation that aliases one of its inputs is not the final consumer of that input. + +#### Background + +Here’s a simple example. Let’s say we have an op `add_one_inplace` which takes in an array and adds one, but promises to do it in-place. + +The following code would be legal. + +```py +def f(): + x = jnp.arange(...) + y = add_one_inplace(x) return y +``` + +However, if `x` had a separate consumer as well, the program may not execute correctly. + +```py +def f(): + x = jnp.arange(...) + y = add_one_inplace(x) + return y, x * 2 # another x consumer! +``` + +This is because `x * 2` operates on the original `x` but `add_one_inplace` clobbers the value in `x`. `x * 2` needs to make sure to read the original values of `x`, not the ones after we’ve incremented it by 1\. XLA notices this and inserts a `copy` op (which is semantically the identity but the input and output buffers will be different). + +```py +def f(x): + x2 = copy(x) + y = add_one_inplace(x2) + return y, x * 2 +``` + +This pass in XLA ensures correctness in the presence of ops that perform in-place updates by forcing them to effectively be out-of-place with `copy` ops. + +#### Copies with downstream ops + +Let’s revisit our example where we add 1 while `ppermute`ing. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z +``` + +If we unpack the future into its components, we’ll see the the aliasing patterns: + +```py +def f(x): + *sems, x2, y = ppermute_start(x) + z = x + 1 + y = ppermute_done((*sems, x2, y)) + return y, z +``` + +We know that `x` is left unchanged by `ppermute_start` (that is, `x` is identical to `x2`), but XLA does not. In fact, it looks like our `add_one_inplace` example to XLA, where it conservatively assumes that `ppermute_start` mutated `x` and `x2` is the new aliased result. Therefore, when we do `z = x + 1`, we run into a consumer of the original buffer. XLA therefore introduces a copy\! + +```py +def f(x): + x2 = copy(x) + *sems, x2, y = ppermute_start(x2) + z = x + 1 + y = ppermute_done((*sems, x2, y)) + return y, z +``` + +This copy is unnecessary because we know that `x2` is unchanged from `x`. In order to remove this copy, we’d need some mechanism to inform XLA we are just forwarding a value. However, in the absence of that we can rewrite our program a bit to explicitly use `x2` instead of `x`. + +```py +def f(x): + *sems, x2, y = ppermute_start(x) + z = x2 + 1 + y = ppermute_done((*sems, x2, y)) + return y, z +``` + +Now, XLA doesn’t see a separate consumer of `x` so no more copy is introduced. However, this comes at a major downside in that it forces us to unpack the future coming from `ppermute_start`. It couples the lifetime problem to the copying problem. + +#### Loop aliasing + +Let’s consider a slightly more advanced example. Let’s implement a function that uses a `while_loop` with `ppermute` to send values around a ring. + +```py +def f(x): + def body(i, x): + fut = ppermute_start(x) + y = ppermute_done(fut) + return y + return fori_loop(0, 8, body, x) +``` + +One implementation detail of `fori_loop` is that the inputs and outputs buffers are automatically aliased to each other. Note that we are setting up some additional aliasing in the `ppermute_start` and `ppermute_done` ops. Let’s run our own “buffer assignment” by coloring each of the values in the program to determine how many unique buffers we need. + +First, we’ll unpack the `fut` tuple that has the aliased `x` and `out` buffers. + +```py +def f(x): + def body(i, x): + *sems, x, y = ppermute_start(x) + y = ppermute_done(*sems, x, y) + return y + return fori_loop(0, 8, body, x) +``` + +Let’s now color each of the values according to the unique buffer they are assigned. We have the input/output aliasing coming from `fori_loop`, the `x` aliasing coming from `ppermute_start` and the `y` aliasing coming from `ppermute_done`. + +```py +def f(x): + def body(i, x): + *sems, x, y = ppermute_start(x) + y = ppermute_done((*sems, x, y)) + return y + return fori_loop(0, 8, body, x) +``` + +If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer re-use and defensively insert a copy. + +```py +def f(x): + def body(i, x): + x = copy(x) + *sems, x, y = ppermute_start(x) + y = ppermute_done((*sems, x, y)) + return y + return fori_loop(0, 8, body, x) +``` + +This copy means `x` and `y` are no longer aliased to each other and the program will be correct. However, do we need this copy? How do we introduce a double buffer to avoid expensive copies each iteration? The answer is unrolling\! + +We’ll manually unroll our code. + +```py +def f(x): + def body(i, x): + *sems, x, x2 = ppermute_start(x) + x2 = ppermute_done((*sems, x, x2)) + + *sems, x2, y = ppermute_start(x2) + y = ppermute_done((*sems, x2, y)) + return y + return fori_loop(0, 4, body, x) +``` + +Now if we were to run the same alias analysis, we’ll find that the buffers all no longer alias to each other and that we won’t need to insert defensive copies to be correct. + +Therefore, the simple solution to removing these copies is to use `fori_loop` with `unroll >= 2`. + +```py +def f(x): + def body(i, x): + fut = ppermute_start(x) + y = ppermute_done(fut) + return y + return fori_loop(0, 8, body, x, unroll=2) +``` + +That’s sufficient to implement this loop without extra copies\! + +#### Passing futures across loop boundaries + +Let’s now look at an even more advanced example. We’ll implement the same program as before but stagger the loop, where we begin the `ppermute` in a prologue before the loop, and wait on the `ppermute` at the beginning of the loop. + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + x = ppermute_done(fut) + fut = ppermute_start(x) + return fut + fut = fori_loop(0, 7, body, fut) + return ppermute_done(fut) +``` + +In this example, rather than passing a value `x` from one loop to another we are passing a future value. + +Let’s unpack the future again to see what’s happening. + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + *sems, x, out = fut + x = ppermute_done((*sems, x, out)) + (*sems, x, out) = ppermute_start(x) + return (*sems, x, out) + (*sems, x, out) = fori_loop(0, 7, body, x) + return ppermute_done((*sems, x, out)) +``` + +So we’re explicitly threading the semaphores, the input buffer, and the target output buffer as a loop carry. What happens if we run alias analysis now? Well, we’ll run into the same aliasing issue as in the previous section where `x` and `out` will be aliased to each other. XLA will introduce a copy. + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + *sems, x, out = fut + out = copy(out) + x = ppermute_done((*sems, x, out)) + (*sems, x, out) = ppermute_start(x) + return (*sems, x, out) + (*sems, x, out) = fori_loop(0, 7, body, x) + return ppermute_done((*sems, x, out)) +``` + +In this case, we inserted a copy on `out`. However, this is a really bad scenario because `out` is being actively copied into\! Even if we insert a copy on `x`, we will also run into issues because then `x`’s lifetime will not extend to the `ppermute_done`. This is very very bad\! We will not only get copies, but we will also get incorrect results\! + +The solution, as we observed before, is to avoid the copies by avoiding aliasing all the buffers via unrolling. So, if we do: + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + x = ppermute_done(fut) + fut = ppermute_start(x) + return fut + fut = fori_loop(0, 7, body, x, unroll=2) + return ppermute_done(fut) +``` + +our program should now be correct. + +### Putting it all together + +So we’ve come up with some rules of thumb: + +1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value. +2. Use `unroll >= 2` when doing `ppermute`s in a loop body. + +Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result. + +```py +def f(x): + out = jnp.zeros_like(x) + fut = (*sems, x, out) = ppermute_start(x) + out = out + x + def body(i, carry): + out, fut = carry + x = ppermute_done(fut) + fut = (*sems, x, out) = ppermute_start(x) + out = out + x + return out, fut + out, fut = fori_loop(0, 7, body, (out, fut), unroll=2) + return out, ppermute_done(fut) +``` + +Note that in this example, we don’t need `optimization_barrier`s because the loop boundary acts as a scheduling barrier, splitting up the `start`s and `done`s. + +That’s it, we are done\! This will be the official API for doing async ops in Pallas. Thank you everyone\! Mission accomplished\! + +***OR IS IT?*** + +## Revenge of the State + +While it seems we have worked around copies and incorrectness issues by using some clever tricks, we are still in an awkward position. This API is powerful, but has many many footguns and caveats. There are likely far many more edge cases we will need to deal with that even require deep knowledge of XLA to predict or understand. Should we release an API like this? Or is there an alternative? + +Well, the answer may have been in front of us this whole time. + +Let’s run through this whole exercise one more time, *except*, let’s write the stateful version. This means each of our custom async ops now operate on `Ref`s instead of values. + +```py +def ppermute_start_stateful(x_ref, y_ref) -> tuple[Semaphore, Semaphore]: + ... + +def ppermute_done_stateful(send_sem, recv_sem, x_ref, y_ref) -> None: + ... +``` + +Let’s assume we can implement these in Pallas and see what our new programs will look like. Let’s start with a basic collective permute: + +```py +def f(x): + x_ref = make_ref(x) + y_ref = make_ref(zeros_like(x)) + fut = ppermute_start_stateful(x_ref, y_ref) + ppermute_done_stateful(*fut, x_ref, y_ref) + return y_ref[...] +``` + +It’s a little bit more verbose than our original value-based version, but it has a few key differences. The first is that we create an “empty” `Ref` to receive the result of the `ppermute`, unlike the value-based version, which creates a value for us. One neat thing is that the lifetime of `x_ref` is clear here: it lives until `ppermute_done_stateful`. We don’t need to “sneak” the `x` value into the op like we did before. + +Another difference becomes more clear when we try adding an op between the `start/done`. + +```py +def f(x): + x_ref = make_ref(x) + y_ref = make_ref(zeros_like(x)) + fut = ppermute_start_stateful(x_ref, y_ref) + x_ref[...] += 1 + ppermute_done_stateful(*fut, x_ref, y_ref) + return y_ref[...] +``` + +Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO. + +The final key difference is evident when we try our loop examples. + +```py +def f(x): + x_ref = make_ref(x) + y_ref = make_ref(zeros_like(x)) + def body(i, _): + fut = ppermute_start_stateful(x_ref, y_ref) + ppermute_done_stateful(*fut, x_ref, y_ref) + # Now switch to y_ref -> x_ref + fut = ppermute_start_stateful(y_ref, x_ref) + ppermute_done_stateful(*fut, y_ref, x_ref) + fori_loop(0, 8 // 2, body, None) + return x_ref[...] +``` + +Because of the requirement that we have a separate buffer ready to receive the `ppermute`, we were forced to write our code in such a way that unrolls it\! There is no way to write the version in XLA that requires copying because that would involve a `ppermute` that sends from a `Ref` into itself, which doesn’t really make sense. + +To handle this without the manual unrolling, we’d create a scratch buffer with a leading `2` dimension that acts as the send/recv target across iterations, switching each one. This is the same pattern we use internally in Pallas kernels when writing manually overlapped kernels. + +The realization here is that being stateful forces us to deal with a lot of the issues that pop up with value semantics earlier on. We define them away\! + +1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints. +2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops. +3. Defensive copies \- Using `Ref`s forces us to handle buffer assignment “manually” and the lowering can ensure the aliasing works out to avoid any copies. + +Another important fundamental limitation is that we eventually stage out an HLO program where the live buffers and semaphores are represented as array value types. XLA does not provide guarantees about buffer lifetimes or which memory spaces they live in for these intermediate values. *Therefore, it is possible XLA can copy array values even if they are actively being copied into by Pallas kernels.* This is easy to verify in HLO but it is a sharp edge of using custom calls to represent asynchronous operations in HLO. + +## Conclusion + +We’ve gone over some tricky challenges when it comes to async ops in Pallas and JAX. `Ref`s seem like a promising way of representing these ops that circumvents some of the issues that come up with value semantics. However, a downside is that it puts stateful JAX front and center, which we haven’t done yet outside of Pallas. It’s worth thinking whether we should educate users about stateful ops, or provide a more dangerous API. We also don’t know if everything we want to do is expressible via `Ref`s as well. We should also brainstorm alternatives to state to flesh out the design space. For example, what if XLA offered a first-class futures API that respected lifetimes, and it could automatically do things like double buffer loops with futures in them? That might be a viable alternative but the tradeoff would be giving more control to the compiler vs explicit control from the user. diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index 467f375d0e43..5969349c962a 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -33,6 +33,13 @@ See also the :class:`jax.experimental.pallas` module API documentation. tpu/index .. toctree:: + :caption: Design Notes + :maxdepth: 1 + + async_note + +.. toctree:: + :caption: Other :maxdepth: 1 CHANGELOG From 242cb7cbc711c63adcbd83d890e88e6419c6fd1b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 17 Sep 2024 12:50:04 -0700 Subject: [PATCH 524/702] Fix the __repr__ of JaxprEqnContext PiperOrigin-RevId: 675673700 --- jax/_src/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 6c3b4093b071..74d03b8d9464 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -283,9 +283,9 @@ def manager(self): def __repr__(self): return ( - f"JaxprEqnContext(compute_type={self.compute_type}," - f"threefry_partitionable={self.threefry_partitionable})," - f"xla_metadata={self.xla_metadata}" + f"JaxprEqnContext(compute_type={self.compute_type}, " + f"threefry_partitionable={self.threefry_partitionable}, " + f"xla_metadata={self.xla_metadata})" ) From 3f2c58b9c65e9406486731a5ea24f4c7496e6c1b Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 17 Sep 2024 13:23:04 -0700 Subject: [PATCH 525/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f6b6175735336f6bdf0ec4af79a3314e6673ccd6. PiperOrigin-RevId: 675686033 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 31aaf75f4504..0df1b77fbb39 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "90be6e3a11b3489451dcacf918febda2a32f7b10" -XLA_SHA256 = "e924c732353ab0fa3fcee5c29b316118678d8dcbac0d16fa06c72e2cb8fb96b1" +XLA_COMMIT = "f6b6175735336f6bdf0ec4af79a3314e6673ccd6" +XLA_SHA256 = "7033fba5ae9cb701173cf534825a7aa95425c0f4d174b6611293d0d08962492e" def repo(): tf_http_archive( From 83a7555ffd355545c1f3d4642eaf4ab5d18ebcc8 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Mon, 16 Sep 2024 16:47:52 -0700 Subject: [PATCH 526/702] docstring_sort_complex_added input_array_modified --- jax/_src/numpy/lax_numpy.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ef6a30400d30..69b3e1023fa3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -8561,9 +8561,29 @@ def sort( return lax.rev(result, dimensions=[dimension]) if descending else result -@util.implements(np.sort_complex) @jit def sort_complex(a: ArrayLike) -> Array: + """Return a sorted copy of complex array. + + JAX implementation of :func:`numpy.sort_complex`. + + Complex numbers are sorted lexicographically, meaning by their real part + first, and then by their imaginary part if real parts are equal. + + Args: + a: input array. If dtype is not complex, the array will be upcast to complex. + + Returns: + A sorted array of the same shape and complex dtype as the input. + + See also: + - :func:`jax.numpy.sort`: Return a sorted copy of an array. + + Examples: + >>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j]) + >>> jnp.sort_complex(a) + Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64) + """ util.check_arraylike("sort_complex", a) a = lax.sort(asarray(a), dimension=0) return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) From e92a599a96374064d53a7230086992e989af542a Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Tue, 17 Sep 2024 15:26:42 -0700 Subject: [PATCH 527/702] [mosaic_gpu] Better error message for misaligned tma_transpose with dtype. PiperOrigin-RevId: 675731295 --- jax/experimental/mosaic/gpu/examples/matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 52d403cd0131..775b7c2ea898 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -132,7 +132,7 @@ def build_kernel( if stages < 2: raise ValueError(f"Need at least 2 stages, but got {stages=}") if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2: - raise ValueError("Transpose only supported for only happen for 16bit types") + raise ValueError(f"Transpose only supported for 16bit types (got: {rhs_transpose=}, {rhs_dtype=})") if swizzle not in {32, 64, 128}: raise ValueError(f"swizzle must be 32, 64, or 128, but got {swizzle=}") From 3c37d4f20e5f2f9370181263fd02952648c7aae0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 17 Sep 2024 15:58:14 -0700 Subject: [PATCH 528/702] Improve documentation for jax.lax.stop_gradient --- docs/_tutorials/advanced-autodiff.md | 2 +- jax/_src/lax/lax.py | 49 +++++++++++++++++++++------- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index 180f65f5d492..da5cd0feaa1a 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -77,7 +77,7 @@ def meta_loss_fn(params, data): meta_grads = jax.grad(meta_loss_fn)(params, data) ``` - +(stopping-gradients)= ### Stopping gradients Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8d2c24d6e64c..c791c668e68b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1372,18 +1372,43 @@ def stop_gradient(x: T) -> T: argument `x` unchanged. However, ``stop_gradient`` prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, ``stop_gradient`` stops gradients - for all of them. - - For example: - - >>> jax.grad(lambda x: x**2)(3.) - Array(6., dtype=float32, weak_type=True) - >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) - Array(0., dtype=float32, weak_type=True) - >>> jax.grad(jax.grad(lambda x: x**2))(3.) - Array(2., dtype=float32, weak_type=True) - >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) - Array(0., dtype=float32, weak_type=True) + for all of them. For some discussion of where this is useful, refer to + :ref:`stopping-gradients`. + + Args: + x: array or pytree of arrays + + Returns: + input value is returned unchanged, but within autodiff will be treated as + a constant. + + Examples: + Consider a simple function that returns the square of the input value: + + >>> def f1(x): + ... return x ** 2 + >>> x = jnp.float32(3.0) + >>> f1(x) + Array(9.0, dtype=float32) + >>> jax.grad(f1)(x) + Array(6.0, dtype=float32) + + The same function with ``stop_gradient`` around ``x`` will be equivalent + under normal evaluation, but return a zero gradient because ``x`` is + effectively treated as a constant: + + >>> def f2(x): + ... return jax.lax.stop_gradient(x) ** 2 + >>> f2(x) + Array(9.0, dtype=float32) + >>> jax.grad(f2)(x) + Array(0.0, dtype=float32) + + This is used in a number of places within the JAX codebase; for example + :func:`jax.nn.softmax` internally normalizes the input by its maximum + value, and this maximum value is wrapped in ``stop_gradient`` for + efficiency. Refer to :ref:`stopping-gradients` for more discussion of + the applicability of ``stop_gradient``. """ def stop(x): # only bind primitive on inexact dtypes, to avoid some staging From 86fe463ad7221de7f4078fcbba9c6bf1af0b19ba Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 17 Sep 2024 16:10:41 -0700 Subject: [PATCH 529/702] [Take 2] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums. This allows us to get more cache hits globally. For example: Before: jax.jit(f, out_shardings=s)(arr) jax.jit(f, out_shardings=s)(arr) # cpp cache miss After: jax.jit(f, out_shardings=s)(arr) jax.jit(f, out_shardings=s)(arr) # cpp cache hit Reverts b615266175effe4aefeb903620a19f3719a604da PiperOrigin-RevId: 675746175 --- jax/_src/api.py | 6 +- jax/_src/interpreters/pxla.py | 42 +++++++++- jax/_src/pjit.py | 117 +++++++++++++++++++--------- jax/experimental/multihost_utils.py | 11 ++- tests/pjit_test.py | 33 +++++--- 5 files changed, 153 insertions(+), 56 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 8ca3803aec35..b548cc43fb3b 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2726,7 +2726,8 @@ def clear_backends(): pjit._infer_params_cached.cache_clear() pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error - pjit._cpp_pjit_cache.clear() + pjit._cpp_pjit_cache_fun_only.clear() + pjit._cpp_pjit_cache_explicit_attributes.clear() xc._xla.PjitFunctionCache.clear_all() @atexit.register @@ -2755,7 +2756,8 @@ def clear_caches(): util.clear_all_weakref_lru_caches() # Clear all C++ compiled executable caches for pjit - pjit._cpp_pjit_cache.clear() + pjit._cpp_pjit_cache_fun_only.clear() + pjit._cpp_pjit_cache_explicit_attributes.clear() pjit._infer_params_cached.cache_clear() xc._xla.PjitFunctionCache.clear_all() diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b7d68f73c2a4..944e20fa7faa 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -22,6 +22,7 @@ from collections.abc import Callable, Sequence, Iterable, Iterator import dataclasses from functools import partial, lru_cache, cached_property +import functools import itertools as it import logging import math @@ -61,6 +62,7 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -88,6 +90,7 @@ class WeakRefList(list): logger = logging.getLogger(__name__) Index = Union[int, slice, tuple[Union[int, slice], ...]] +PyTreeDef = tree_util.PyTreeDef NoSharding = sharding_specs.NoSharding Chunked = sharding_specs.Chunked @@ -2904,6 +2907,34 @@ class MeshExecutableFastpathData(NamedTuple): in_device_local_layouts: Sequence[DeviceLocalLayout | None] +@dataclasses.dataclass(frozen=True, kw_only=True) +class JitGlobalCppCacheKeys: + donate_argnums: tuple[int, ...] | None = None + donate_argnames: tuple[str, ...] | None = None + device: xc.Device | None = None + backend: str | None = None + in_shardings_treedef: PyTreeDef | None = None + in_shardings_leaves: tuple[Any, ...] | None = None + out_shardings_treedef: PyTreeDef | None = None + out_shardings_leaves: tuple[Any, ...] | None = None + in_layouts_treedef: PyTreeDef | None = None + in_layouts_leaves: tuple[Any, ...] | None = None + out_layouts_treedef: PyTreeDef | None = None + out_layouts_leaves: tuple[Any, ...] | None = None + use_resource_env: bool = False + + @functools.cached_property + def contains_explicit_attributes(self): + return (self.donate_argnums is not None or + self.donate_argnames is not None or + self.device is not None or + self.backend is not None or + any(not is_unspecified(i) for i in self.in_shardings_leaves) or + any(not is_unspecified(o) for o in self.out_shardings_leaves) or + any(i is not None for i in self.in_layouts_leaves) or + any(o is not None for o in self.out_layouts_leaves)) + + def reflatten_outputs_for_dispatch(out_tree, out_flat): # We arrive at dispatch having flattened according to the default # pytree registry, but we want to re-flatten according to our @@ -3017,9 +3048,14 @@ def aot_cache_miss(*args, **kwargs): fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, cc_shard_arg) + if xla_extension_version >= 286: + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], + JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) + else: + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], [], + tree_util.dispatch_registry, cc_shard_arg) def cc_shard_arg(x, sharding, layout): return shard_args([sharding], [layout], [x])[0] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index fb76f7931c01..42a7c966b4d6 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -62,6 +62,7 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src import sharding from jax._src.mesh import AbstractMesh from jax._src.sharding_impls import ( @@ -164,7 +165,6 @@ class PjitInfo(NamedTuple): keep_unused: bool inline: bool abstracted_axes: Any | None - has_explicit_sharding: bool use_resource_env: bool # False for jit, True for pjit # Hash and compare PjitInfo by identity when used as a cache key. @@ -311,14 +311,39 @@ def _cpp_pjit_evict_fn(self): # The entries are doubled here from the default 4096 because _pjit_call_impl # also has a cpp dispatch path and that would double the number of entries in # the global shared cache. -_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192) +# This cache is only used for jit's with only fun. For example: jax.jit(f) +_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192) +# This cache is used for jit where extra arguments are defined other than the +# fun. For example: jax.jit(f, donate_argnums=...) OR +# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the +# capacity might get full very fast because of all the jitted function in JAX +# which might evict train_step for example. +_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192) -def _get_cpp_global_cache(pjit_has_explicit_sharding): - if pjit_has_explicit_sharding: - return xc._xla.PjitFunctionCache() - else: - return _cpp_pjit_cache + +if xla_extension_version < 286: + def _get_cpp_global_cache(pjit_has_explicit_sharding): + if pjit_has_explicit_sharding: + return xc._xla.PjitFunctionCache() + else: + return _cpp_pjit_cache_fun_only + + def _pjit_explicit_sharding_and_layout( + in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, + device, backend) -> bool: + return (device is not None or + backend is not None or + any(not is_unspecified(i) for i in in_shardings_flat) or + any(not is_unspecified(o) for o in out_shardings_flat) or + any(i is not None for i in in_layouts_flat) or + any(o is not None for o in out_layouts_flat)) +else: + def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore + if contains_explicit_attributes: + return _cpp_pjit_cache_explicit_attributes + else: + return _cpp_pjit_cache_fun_only def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @@ -339,11 +364,35 @@ def cache_miss(*args, **kwargs): return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), - fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, - jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(jit_info.has_explicit_sharding)) + if xla_extension_version >= 286: + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=jit_info.donate_argnums, + donate_argnames=jit_info.donate_argnames, + device=jit_info.device, backend=jit_info.backend, + in_shardings_treedef=jit_info.in_shardings_treedef, + in_shardings_leaves=jit_info.in_shardings_leaves, + out_shardings_treedef=jit_info.out_shardings_treedef, + out_shardings_leaves=jit_info.out_shardings_leaves, + in_layouts_treedef=jit_info.in_layouts_treedef, + in_layouts_leaves=jit_info.in_layouts_leaves, + out_layouts_treedef=jit_info.out_layouts_treedef, + out_layouts_leaves=jit_info.out_layouts_leaves, + use_resource_env=jit_info.use_resource_env) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore + pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes)) + else: + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + jit_info.in_shardings_leaves, jit_info.out_shardings_leaves, + jit_info.in_layouts_leaves, jit_info.out_layouts_leaves, + jit_info.device, jit_info.backend) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, jit_info.donate_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun @@ -351,17 +400,6 @@ def cache_miss(*args, **kwargs): return cpp_pjitted_f -def _pjit_explicit_sharding_and_layout( - in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, - device, backend) -> bool: - return (device is not None or - backend is not None or - any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(o) for o in out_shardings_flat) or - any(i is not None for i in in_layouts_flat) or - any(o is not None for o in out_layouts_flat)) - - def _split_layout_and_sharding(entries): entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) layouts, shardings = [], [] @@ -445,10 +483,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings_leaves, out_shardings_leaves, in_layouts_leaves, - out_layouts_leaves, device, backend) - return PjitInfo( fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, @@ -466,7 +500,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, abstracted_axes=abstracted_axes, - has_explicit_sharding=has_explicit_sharding, use_resource_env=use_resource_env) @@ -1706,13 +1739,27 @@ def call_impl_cache_miss(*args_, **kwargs_): f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) - donated_argnums = [i for i, d in enumerate(donated_invars) if d] - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings, out_shardings, in_layouts, out_layouts, None, None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(has_explicit_sharding))(*args) + donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) + if xla_extension_version >= 286: + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=donated_argnums, donate_argnames=None, + device=None, backend=None, + in_shardings_treedef=None, in_shardings_leaves=in_shardings, + out_shardings_treedef=None, out_shardings_leaves=out_shardings, + in_layouts_treedef=None, in_layouts_leaves=in_layouts, + out_layouts_treedef=None, out_layouts_leaves=out_layouts, + use_resource_env=resource_env is not None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], cache_key, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) + else: + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings, out_shardings, in_layouts, out_layouts, None, None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], donated_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 554bf2641769..56003ea7af5d 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -90,19 +90,17 @@ def sync_global_devices(name: str): assert_equal(h, f"sync_global_devices name mismatch ('{name}')") +# Identity function is at the top level so that `process_allgather` doesn't +# recompile on every invocation. def _identity_fn(x): return x -@lru_cache(maxsize=128) -def _jitted_identity_fn(sharding): - return jax.jit(_identity_fn, out_shardings=sharding) - def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: reps = sharding_impls.GSPMDSharding.get_replicated( inp.sharding._device_assignment) - out = _jitted_identity_fn(reps)(inp) + out = jax.jit(_identity_fn, out_shardings=reps)(inp) else: # All inputs here will be fully addressable. if jax.process_count() == 1: @@ -125,7 +123,8 @@ def _handle_array_process_allgather(inp, tiled): bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( global_aval.shape, s, bufs) - out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr) + out = jax.jit(_identity_fn, + out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr) return np.asarray(out.addressable_data(0)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6c022653581d..11a541f2e5f5 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -57,6 +57,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -652,18 +653,16 @@ def testAutodiff(self, mesh, resources): @jtu.with_mesh([('x', 2), ('y', 1)]) def testAutodiffCache(self): - f = pjit( - lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None - ) + f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None) x = jnp.arange(16, dtype=jnp.float32) - jax.grad(f)(x) # Warm up the cache. - before = pjit_lib._pjit_lower_cached.cache_info() - jax.grad(f)(x) - after = pjit_lib._pjit_lower_cached.cache_info() - # One hit for the forward pass, one hit for backward. - self.assertEqual(after.hits, before.hits + 2) - self.assertEqual(after.misses, before.misses) + jax.grad(f)(x) # Warm up the cache. + with jtu.count_pjit_cpp_cache_miss() as count: + jax.grad(f)(x) + if xla_extension_version >= 286: + self.assertEqual(count[0], 0) # no cache miss i.e. cache hit + else: + self.assertEqual(count[0], 2) @jtu.with_mesh([('x', 2), ('y', 1)]) def testEvalJaxpr(self): @@ -4531,6 +4530,20 @@ def test_wsc_abstract_mesh_errors(self): ' match the mesh shape of the target sharding.*'): with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y'))) + @unittest.skipIf(xla_extension_version < 286, + "Requires xla_extension_version >= 286") + def test_global_jit_cpp_cache_hit_out_shardings(self): + mesh = jtu.create_mesh((2,), 'x') + s = NamedSharding(mesh, P('x')) + + def f(x): + return x * 2 + + with jtu.count_pjit_cpp_cache_miss() as count: + jax.jit(f, out_shardings=s)(np.arange(8)) + jax.jit(f, out_shardings=s)(np.arange(8)) + self.assertEqual(count[0], 1) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") From 8b5b71750b009fdd979dfd0abeb43a359a60c664 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 17 Sep 2024 16:39:55 -0700 Subject: [PATCH 530/702] Fix jaxpr equation context propagation in jaxpr equations when `inline=True`. PiperOrigin-RevId: 675754808 --- jax/_src/core.py | 3 +-- jax/_src/interpreters/partial_eval.py | 3 +-- jax/_src/pjit.py | 2 -- tests/memories_test.py | 23 +++++++++++++++++++++++ 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 74d03b8d9464..51933a9f8bbf 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -343,8 +343,7 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, ctx = ctx or JaxprEqnContext( compute_on.current_compute_type(), config.threefry_partitionable.value, - xla_metadata_lib.current_xla_metadata(), - ) + xla_metadata_lib.current_xla_metadata()) if config.enable_checks.value: assert all(isinstance(x, (Var, Literal)) for x in invars) assert all(isinstance(v, Var) for v in outvars) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 2d27bf064fce..374816e001ec 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2828,8 +2828,7 @@ def inline_jaxpr_into_trace( outvars = [Var('', v.aval) for v in eqn.outvars] src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) - trace.frame.add_eqn(core.new_jaxpr_eqn(invars, outvars, eqn.primitive, - eqn.params, eqn.effects, src_)) + trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) # type: ignore map(env.setdefault, eqn.outvars, outvars) tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 42a7c966b4d6..34bf257f639e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1800,13 +1800,11 @@ def pjit_staging_rule(trace, *args, **params): params['jaxpr'], params['out_shardings'], params['out_layouts']) params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) - if (params["inline"] and all(is_unspecified(i) for i in params["in_shardings"]) and all(is_unspecified(o) for o in params["out_shardings"]) and all(i is None for i in params["in_layouts"]) and all(o is None for o in params["out_layouts"])): - if config.dynamic_shapes.value: # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, diff --git a/tests/memories_test.py b/tests/memories_test.py index 68aecfdf669f..3e0f444a1e66 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -742,6 +742,29 @@ def h(x): self.assertArraysEqual(out2, inp * 6) self.assertEqual(out2.sharding.memory_kind, 'pinned_host') + def test_compute_on_basic_inline(self): + @compute_on('device_host') + @jax.jit + def g(x): + return x * 2 + + @functools.partial(jax.jit, inline=True) + def h(x): + y = g(x) + return y * 3 + + @jax.jit + def f(x): + return h(x) + + inp = jnp.arange(8) + out = f(inp) + self.assertArraysEqual(out, inp * 6) + + lowered_text = f.lower(jnp.arange(8)).as_text('hlo') + self.assertRegex(lowered_text, + 'to_apply=g.*frontend_attributes={_xla_compute_type="host"}') + def test_compute_on_reduction(self): out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host') From 8bcdb1285218d42e051882b33abf65a75649488b Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 17 Sep 2024 16:50:55 -0700 Subject: [PATCH 531/702] Add CI jobs for python 3.13.0rc2. PiperOrigin-RevId: 675758096 --- .bazelrc | 4 ++++ .github/workflows/wheel_win_x64.yml | 2 +- build/requirements.in | 2 ++ build/requirements_lock_3_13.txt | 10 ++++++---- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/.bazelrc b/.bazelrc index 9d5d9664939e..948d92c29c26 100644 --- a/.bazelrc +++ b/.bazelrc @@ -215,6 +215,8 @@ build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" +build:rbe_cpu_linux_py3.13 --config=rbe_cpu_linux_base +build:rbe_cpu_linux_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13" build:rbe_linux_cuda_base --config=rbe_linux build:rbe_linux_cuda_base --config=cuda @@ -237,6 +239,8 @@ build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" +build:rbe_linux_cuda12.3_nvcc_py3.13 --config=rbe_linux_cuda12.3_nvcc_base +build:rbe_linux_cuda12.3_nvcc_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13" # These you may need to change for your own GCP project. build:tensorflow_testing_rbe --project_id=tensorflow-testing diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 447ccba4f8c2..bae1edec0214 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -17,7 +17,7 @@ jobs: matrix: os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.10', '3.11', '3.12'] + pyver: ['3.10', '3.11', '3.12', '3.13.0rc2'] name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build runs-on: ${{ matrix.os }} diff --git a/build/requirements.in b/build/requirements.in index f6b5b18b2660..a8d81fa5c670 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -23,3 +23,5 @@ ml_dtypes>=0.4.0 opt_einsum zstandard etils[epath] +# TODO(ybaturina): remove setuptools version +setuptools<71.0.0 diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index ef121b73713b..e2369a8001bb 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -732,7 +732,9 @@ zstandard==0.23.0 \ # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==75.1.0 \ - --hash=sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2 \ - --hash=sha256:d59a21b17a275fb872a9c3dae73963160ae079f1049ed956880cd7c09b120538 - # via -r build/test-requirements.txt +setuptools==70.3.0 \ + --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ + --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc + # via + # -r build/requirements.in + # -r build/test-requirements.txt From 988ed2bd75df5fe25b74eaf38075aadff19be207 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 17 Sep 2024 21:09:26 -0700 Subject: [PATCH 532/702] Add support for SMEM windows in Pallas custom pipeline. PiperOrigin-RevId: 675822640 --- jax/_src/pallas/mosaic/pipeline.py | 66 +++++++++++++++++++----------- 1 file changed, 41 insertions(+), 25 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index fca9ee471e6a..e8f2384784eb 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -189,7 +189,7 @@ class BufferedRef: dtype: dtype for buffers. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. - vmem_ref: a double-buffer to hold a working buffer and a dirty buffer used + window_ref: a double-buffer to hold a working buffer and a dirty buffer used to copy into and out of. In the case of a BufferedRef targeting a VMEM reference, this simply points to the existing ref. accum_ref: accumulating buffer used by accumulator BufferedRefs. @@ -210,7 +210,7 @@ class BufferedRef: spec: pl.BlockSpec # static metadata dtype: Any # static metadata buffer_type: BufferType # static metadata - vmem_ref: REF | None + window_ref: REF | None accum_ref: REF | None current_slot: ArrayRef | None next_slot: ArrayRef | None @@ -218,9 +218,17 @@ class BufferedRef: sem_sends: SemaphoreTuple | None def tree_flatten(self): - return ((self.vmem_ref, self.accum_ref, self.current_slot, - self.next_slot, self.sem_recvs, self.sem_sends), - (self.spec, self.dtype, self.buffer_type)) + return ( + ( + self.window_ref, + self.accum_ref, + self.current_slot, + self.next_slot, + self.sem_recvs, + self.sem_sends, + ), + (self.spec, self.dtype, self.buffer_type), + ) @classmethod def tree_unflatten(cls, meta, data): @@ -252,7 +260,7 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef: spec=spec, dtype=dtype, buffer_type=buffer_type, - vmem_ref=None, # to be bound to existing ref by the pipeline routine + window_ref=None, # to be bound to existing ref by the pipeline routine accum_ref=accum_ref, current_slot=None, next_slot=None, @@ -260,11 +268,12 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef: sem_sends=None, ) else: + memory_space = SMEM if spec.memory_space == SMEM else VMEM return cls( spec=spec, dtype=dtype, buffer_type=buffer_type, - vmem_ref=VMEM((2,) + block_shape, dtype), + window_ref=memory_space((2,) + block_shape, dtype), accum_ref=accum_ref, current_slot=SMEM((1,), jnp.int32), next_slot=SMEM((1,), jnp.int32), @@ -313,9 +322,9 @@ def current_ref(self): buffer_slice = tuple( 0 if x is None else slice(None) for x in self.block_shape) if self.memory_space == VMEM: - return self.vmem_ref.at[buffer_slice] + return self.window_ref.at[buffer_slice] else: - return self.vmem_ref.at[(self.current_slot[0], *buffer_slice)] + return self.window_ref.at[(self.current_slot[0], *buffer_slice)] @property def is_input(self): @@ -341,11 +350,12 @@ def is_accumulator(self): def is_input_output(self): return self.buffer_type == BufferType.INPUT_OUTPUT - def bind_existing_ref(self, vmem_ref, indices): + def bind_existing_ref(self, window_ref, indices): """For handling VMEM references, the pipeline aliases the existing ref.""" if self.memory_space == VMEM: return dataclasses.replace( - self, vmem_ref=vmem_ref.at[self.compute_slice(indices)]) + self, window_ref=window_ref.at[self.compute_slice(indices)] + ) return self def compute_slice(self, grid_indices): @@ -432,8 +442,9 @@ def copy_in(self, src_ref, grid_indices): dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) tpu_primitives.make_async_copy( src_ref.at[src_slice], - self.vmem_ref.at[next_slot].at[dst_slice], - self.sem_recvs.at[next_slot]).start() + self.window_ref.at[next_slot].at[dst_slice], + self.sem_recvs.at[next_slot], + ).start() def copy_out(self, dst_ref, grid_indices): """Starts copy of HBM dma slice from the current slot.""" @@ -444,9 +455,10 @@ def copy_out(self, dst_ref, grid_indices): dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.vmem_ref.at[slot].at[src_slice], + self.window_ref.at[slot].at[src_slice], dst_ref.at[dst_slice], - self.sem_sends.at[slot]).start() + self.sem_sends.at[slot], + ).start() def wait_in(self, src_ref, grid_indices): """Waits for input copy to finish.""" @@ -456,9 +468,12 @@ def wait_in(self, src_ref, grid_indices): dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) current_slot = self.current_slot[0] tpu_primitives.make_async_copy( - src_ref.at[src_slice], # nb: doesn't matter - self.vmem_ref.at[current_slot].at[dst_slice], # only dst shape is important - self.sem_recvs.at[current_slot]).wait() + src_ref.at[src_slice], # nb: doesn't matter + self.window_ref.at[current_slot].at[ + dst_slice + ], # only dst shape is important + self.sem_recvs.at[current_slot], + ).wait() def wait_out(self, dst_ref, grid_indices): """Waits for output copy to finish.""" @@ -468,9 +483,10 @@ def wait_out(self, dst_ref, grid_indices): dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.vmem_ref.at[prev_slot].at[src_slice], # nb: doesn't matter - dst_ref.at[dst_slice], # only dst shape is important - self.sem_sends.at[prev_slot]).wait() + self.window_ref.at[prev_slot].at[src_slice], # nb: doesn't matter + dst_ref.at[dst_slice], # only dst shape is important + self.sem_sends.at[prev_slot], + ).wait() # Accumulator methods # @@ -498,14 +514,14 @@ def accumulate(self): assert self.is_accumulator if self.accum_ref is not None: accum_dtype = jnp.float32 - if self.vmem_ref.dtype == jnp.int32: + if self.window_ref.dtype == jnp.int32: accum_dtype = jnp.int32 # TODO(levskaya): we could generalize init and reduction functions, # could it ever be useful to support more generic monoids? self.current_ref[...] = ( - self.current_ref[...].astype(accum_dtype) + - self.accum_ref[...].astype(accum_dtype) - ).astype(self.vmem_ref.dtype) + self.current_ref[...].astype(accum_dtype) + + self.accum_ref[...].astype(accum_dtype) + ).astype(self.window_ref.dtype) # Helper to tree map over BufferedRefs as leaves. From b904599b98cca5fb73a387911afc685b290b623b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 18 Sep 2024 04:23:25 -0700 Subject: [PATCH 533/702] `pl.debug_print` no longer restricts values to be scalars This allows printing arrays on Triton and soon on Mosaic GPU. PiperOrigin-RevId: 675935666 --- docs/pallas/CHANGELOG.md | 18 +++++++++++++----- jax/_src/pallas/mosaic/lowering.py | 3 +++ jax/_src/pallas/mosaic_gpu/lowering.py | 5 +++-- jax/_src/pallas/primitives.py | 12 +++++------- jax/_src/pallas/triton/lowering.py | 9 ++++++++- 5 files changed, 32 insertions(+), 15 deletions(-) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index c1ed1385bbbc..b39d0211c761 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -11,6 +11,18 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c Remember to align the itemized text with the first line of an item within a list. --> +## Released with jax 0.4.34 + +* Changes + + * {func}`jax.experimental.pallas.debug_print` no longer requires all arguments + to be scalars. The restrictions on the arguments are backend-specific: + Non-scalar arguments are currently only supported on GPU, when using Triton. + +## Released with jax 0.4.33 (September 16, 2024) + +## Released with jax 0.4.32 (September 11, 2024) + ## Released with jax 0.4.32 * Changes @@ -19,7 +31,7 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations -* New functionality: +* New functionality * Improved error messages for mistakes in the signature of the index map functions, to include the name and source location of the index map. @@ -73,7 +85,3 @@ Remember to align the itemized text with the first line of an item within a list * Added checkify support for {func}`jax.experimental.pallas.pallas_call` in interpret mode ({jax-issue}`#21862`). * Improved support for PRNG keys for TPU kernels ({jax-issue}`#21773`). - - - - diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 13d861033e90..f76a4d86616a 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2737,6 +2737,9 @@ def _delay_rule(ctx: LoweringRuleContext, nanos: int): def _debug_print_rule( ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool ): + if any(aval.shape for aval in ctx.avals_in): + raise NotImplementedError("Only scalar values are supported") + primitives.check_debug_print_format(fmt, *args) if has_placeholders: if not all( diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 39483e674681..ef10236e2fdd 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -693,8 +693,9 @@ def _debug_print_lowering_rule( fmt, has_placeholders: bool, ): - del ctx - del has_placeholders + del has_placeholders # Unused. + if any(aval.shape for aval in ctx.avals_in): + raise NotImplementedError("Only scalar values are supported") primitives.check_debug_print_format(fmt, *args) mgpu.debug_print(fmt, *args) return () diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index fbc389aae3fb..8cba0a36c6e4 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -714,7 +714,7 @@ class PrintEffect(effects.Effect): def debug_print(fmt: str, *args: jax.typing.ArrayLike): - """Prints scalar values from inside a Pallas kernel. + """Prints values from inside a Pallas kernel. Args: fmt: A format string to be included in the output. The restrictions on the @@ -724,11 +724,11 @@ def debug_print(fmt: str, *args: jax.typing.ArrayLike): (``{...}``), since it is always printed before any of the values. * On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must contain a placeholder for each value to be printed. Format specs and - conversions are not supported. + conversions are not supported. All values must be scalars. * In TPU, if ``fmt`` contains placeholders, all values must be 32-bit integers. If there are no placeholders, the values are printed after - the format string. - *args: The scalar values to print. + the format string. All values must be scalars. + *args: The values to print. """ # fmt: skip has_placeholders = False if fmt: @@ -771,9 +771,7 @@ def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool): @debug_print_p.def_effectful_abstract_eval def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool): - del fmt, has_placeholders - if any(aval.shape for aval in avals): - raise ValueError("Only scalar values are supported") + del avals, fmt, has_placeholders # Unused. return [], {debug_print_effect} diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 5e495f4bef3e..856bcae97fcf 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1202,7 +1202,14 @@ def debug_print_lowering_rule( "pl.debug_print() does not support placeholders when lowering to Triton" ) - tt_dialect.print_(f" {fmt} ", hex=False, args=args) + tt_dialect.print_( + f" {fmt} ", + hex=False, + args=args, + is_signed=ir.DenseI32ArrayAttr.get([ + jnp.issubdtype(aval.dtype, jnp.signedinteger) for aval in ctx.avals_in + ]), + ) return () From 2714469397c18041d6c5696448abb7abb916ba89 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 18 Sep 2024 17:06:28 +0530 Subject: [PATCH 534/702] Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros --- CHANGELOG.md | 4 ++++ jax/_src/deprecations.py | 1 + jax/_src/numpy/lax_numpy.py | 28 ++++++++++++++++++++-------- tests/lax_numpy_test.py | 6 ++++++ 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b34eee046856..d507c7e01385 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.34 +* Deprecations + * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike arguments + with `ndim != 1` are now deprecated, and in the future will result in an error. + * Deletion: * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation in 0.4.30 JAX release. diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 10850357f677..5f1d132bcbb3 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -132,3 +132,4 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-numpy-linalg-matrix_rank-tol') register('jax-numpy-linalg-pinv-rcond') register('jax-numpy-quantile-interpolation') +register('jax-numpy-trimzeros-not-1d-array') diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4e2a20c92ed8..1513270a9dff 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -7018,7 +7018,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: return res -def trim_zeros(filt, trim='fb'): +def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: """Trim leading and/or trailing zeros of the input array. JAX implementation of :func:`numpy.trim_zeros`. @@ -7040,14 +7040,26 @@ def trim_zeros(filt, trim='fb'): >>> jnp.trim_zeros(x) Array([2, 0, 1, 4, 3], dtype=int32) """ - filt = core.concrete_or_error(asarray, filt, - "Error arose in the `filt` argument of trim_zeros()") - nz = (filt == 0) + # Non-array inputs are deprecated 2024-09-11 + util.check_arraylike("trim_zeros", filt, emit_warning=True) + core.concrete_or_error(None, filt, + "Error arose in the `filt` argument of trim_zeros()") + filt_arr = jax.numpy.asarray(filt) + del filt + if filt_arr.ndim != 1: + # Added on 2024-09-11 + if deprecations.is_accelerated("jax-numpy-trimzeros-not-1d-array"): + raise TypeError(f"'filt' must be 1-D array, but received {filt_arr.ndim}-D array.") + warnings.warn( + "Passing arrays with ndim != 1 to jnp.trim_zeros() is deprecated. Currently, it " + "works with Arrays having ndim != 1. In the future this will result in an error.", + DeprecationWarning, stacklevel=2) + nz = (filt_arr == 0) if reductions.all(nz): - return empty(0, _dtype(filt)) - start = argmin(nz) if 'f' in trim.lower() else 0 - end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 - return filt[start:len(filt) - end] + return empty(0, filt_arr.dtype) + start: Array | int = argmin(nz) if 'f' in trim.lower() else 0 + end: Array | int = argmin(nz[::-1]) if 'b' in trim.lower() else 0 + return filt_arr[start:len(filt_arr) - end] def trim_zeros_tol(filt, tol, trim='fb'): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 01c89caf7a22..6415b31e7014 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1478,6 +1478,12 @@ def testTrimZeros(self, a_shape, dtype, trim): jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + def testTrimZerosNotOneDArray(self): + # TODO: make this an error after the deprecation period. + with self.assertWarnsRegex(DeprecationWarning, + r"Passing arrays with ndim != 1 to jnp.trim_zeros\(\)"): + jnp.trim_zeros(jnp.array([[0.0, 1.0, 0.0],[2.0, 4.5, 0.0]])) + @jtu.sample_product( rank=(1, 2), dtype=default_dtypes, From e90336947a7f763226e8609ea96bc49a64fdb2c9 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 18 Sep 2024 05:25:37 -0700 Subject: [PATCH 535/702] Pulled `scratch_shapes` into `GridSpec` It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton. PiperOrigin-RevId: 675950199 --- docs/pallas/CHANGELOG.md | 16 ++++++++-------- jax/_src/pallas/core.py | 25 ++++++++++++++++--------- jax/_src/pallas/mosaic/core.py | 22 +++++----------------- jax/_src/pallas/mosaic_gpu/__init__.py | 1 - jax/_src/pallas/mosaic_gpu/core.py | 20 -------------------- jax/_src/pallas/pallas_call.py | 16 +++++++++++++--- jax/_src/pallas/triton/lowering.py | 4 ++++ tests/pallas/mosaic_gpu_test.py | 13 ++++++------- 8 files changed, 52 insertions(+), 65 deletions(-) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index b39d0211c761..43ba3ebd6afb 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -19,18 +19,22 @@ Remember to align the itemized text with the first line of an item within a list to be scalars. The restrictions on the arguments are backend-specific: Non-scalar arguments are currently only supported on GPU, when using Triton. +* Deprecations + +* New functionality + + * {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`, + a PyTree specifying backend-specific temporary objects needed by the + kernel, for example, buffers, synchronization primitives etc. + ## Released with jax 0.4.33 (September 16, 2024) ## Released with jax 0.4.32 (September 11, 2024) -## Released with jax 0.4.32 - * Changes * The kernel function is not allowed to close over constants. Instead, all the needed arrays must be passed as inputs, with proper block specs ({jax-issue}`#22746`). -* Deprecations - * New functionality * Improved error messages for mistakes in the signature of the index map functions, to include the name and source location of the index map. @@ -56,10 +60,6 @@ Remember to align the itemized text with the first line of an item within a list * Previously it was possible to import many APIs that are meant to be private, as `jax.experimental.pallas.pallas`. This is not possible anymore. - -* Deprecations - - * New Functionality * Added documentation for BlockSpec: {ref}`pallas_grids_and_blockspecs`. * Improved error messages for the {func}`jax.experimental.pallas.pallas_call` diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 56c47b9401cc..f354dd83f315 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -728,7 +728,16 @@ def _convert_block_spec_to_block_mapping( index_map_grid_aval = jax_core.ShapedArray((), jnp.int32) -@dataclasses.dataclass(init=False) + +class ScratchShape(Protocol): + def get_aval(self) -> jax_core.AbstractValue: + ... + + +ScratchShapeTree = Sequence[Union[ScratchShape, "ScratchShapeTree"]] + + +@dataclasses.dataclass(init=False, kw_only=True) class GridSpec: """Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`. @@ -741,12 +750,14 @@ class GridSpec: grid_names: tuple[Hashable, ...] | None in_specs: BlockSpecTree out_specs: BlockSpecTree + scratch_shapes: ScratchShapeTree = () def __init__( self, grid: Grid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, + scratch_shapes: ScratchShapeTree = (), ): # Be more lenient for in/out_specs if isinstance(in_specs, list): @@ -758,6 +769,7 @@ def __init__( self.in_specs = in_specs self.out_specs = out_specs + self.scratch_shapes = tuple(scratch_shapes) grid_names = None if isinstance(grid, int): @@ -773,9 +785,6 @@ def __init__( self.grid = grid # type: ignore self.grid_names = grid_names - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - assert False # Not needed in GridSpec - def _make_scalar_ref_aval(self, aval): assert False # Not needed in GridSpec @@ -820,12 +829,10 @@ def get_grid_mapping( else: num_flat_scalar_prefetch = 0 jaxpr_scalar_ref_avals = () - - scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ()) - if scratch_shapes: + if grid_spec.scratch_shapes: flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( - scratch_shapes) - flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes) + grid_spec.scratch_shapes) + flat_scratch_avals = map(lambda s: s.get_aval(), flat_scratch_shapes) num_flat_scratch_operands = len(flat_scratch_avals) jaxpr_scratch_avals = tree_util.tree_unflatten( scratch_tree, flat_scratch_avals) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 61b1dc435e72..b2b892a64f90 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,7 +19,7 @@ import dataclasses import enum import functools -from typing import Any, ClassVar, Hashable, Literal +from typing import Any, ClassVar, Literal import jax from jax._src import core as jax_core @@ -39,6 +39,7 @@ BlockSpecTree = pallas_core.BlockSpecTree GridMapping = pallas_core.GridMapping NoBlockSpec = pallas_core.NoBlockSpec +ScratchShapeTree = pallas_core.ScratchShapeTree AbstractMemoryRef = pallas_core.AbstractMemoryRef no_block_spec = pallas_core.no_block_spec _convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping @@ -174,14 +175,9 @@ def get_aval(self) -> AbstractMemoryRef: jax_core.ShapedArray(self.shape, self.dtype), self.memory_space) -@dataclasses.dataclass(init=False, unsafe_hash=True) +@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): - grid: TupleGrid - grid_names: tuple[Hashable, ...] | None num_scalar_prefetch: int - in_specs: pallas_core.BlockSpecTree - out_specs: pallas_core.BlockSpecTree - scratch_shapes: tuple[Any, ...] def __init__( self, @@ -189,9 +185,9 @@ def __init__( grid: Grid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, - scratch_shapes: Any | Sequence[Any] = () + scratch_shapes: ScratchShapeTree = () ): - super().__init__(grid, in_specs, out_specs) + super().__init__(grid, in_specs, out_specs, scratch_shapes) self.num_scalar_prefetch = num_scalar_prefetch self.scratch_shapes = tuple(scratch_shapes) @@ -199,14 +195,6 @@ def _make_scalar_ref_aval(self, aval): return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), TPUMemorySpace.SMEM) - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - if isinstance(obj, MemoryRef): - return obj.get_aval() - if isinstance(obj, SemaphoreType): - return obj.get_aval() - raise ValueError(f"No registered conversion for {type(obj)}. " - "Only VMEM and SemaphoreType are supported.") - @dataclasses.dataclass(frozen=True) class TensorCore: diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index 11258f741b7f..1bd512834ce5 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -17,7 +17,6 @@ from jax._src.pallas.mosaic_gpu.core import Barrier from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams -from jax._src.pallas.mosaic_gpu.core import GPUGridSpec from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 3ef205d336d0..5a046afead72 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -150,26 +150,6 @@ def to_block_mapping( ) -@dataclasses.dataclass(init=False, kw_only=True) -class GPUGridSpec(pallas_core.GridSpec): - scratch_shapes: Sequence[Any] - - def __init__( - self, - grid: pallas_core.Grid = (), - in_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec, - out_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec, - scratch_shapes: Sequence[Any] = () - ): - super().__init__(grid, in_specs, out_specs) - self.scratch_shapes = tuple(scratch_shapes) - - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - if isinstance(obj, (MemoryRef, Barrier)): - return obj.get_aval() - raise TypeError(f"Cannot convert {obj} to an abstract value") - - # TODO(b/354568887): Cosolidate this with TPU's MemoryRef. @dataclasses.dataclass(frozen=True) class MemoryRef: diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index b69fb03f0951..206c0cdee876 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -62,6 +62,7 @@ BlockSpecTree = pallas_core.BlockSpecTree NoBlockSpec = pallas_core.NoBlockSpec no_block_spec = pallas_core.no_block_spec +ScratchShapeTree = pallas_core.ScratchShapeTree CostEstimate = pallas_core.CostEstimate # See the docstring for GridMapping for the calling convention @@ -1233,6 +1234,7 @@ def pallas_call( grid: TupleGrid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, + scratch_shapes: ScratchShapeTree = (), input_output_aliases: dict[int, int] = {}, debug: bool = False, interpret: bool = False, @@ -1250,8 +1252,9 @@ def pallas_call( corresponding ``in_specs`` and ``out_specs``. out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape and dtypes of the outputs. - grid_spec: An alternative way to specify ``grid``, ``in_specs``, and - ``out_specs``. If given, those other parameters must not be also given. + grid_spec: An alternative way to specify ``grid``, ``in_specs``, + ``out_specs`` and ``scratch_shapes``. If given, those other parameters + must not be also given. grid: the iteration space, as a tuple of integers. The kernel is executed as many times as ``prod(grid)``. See details at :ref:`pallas_grid`. @@ -1265,6 +1268,9 @@ def pallas_call( The default value for ``out_specs`` specifies the whole array, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``. See details at :ref:`pallas_blockspec`. + scratch_shapes: a PyTree of backend-specific temporary objects required + by the kernel, such as temporary buffers, synchronization primitives, + etc. input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs. @@ -1305,7 +1311,7 @@ def pallas_call( } if grid_spec is None: - grid_spec = GridSpec(grid, in_specs, out_specs) + grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes) else: if grid: raise ValueError( @@ -1319,6 +1325,10 @@ def pallas_call( raise ValueError( "If `grid_spec` is specified, then `out_specs` must " f"be `no_block_spec`. It is {out_specs}") + if scratch_shapes: + raise ValueError( + "If `grid_spec` is specified, then `scratch_shapes` must " + f"be `()`. It is {scratch_shapes}") del grid, in_specs, out_specs grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec) # TODO(necula): this canonicalization may be convenient for some usage diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 856bcae97fcf..0a23e512dfb3 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -277,6 +277,10 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "scalar prefetch not implemented in the Triton backend" ) + if jaxpr.invars[grid_mapping.slice_scratch_ops]: + raise NotImplementedError( + "scratch memory not implemented in the Triton backend" + ) with grid_mapping.trace_env(): jaxpr, _ = pe.dce_jaxpr( jaxpr, [True] * len(jaxpr.outvars), instantiate=True diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index bd9df6182793..4810e780813d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -134,16 +134,15 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_with_async_copy_gmem_to_smem(self): + @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - grid_spec=plgpu.GPUGridSpec( - in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), - scratch_shapes=[ - plgpu.SMEM((128,), jnp.float32), - plgpu.Barrier(num_arrivals=1), - ], - ), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((128,), jnp.float32), + plgpu.Barrier(num_arrivals=1), + ], ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): plgpu.async_copy_gmem_to_smem( From 611ad630603cffa88aa714bf876340af315dd819 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 6 Sep 2024 16:09:58 +0000 Subject: [PATCH 536/702] Add basic PyTorch integration for Mosaic GPU We have already had most of the relevant pieces and we only needed to connect them together. The most sensitive change is perhaps that I needed to expose one more symbol from the XLA GPU plugin, but I don't think it should be a problem. --- jax/experimental/mosaic/gpu/__init__.py | 121 +++++++++++++++++++++--- jaxlib/mosaic/gpu/custom_call.cc | 94 ++++++++++++------ jaxlib/tools/BUILD.bazel | 3 +- jaxlib/tools/gpu_version_script.lds | 11 +++ tests/mosaic/gpu_test.py | 24 +++++ 5 files changed, 212 insertions(+), 41 deletions(-) create mode 100644 jaxlib/tools/gpu_version_script.lds diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 2e2941fca5b1..0e263844b18e 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -27,6 +27,7 @@ import tempfile import time from typing import Any, Generic, TypeVar +import weakref import jax from jax._src import config @@ -800,6 +801,21 @@ def main(token_ptr, buffers): return module, out_shape, unwrap_output_tuple +def _declare_runtime_functions(): + """Declares the runtime functions that can be used by the generated code.""" + ptr_ty = ir.Type.parse("!llvm.ptr") + i64 = ir.IntegerType.get_signless(64) + arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] + init_tma_desc_type = ir.FunctionType.get(arg_tys, []) + func.FuncOp( + "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" + ) + memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) + func.FuncOp( + "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" + ) + + def as_gpu_kernel( body, grid: tuple[int, int, int], @@ -867,16 +883,97 @@ def kernel(*args): return kernel -def _declare_runtime_functions(): - """Declares the runtime functions that can be used by the generated code.""" - ptr_ty = ir.Type.parse("!llvm.ptr") - i64 = ir.IntegerType.get_signless(64) - arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] - init_tma_desc_type = ir.FunctionType.get(arg_tys, []) - func.FuncOp( - "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" - ) - memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) - func.FuncOp( - "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" +def as_torch_gpu_kernel( + body, + grid: tuple[int, int, int], + block: tuple[int, int, int], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, + cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", +): + try: + import torch + except ImportError: + raise RuntimeError("as_torch_gpu_kernel requires PyTorch") + torch.cuda.init() # Make sure CUDA context is set up. + + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + + flat_out_types, out_treedef = jax.tree.flatten(out_shape) + expected_arg_treedef = jax.tree.structure(in_shape) + + module, out_shape, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, + module_name, prof_spec + ) ) + + # Get our hands on the compilation and unload functions + try: + import jax_plugins.xla_cuda12 as cuda_plugin + except ImportError: + raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds " + "that use backend plugins") + dll = ctypes.CDLL(cuda_plugin._get_library_path()) + compile_func = dll.MosaicGpuCompile + compile_func.argtypes = [ctypes.c_void_p] + compile_func.restype = ctypes.POINTER(ctypes.c_void_p) + unload_func = dll.MosaicGpuUnload + unload_func.argtypes = [compile_func.restype] + unload_func.restype = None + + module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + compiled = compile_func(ctypes.c_char_p(module_asm)) + if compiled is None: + raise RuntimeError("Failed to compile the module") + ctx, launch_ptr = compiled[0], compiled[1] + ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx)) + launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr) + + def as_torch_dtype(dtype): + # torch contains NumPy-compatible dtypes in its top namespace + return getattr(torch, np.dtype(dtype).name) + + def apply(*args): + flat_args, arg_treedef = jax.tree.flatten(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({args=})" + ) + + # Construct a device pointer list like in the XLA calling convention + buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))() + i = -1 # Define i in case there are no args + device = 'cuda' + for i, arg in enumerate(flat_args): + buffers[i] = arg.data_ptr() + device = arg.device + flat_outs = [] + for i, t in enumerate(flat_out_types, i + 1): + out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device) + flat_outs.append(out) + buffers[i] = out.data_ptr() + # Allocate another buffer for args of the host-side program. This is sadly + # the default MLIR calling convention. + args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)() + args_ptr[0] = ctx_ptr_ptr + args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_) + args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)), + ctypes.POINTER(ctypes.c_void_p)) + launch(args_ptr) + return jax.tree.unflatten(out_treedef, flat_outs) + + # Unload the compiled code when the Python function is destroyed. + def unload(_): + unload_func(compiled) + apply.destructor = weakref.ref(apply, unload) + + return apply diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 2e5723b184a8..103f9f78c32f 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -377,10 +377,40 @@ GetKernelCache() { return std::make_pair(&context_cache, &mutex); } + +absl::StatusOr CompileAndInit(const char* module) { + mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); + InitContext(&context); + mlir::ParserConfig parse_config(&context); + auto module_op = + mlir::parseSourceString(module, parse_config); + if (!module_op) { + return absl::InternalError("Failed to parse module"); + } + auto maybe_engine = Compile(*module_op); + if (!maybe_engine.ok()) { + return maybe_engine.status(); + } + mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + auto main = execution_engine->lookupPacked("_mlir_ciface_main"); + auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); + if (!init || !main) { + return absl::InternalError("Failed to retrieve kernel function"); + } + void* module_ptr = nullptr; + void* kernel_ptr = nullptr; + void** module_ptr_ptr = &module_ptr; + void** kernel_ptr_ptr = &kernel_ptr; + void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; + reinterpret_cast(*init)(init_args); + return CompiledKernel(std::move(*maybe_engine), kernel_ptr, + reinterpret_cast(*main)); +} + // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CompileAndInit( +absl::StatusOr> CachedCompileAndInit( CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; @@ -397,33 +427,11 @@ absl::StatusOr> CompileAndInit( absl::MutexLock lock(mutex); // We released the reader lock, another thread might have initialized it. if (cache->find(key) == cache->end()) { - mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); - InitContext(&context); - mlir::ParserConfig parse_config(&context); - auto module_op = - mlir::parseSourceString(module, parse_config); - if (!module_op) { - return absl::InternalError("Failed to parse module"); - } - auto maybe_engine = Compile(*module_op); - if (!maybe_engine.ok()) { - return maybe_engine.status(); + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return compiled.status(); } - mlir::ExecutionEngine* execution_engine = maybe_engine->get(); - auto main = execution_engine->lookupPacked("_mlir_ciface_main"); - auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); - if (!init || !main) { - return absl::InternalError("Failed to retrieve kernel function"); - } - void* module_ptr = nullptr; - void* kernel_ptr = nullptr; - void** module_ptr_ptr = &module_ptr; - void** kernel_ptr_ptr = &kernel_ptr; - void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; - reinterpret_cast(*init)(init_args); - cache->insert_or_assign( - key, CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(*main))); + cache->insert_or_assign(key, std::move(*compiled)); } return cache->at(key).GetHostLaunch(); } @@ -441,7 +449,7 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, abort(); } CacheKey key(hash, reinterpret_cast(ctx)); - auto ctx_and_kernel = CompileAndInit(key, opaque + sizeof(KernelHash)); + auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); if (!ctx_and_kernel.ok()) { XlaCustomCallStatusSetFailure(status, ctx_and_kernel.status().message().data(), @@ -456,3 +464,33 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, "CUDA"); } // namespace + +extern "C" { + +__attribute__((visibility("default"))) +void** MosaicGpuCompile(const char* module) { + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return nullptr; + } + auto [ctx, launch] = compiled->GetHostLaunch(); + auto tuple_ptr = std::unique_ptr(new void*[3]); + if (!tuple_ptr) { + return nullptr; + } + tuple_ptr.get()[0] = ctx; + tuple_ptr.get()[1] = reinterpret_cast(launch); + tuple_ptr.get()[2] = new CompiledKernel(std::move(*compiled)); + if (!tuple_ptr.get()[2]) { + return nullptr; + } + return tuple_ptr.release(); +} + +__attribute__((visibility("default"))) +void MosaicGpuUnload(void** tuple_ptr) { + delete reinterpret_cast(tuple_ptr[2]); + delete[] tuple_ptr; +} + +} // extern "C" diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 8463cba08c5f..4642af12011d 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -64,11 +64,12 @@ py_test( cc_binary( name = "pjrt_c_api_gpu_plugin.so", linkopts = [ - "-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)", + "-Wl,--version-script,$(location :gpu_version_script.lds)", "-Wl,--no-undefined", ], linkshared = True, deps = [ + ":gpu_version_script.lds", "@xla//xla/pjrt/c:pjrt_c_api_gpu", "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds", "@xla//xla/service:gpu_plugin", diff --git a/jaxlib/tools/gpu_version_script.lds b/jaxlib/tools/gpu_version_script.lds new file mode 100644 index 000000000000..8e46b2c590b2 --- /dev/null +++ b/jaxlib/tools/gpu_version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + global: + extern "C" { + GetPjrtApi; + MosaicGpuCompile; + MosaicGpuUnload; + }; + + local: + *; +}; diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index ec9a7cd8b64e..27dc1c984bee 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -19,6 +19,7 @@ import itertools import math import operator +import unittest from absl.testing import absltest, parameterized import jax @@ -1387,5 +1388,28 @@ def kernel(ctx, src, dst, _): jax.block_until_ready(f(xd)) +class TorchTest(TestCase): + + @classmethod + def setUpClass(cls): + try: + import torch + except ImportError: + raise unittest.SkipTest("Test requires PyTorch") + cls.torch = torch + + def test_basic(self): + def kernel(ctx, i_gmem, o_gmem, _): + x = mgpu.FragmentedArray.load_strided(i_gmem) + (x + x).store_untiled(o_gmem) + + ty = jax.ShapeDtypeStruct((128, 128), jnp.float32) + x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda') + f = mosaic_gpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ()) + y = f(x) + np.testing.assert_allclose(y.cpu(), x.cpu() * 2) + del y # Make sure the destructor runs successfully. + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From b7c91e90c2f3645f5270b6eb7aea5882852eb7b1 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 18 Sep 2024 06:22:14 -0700 Subject: [PATCH 537/702] Lookup `shape` and `dtype` directly on `state.AbstractRef` instead of going through `inner_aval` This is just a cleanup. No behavior changes are expected. PiperOrigin-RevId: 675964703 --- jax/_src/pallas/core.py | 7 ++++--- jax/_src/pallas/mosaic_gpu/core.py | 2 +- jax/_src/pallas/mosaic_gpu/lowering.py | 5 +---- jax/_src/state/discharge.py | 3 +-- jax/_src/state/types.py | 18 ++++++++++++------ 5 files changed, 19 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index f354dd83f315..f8ec3b63339a 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -659,9 +659,10 @@ def slice_scratch_ops(self): @property def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: """The shapes of *index, *inputs.""" - index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape, - ia.inner_aval.dtype) - for ia in self.index_map_avals[len(self.grid):]) + index_shapes = ( + jax.ShapeDtypeStruct(ia.shape, ia.dtype) + for ia in self.index_map_avals[len(self.grid) :] + ) inputs_shapes = ( bm.array_shape_dtype for bm in self.block_mappings[:self.num_inputs]) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 5a046afead72..34ad5acf34d6 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -82,7 +82,7 @@ def __init__(self, tiling: tuple[int, ...]): def __call__( self, block_aval: pallas_core.AbstractMemoryRef ) -> pallas_core.AbstractMemoryRef: - block_shape = block_aval.inner_aval.shape # pytype: disable=attribute-error + block_shape = block_aval.shape old_tiled_dims = block_shape[-len(self.tiling) :] num_tiles = tuple( block_dim // tiling_dim diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ef10236e2fdd..1d76dc8405d5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -268,10 +268,7 @@ def lower_jaxpr_to_module( for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] ] in_structs_smem = [ - jax.ShapeDtypeStruct( - [num_stages, *bm.ref_aval.inner_aval.shape], - bm.ref_aval.inner_aval.dtype, - ) + jax.ShapeDtypeStruct([num_stages, *bm.ref_aval.shape], bm.ref_aval.dtype) if in_smem else None for bm, in_smem in zip( diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 6a912abf215b..4231822965b1 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -516,8 +516,7 @@ def eval_jaxpr(*refs): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( eval_jaxpr, [*in_avals, *res_ref_avals]) assert not consts - return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error - for a in res_ref_avals] + return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr: assert not jaxpr.constvars, "Jaxpr should not have constvars" diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 05368e978593..8289f858498b 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -196,15 +196,21 @@ def join(self, other): @property def shape(self): - if not isinstance(self.inner_aval, core.ShapedArray): - raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.") - return self.inner_aval.shape + try: + return self.inner_aval.shape # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`." + ) from None @property def dtype(self): - if not isinstance(self.inner_aval, core.UnshapedArray): - raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.") - return self.inner_aval.dtype + try: + return self.inner_aval.dtype # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`." + ) from None @core.aval_property def at(self): From 73c38cb7009b52706d3769aaec6d32046aced508 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 17 Sep 2024 14:00:21 -0400 Subject: [PATCH 538/702] Add a note to the developer docs making it clear that clang is the only toolchain that is actively supported for source compilation. As discussed in https://github.com/google/jax/issues/23687 --- docs/developer.md | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/docs/developer.md b/docs/developer.md index 53b6f0cf0f45..40ad51e873ca 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -31,23 +31,33 @@ guidance on pip installation (e.g., for GPU and TPU support). ### Building `jaxlib` from source +```{warning} +While it should typically be possible to compile `jaxlib` from source using +most modern compilers, the builds are only tested using clang. Pull requests +are welcomed to improve support for different toolchains, but other compilers +are not actively supported. +``` + To build `jaxlib` from source, you must also install some prerequisites: -- a C++ compiler (g++, clang, or MSVC) +- A C++ compiler: - On Ubuntu or Debian you can install the necessary prerequisites with: + As mentioned in the box above, it is best to use a recent version of clang + (at the time of writing, the version we test is 18), but other compilers (e.g. + g++ or MSVC) may work. - ``` - sudo apt install g++ python python3-dev - ``` + On Ubuntu or Debian you can follow the instructions from the + [LLVM](https://apt.llvm.org/) documentation to install the latest stable + version of clang. If you are building on a Mac, make sure XCode and the XCode command line tools are installed. See below for Windows build instructions. -- there is no need to install Python dependencies locally, as your system - Python will be ignored during the build; please check +- Python: for running the build helper script. Note that there is no need to + install Python dependencies locally, as your system Python will be ignored + during the build; please check [Managing hermetic Python](#managing-hermetic-python) for details. To build `jaxlib` for CPU or TPU, you can run: @@ -86,7 +96,7 @@ the `build/build.py` script itself will be processed by your system Python interpreter. By default, the wheel is written to the `dist/` subdirectory of the current directory. -* JAX versions starting from v.0.4.32: you can provide custom CUDA and CUDNN +* JAX versions starting from v.0.4.32: you can provide custom CUDA and CUDNN versions in the configuration options. Bazel will download them and use as target dependencies. @@ -259,8 +269,8 @@ together with their corresponding hashes are specified in `build/requirements_lock_.txt` files ( e.g. `build/requirements_lock_3_12.txt` for `Python 3.12`). -To update the lock files, make sure `build/requirements.in` contains the desired -direct dependencies list and then execute the following command (which will call +To update the lock files, make sure `build/requirements.in` contains the desired +direct dependencies list and then execute the following command (which will call [pip-compile](https://pypi.org/project/pip-tools/) under the hood): ``` @@ -382,7 +392,7 @@ sudo apt-get install libopenblas-dev -y example for `Python 3.13` it should have something like `"3.13": "//build:requirements_lock_3_13.txt"`. Note, the key in the `requirements` parameter must always be in `"major.minor"` version format, so - even if you are building Python version `3.13.0rc1` the corresponding + even if you are building Python version `3.13.0rc1` the corresponding `requirements` entry must still be `"3.13": "//build:requirements_lock_3_13.txt"`, **not** `"3.13.0rc1": "//build:requirements_lock_3_13_0rc1.txt"`. From 2834c135a34636af9de73b421e8fbb6731b20c4e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 17 Sep 2024 15:32:25 -0700 Subject: [PATCH 539/702] jnp.sort_complex: fix output for N-dimensional inputs --- jax/_src/numpy/lax_numpy.py | 13 +++++++++++-- tests/lax_numpy_test.py | 10 ++-------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 1513270a9dff..dd4c54fe7efa 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -8844,7 +8844,8 @@ def sort_complex(a: ArrayLike) -> Array: a: input array. If dtype is not complex, the array will be upcast to complex. Returns: - A sorted array of the same shape and complex dtype as the input. + A sorted array of the same shape and complex dtype as the input. If ``a`` + is multi-dimensional, it is sorted along the last axis. See also: - :func:`jax.numpy.sort`: Return a sorted copy of an array. @@ -8853,9 +8854,17 @@ def sort_complex(a: ArrayLike) -> Array: >>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j]) >>> jnp.sort_complex(a) Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64) + + Multi-dimensional arrays are sorted along the last axis: + + >>> a = jnp.array([[5, 3, 4], + ... [6, 9, 2]]) + >>> jnp.sort_complex(a) + Array([[3.+0.j, 4.+0.j, 5.+0.j], + [2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64) """ util.check_arraylike("sort_complex", a) - a = lax.sort(asarray(a), dimension=0) + a = lax.sort(asarray(a)) return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) @util.implements(np.lexsort) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6415b31e7014..f93e28dada71 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4295,14 +4295,8 @@ def testSortStableDescending(self): self.assertArraysEqual(jnp.argsort(x), argsorted_stable) self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable) - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in one_dim_array_shapes - for axis in [None] - ], - dtype=all_dtypes, - ) - def testSortComplex(self, dtype, shape, axis): + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) + def testSortComplex(self, shape, dtype): rng = jtu.rand_some_equal(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker, From 69ba060957529fc9babf838af8d47a2626615e62 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 18 Sep 2024 07:40:58 -0700 Subject: [PATCH 540/702] Reverts e15ec1e8abe3732d747731c15a36facf4169739e PiperOrigin-RevId: 675987338 --- jax/_src/lax/lax.py | 10 +++++----- tests/filecheck/math.filecheck.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 745958239806..8d2c24d6e64c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -340,10 +340,6 @@ def cos(x: ArrayLike) -> Array: r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.""" return cos_p.bind(x) -def tan(x: ArrayLike) -> Array: - r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.""" - return tan_p.bind(x) - def atan2(x: ArrayLike, y: ArrayLike) -> Array: r"""Elementwise arc tangent of two variables: :math:`\mathrm{atan}({x \over y})`.""" @@ -1553,6 +1549,10 @@ def f_wrapped(x): return f_wrapped +def tan(x: ArrayLike) -> Array: + r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.""" + return tan_p.bind(x) + def asin(x: ArrayLike) -> Array: r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`.""" return asin_p.bind(x) @@ -2014,7 +2014,7 @@ def _tan_impl(x): tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) +mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index f34b8211eb33..e75e8e7d735f 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -419,7 +419,7 @@ def integer_pow(x): return lax.integer_pow(x, 3) print_ir(jnp.bfloat16(0))(lax.sqrt) # CHECK-LABEL: TEST: tan float16[] - # CHECK: hlo.tan + # CHECK: chlo.tan # CHECK-SAME: tensor print_ir(np.float16(0))(lax.tan) From 922e652c05654b9e2278a88a22cb70b94fdd6f46 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Sep 2024 15:17:49 +0000 Subject: [PATCH 541/702] Replace plat-name with plat_name. The former seems to elicit a deprecation warning from setuptools recently. --- jaxlib/tools/build_gpu_kernels_wheel.py | 2 +- jaxlib/tools/build_gpu_plugin_wheel.py | 2 +- jaxlib/tools/build_wheel.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 28d2806a7da9..ced0b76c344c 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -74,7 +74,7 @@ def write_setup_cfg(sources_path, cpu): license_files = LICENSE.txt [bdist_wheel] -plat-name={tag} +plat_name={tag} """) diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 73cb8a9e020d..0e2bba0c74d0 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -80,7 +80,7 @@ def write_setup_cfg(sources_path, cpu): license_files = LICENSE.txt [bdist_wheel] -plat-name={tag} +plat_name={tag} python-tag=py3 """ ) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 48aab847f3fb..6305b0c24aa8 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -164,7 +164,7 @@ def write_setup_cfg(sources_path, cpu): license_files = LICENSE.txt [bdist_wheel] -plat-name={tag} +plat_name={tag} """ ) From 1cc96616baab7394117c2ecb9b64db22fd82dc44 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 16 Sep 2024 14:18:29 -0400 Subject: [PATCH 542/702] Unconditionally lower jnp.dot to lax.dot_general. https://github.com/google/jax/pull/16721 added a condition to lower calls to `jnp.dot` with scalar inputs to `lax.mul` instead of `lax.dot_general`. AFAICT, https://github.com/google/jax/pull/16826 fixed the issue that this was solving, so this condition should no longer be necessary. Removing this condition simplifies the addition of new arguments to `dot` and `dot_general`, including the `algorithm` parameter that I am currently working on in https://github.com/google/jax/pull/23574, so now seemed like a good time to remove it! --- jax/_src/numpy/lax_numpy.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index dd4c54fe7efa..c1b74205e140 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -7409,20 +7409,17 @@ def dot(a: ArrayLike, b: ArrayLike, *, batch_dims = ((), ()) a_ndim, b_ndim = ndim(a), ndim(b) if a_ndim == 0 or b_ndim == 0: - # TODO(jakevdp): lower this case to dot_general as well? - # Currently, doing so causes issues in remat tests due to #16805 - if preferred_element_type is not None: - a = a.astype(preferred_element_type) - b = b.astype(preferred_element_type) - result = lax.mul(a, b) + contract_dims: tuple[tuple[int, ...], tuple[int, ...]] = ((), ()) else: if b_ndim == 1: contract_dims = ((a_ndim - 1,), (0,)) else: contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) - result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), - precision=precision, preferred_element_type=preferred_element_type) - return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) + result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), + precision=precision, + preferred_element_type=preferred_element_type) + return lax_internal._convert_element_type(result, preferred_element_type, + output_weak_type) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) From c191bbcdb162bec58494e818f42feb457cd0a287 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 18 Sep 2024 08:40:30 -0700 Subject: [PATCH 543/702] Make `debug.print` work with static args. Fixes: https://github.com/google/jax/issues/23600 PiperOrigin-RevId: 676005582 --- jax/_src/debugging.py | 26 ++++++++++++++++++++++---- tests/debugging_primitives_test.py | 12 ++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 3e7082ab10ec..3373496940e2 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -46,6 +46,7 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.sharding import Sharding from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding +from jax._src.api_util import shaped_abstractify from jax._src.state import discharge as state_discharge logger = logging.getLogger(__name__) @@ -256,12 +257,29 @@ def debug_callback(callback: Callable[..., None], *args: Any, raise TypeError("first argument to jax.debug.callback must be callable, " f"but got an object of type {type(callback)}") flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) - effect = ordered_debug_effect if ordered else debug_effect - def _flat_callback(*flat_args): - args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) + static_args, dyn_args = {}, [] + for i, a in enumerate(flat_args): + try: + shaped_abstractify(a) + dyn_args.append(a) + except (AssertionError, TypeError): + static_args[i] = a + + def _flat_callback(*dyn_args): + all_args = [None] * (len(static_args) + len(dyn_args)) + di = iter(dyn_args) + for i in range(len(all_args)): + if i in static_args: + all_args[i] = static_args[i] + else: + all_args[i] = next(di) + assert next(di, None) is None + args, kwargs = tree_util.tree_unflatten(in_tree, all_args) callback(*args, **kwargs) return () - debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect) + + effect = ordered_debug_effect if ordered else debug_effect + debug_callback_p.bind(*dyn_args, callback=_flat_callback, effect=effect) class _DebugPrintFormatChecker(string.Formatter): diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 273c12f1b13c..5532fdf0303f 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -80,6 +80,18 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "x: 2\n") + def test_static_args(self): + @jax.jit + def f(arr): + jax.debug.print("arr {array}, dtype: {dtype}, arr {array2}", + array=arr, dtype=arr.dtype, array2=arr) + arr = jnp.array([1, 2, 3], dtype=jnp.float32) + with jtu.capture_stdout() as output: + f(arr) + jax.effects_barrier() + self.assertEqual( + output(), "arr [1. 2. 3.], dtype: float32, arr [1. 2. 3.]\n") + def test_debug_print_works_with_named_format_strings(self): def f(x): debug_print('x: {x}', x=x) From c756d9b7033d1dcf3a389bcb4c9e65f9f5201019 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Sep 2024 15:44:45 +0000 Subject: [PATCH 544/702] Fix error in debugger tests that is showing up in CI. I'm unsure why this started happening now, but sometimes we get an invalid offset for a frame. Be tolerant of that case. --- jax/_src/debugger/core.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jax/_src/debugger/core.py b/jax/_src/debugger/core.py index f6b0a81baf92..1efeed73cbc8 100644 --- a/jax/_src/debugger/core.py +++ b/jax/_src/debugger/core.py @@ -112,6 +112,11 @@ def from_frameinfo(cls, frame_info) -> DebuggerFrame: # then we subtract it off from the `lineno` and don't need to subtract 1 # since both start and lineno are 1-indexed. offset = frame_info.lineno - max(start, 1) + if offset >= len(source): + # Sometimes we don't get a valid source/offset pair. This seems to + # happen sometimes when code uses eval(). If that happens, give up. + source = [] + offset = None except OSError: source = [] offset = None From 442e8630deff3f89d3ae756aac30bb71e7ba7cf2 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 18 Sep 2024 08:56:49 -0700 Subject: [PATCH 545/702] Added a missing branch to `mgpu.FragmentedArray.astype` Previously, an unsupported cast produced a `NameError` instead. PiperOrigin-RevId: 676010161 --- jax/experimental/mosaic/gpu/fragmented_array.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 0b228833cbdb..502373bdc91e 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -494,6 +494,8 @@ def astype(self, new_dtype: ir.Type): convert = arith.sitofp elif from_float and to_integer: convert = arith.fptosi + else: + raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}") new_registers = np.empty_like(self.registers) match self.layout: case WGMMAFragLayout(): From e27f1e9b3a8af39d2791b95a8106106229cad238 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 18 Sep 2024 09:03:55 -0700 Subject: [PATCH 546/702] Change Python version 3.13.0rc2 to 3.13.0-rc.2. The value is taken from [the versions manifest](https://raw.githubusercontent.com/actions/python-versions/main/versions-manifest.json). PiperOrigin-RevId: 676012255 --- .github/workflows/wheel_win_x64.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index bae1edec0214..367f8e05bf56 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -17,7 +17,7 @@ jobs: matrix: os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.10', '3.11', '3.12', '3.13.0rc2'] + pyver: ['3.10', '3.11', '3.12', '3.13.0-rc.2'] name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build runs-on: ${{ matrix.os }} From 9dd363da1298e4810b693a918fc2e8199094acdb Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Wed, 18 Sep 2024 09:28:25 -0700 Subject: [PATCH 547/702] Export `jax.lib.xla_extension.ifrt_programs`. PiperOrigin-RevId: 676020419 --- jax/extend/BUILD | 6 ++++++ jax/extend/ifrt_programs.py | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 jax/extend/ifrt_programs.py diff --git a/jax/extend/BUILD b/jax/extend/BUILD index babe0c8b10d2..59958c1da389 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -80,3 +80,9 @@ pytype_strict_library( srcs = ["ffi.py"], deps = ["//jax"], ) + +pytype_strict_library( + name = "ifrt_programs", + srcs = ["ifrt_programs.py"], + deps = ["//jax/_src/lib"], +) diff --git a/jax/extend/ifrt_programs.py b/jax/extend/ifrt_programs.py new file mode 100644 index 000000000000..d5fb9245af91 --- /dev/null +++ b/jax/extend/ifrt_programs.py @@ -0,0 +1,22 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +from jax._src.lib import xla_extension as _xe + +ifrt_programs = _xe.ifrt_programs + +del _xe From 016c49951f670256ce4750cdfea182e3a2a15325 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 18 Sep 2024 09:56:44 -0700 Subject: [PATCH 548/702] Removed leftover usages of GPUGridSpec from Pallas Mosaic GPU tests PiperOrigin-RevId: 676029854 --- tests/pallas/mosaic_gpu_test.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 4810e780813d..746c4e93387b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -79,15 +79,14 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_with_scratch(self): + @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), - grid_spec=plgpu.GPUGridSpec( - in_specs=[pl.BlockSpec((128,), lambda *i: i)], - out_specs=pl.BlockSpec((128,), lambda *i: i), - scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], - grid=2, - ), + in_specs=[pl.BlockSpec((128,), lambda *i: i)], + out_specs=pl.BlockSpec((128,), lambda *i: i), + scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], + grid=2, ) def kernel(x_ref, o_ref, scratch_ref): scratch_ref[...] = x_ref[...] + 1 @@ -120,10 +119,8 @@ def test_add_one_with_async_copy_smem_to_gmem(self): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - grid_spec=plgpu.GPUGridSpec( - out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], - ), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 From bef36c431d752b91372ac58d4cf6a84277dc600e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Sep 2024 18:57:03 +0000 Subject: [PATCH 549/702] Add Python 3.13 wheels to changelog. --- CHANGELOG.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d507c7e01385..659a8ee04db0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,9 +12,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.34 +* New Functionality + * This release includes wheels for Python 3.13. Free-threading mode is not yet + supported. + * Deprecations - * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike arguments - with `ndim != 1` are now deprecated, and in the future will result in an error. + * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike + arguments with `ndim != 1` are now deprecated, and in the future will result + in an error. * Deletion: * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation From 57a4b76d09fb1eac160242bf1f31bc8b3841f82d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 18 Sep 2024 11:59:00 -0700 Subject: [PATCH 550/702] Improve documentation for jnp.digitize --- jax/_src/numpy/lax_numpy.py | 48 ++++++++++++++++++++++++++++++++----- jax/numpy/__init__.pyi | 3 ++- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c1b74205e140..47a36a2b83ff 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -10808,11 +10808,46 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', }[method] return impl(asarray(a), asarray(v), side, dtype) # type: ignore -@util.implements(np.digitize, lax_description=_dedent(""" - Optionally, the ``method`` argument can be used to configure the - underlying :func:`jax.numpy.searchsorted` algorithm.""")) + @partial(jit, static_argnames=('right', 'method')) -def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str = 'scan') -> Array: +def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, + *, method: str | None = None) -> Array: + """Convert an array to bin indices. + + JAX implementation of :func:`numpy.digitize`. + + Args: + x: array of values to digitize. + bins: 1D array of bin edges. Must be monotonically increasing or decreasing. + right: if true, the intervals include the right bin edges. If false (default) + the intervals include the left bin edges. + method: optional method argument to be passed to :func:`~jax.numpy.searchsorted`. + See that function for available options. + + Returns: + An integer array of the same shape as ``x`` indicating the bin number that + the values are in. + + See also: + - :func:`jax.numpy.searchsorted`: find insertion indices for values in a + sorted array. + - :func:`jax.numpy.histogram`: compute frequency of array values within + specified bins. + + Examples: + >>> x = jnp.array([1.0, 2.0, 2.5, 1.5, 3.0, 3.5]) + >>> bins = jnp.array([1, 2, 3]) + >>> jnp.digitize(x, bins) + Array([1, 2, 2, 1, 3, 3], dtype=int32) + >>> jnp.digitize(x, bins, right=True) + Array([0, 1, 2, 1, 2, 3], dtype=int32) + + ``digitize`` supports reverse-ordered bins as well: + + >>> bins = jnp.array([3, 2, 1]) + >>> jnp.digitize(x, bins) + Array([2, 1, 1, 2, 0, 0], dtype=int32) + """ util.check_arraylike("digitize", x, bins) right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()") bins_arr = asarray(bins) @@ -10821,10 +10856,11 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str if bins_arr.shape[0] == 0: return zeros_like(x, dtype=int32) side = 'right' if not right else 'left' + kwds: dict[str, str] = {} if method is None else {'method': method} return where( bins_arr[-1] >= bins_arr[0], - searchsorted(bins_arr, x, side=side, method=method), - bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, method=method) + searchsorted(bins_arr, x, side=side, **kwds), + bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds) ) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index d5b66c1b3b32..c23f659bd3f9 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -300,7 +300,8 @@ def diagonal( def diff(a: ArrayLike, n: int = ..., axis: int = ..., prepend: ArrayLike | None = ..., append: ArrayLike | None = ...) -> Array: ... -def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ...) -> Array: ... +def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ..., *, + method: str | None = ...) -> Array: ... divide = true_divide def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... def dot( From dbc03cf8e5a7ac8e1e6e8e593e063eb7f54990d1 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 18 Sep 2024 12:39:58 -0700 Subject: [PATCH 551/702] Re-land #23261 with appropriate compatibility checks. PiperOrigin-RevId: 676092618 --- jax/_src/lax/lax.py | 12 +++++++++++- tests/filecheck/math.filecheck.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8d2c24d6e64c..e356756cd3e1 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2012,9 +2012,19 @@ def _cos_lowering(ctx, x): def _tan_impl(x): return div(sin(x), cos(x)) +def _tan_lowering(ctx, x): + # TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this + # lowering is supported, but export doesn't target a sufficiently up-to-date + # StableHLO version, and the compatibility updates from + # https://github.com/openxla/xla/pull/16649 aren't included in the 0.4.33 + # release. + if ctx.is_forward_compat(): + return _nary_lower_hlo(chlo.tan, ctx, x) + return _nary_lower_hlo(hlo.tan, ctx, x) + tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) +mlir.register_lowering(tan_p, _tan_lowering) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index e75e8e7d735f..f34b8211eb33 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -419,7 +419,7 @@ def integer_pow(x): return lax.integer_pow(x, 3) print_ir(jnp.bfloat16(0))(lax.sqrt) # CHECK-LABEL: TEST: tan float16[] - # CHECK: chlo.tan + # CHECK: hlo.tan # CHECK-SAME: tensor print_ir(np.float16(0))(lax.tan) From 018189491bde26fe9c7ade1213c5cbbad8bca1c6 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 18 Sep 2024 13:43:14 -0700 Subject: [PATCH 552/702] Clean up and fix primal type to tangent type mapping This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types. Changes: 1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself. 2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion. 3. Add `to_tangent_type` calls in various other places they're missing. 4. Remove non-support for float0 in custom deriviatives? 5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.) PiperOrigin-RevId: 676115753 --- jax/_src/ad_checkpoint.py | 2 +- jax/_src/ad_util.py | 9 ++++- jax/_src/api.py | 4 +- jax/_src/checkify.py | 2 +- jax/_src/core.py | 20 ++++++--- jax/_src/custom_derivatives.py | 49 +++++++++++------------ jax/_src/dtypes.py | 2 +- jax/_src/export/_export.py | 2 +- jax/_src/interpreters/ad.py | 47 ++++++++-------------- jax/_src/interpreters/partial_eval.py | 3 +- jax/_src/lax/ann.py | 4 +- jax/_src/lax/control_flow/conditionals.py | 2 +- jax/_src/lax/control_flow/for_loop.py | 2 +- jax/_src/lax/control_flow/loops.py | 4 +- jax/_src/lax/control_flow/solves.py | 4 +- jax/_src/lax/lax.py | 12 +++--- jax/_src/lax/linalg.py | 4 +- jax/_src/lax/slicing.py | 8 ++-- jax/_src/lax/windowed_reductions.py | 4 +- jax/_src/pallas/core.py | 4 +- jax/_src/state/discharge.py | 2 +- jax/_src/state/types.py | 4 +- jax/core.py | 1 + jax/experimental/attrs.py | 4 +- jax/experimental/shard_map.py | 2 +- jax/experimental/sparse/bcoo.py | 16 ++++---- jax/experimental/sparse/bcsr.py | 4 +- jax/experimental/sparse/coo.py | 4 +- jax/experimental/sparse/csr.py | 4 +- jax/interpreters/ad.py | 2 - tests/api_test.py | 23 ++++++----- tests/export_test.py | 3 +- 32 files changed, 130 insertions(+), 127 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index fd30119882e7..8c7fe2f489d5 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -514,7 +514,7 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): prevent_cse=prevent_cse, differentiated=differentiated, policy=policy) out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)]) out_tangents_ = iter(out_tangents_) - out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents ad.primitive_jvps[remat_p] = remat_jvp diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 57e881c34f82..c69ff3754dc6 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -65,8 +65,8 @@ def __init__(self, aval: core.AbstractValue): def __repr__(self) -> str: return f'Zero({self.aval})' @staticmethod - def from_value(val: Any) -> Zero: - return Zero(raise_to_shaped(get_aval(val))) + def from_primal_value(val: Any) -> Zero: + return Zero(raise_to_shaped(get_aval(val)).to_tangent_aval()) register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval)) @@ -82,6 +82,7 @@ def _stop_gradient_impl(x: T) -> T: stop_gradient_p.def_abstract_eval(lambda x: x) +# User-facing version of `Zero` class SymbolicZero: def __init__(self, aval: core.AbstractValue) -> None: self.aval = aval @@ -108,6 +109,10 @@ def __getattr__(self, name): else: return attr + @staticmethod + def from_primal_value(val: Any) -> SymbolicZero: + return SymbolicZero(get_aval(val).to_tangent_aval()) + JaxTypeOrTracer = Any def replace_internal_symbolic_zeros( diff --git a/jax/_src/api.py b/jax/_src/api.py index b548cc43fb3b..aae99a28bbea 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1826,7 +1826,7 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args): def fun(*tangents): tangent_avals = list(map(core.get_aval, tangents)) for primal_aval, tangent_aval in zip(primal_avals, tangent_avals): - if not core.typecompat(primal_aval.at_least_vspace(), tangent_aval): + if not core.typecompat(primal_aval.to_tangent_aval(), tangent_aval): raise ValueError("linearized function called on tangent values inconsistent with " "the original primal values: " f"got {tangent_aval} for primal aval {primal_aval}") @@ -1869,7 +1869,7 @@ def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_): f"got {in_tree}, but expected to match {in_tree_expected}") for arg, aval in zip(args, out_primal_avals): ct_aval = shaped_abstractify(arg) - ct_aval_expected = aval.at_least_vspace() + ct_aval_expected = aval.to_tangent_aval() if (not core.typecompat(ct_aval, ct_aval_expected) and not _temporary_dtype_exception(ct_aval, ct_aval_expected)): raise ValueError( diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 1167914e51c9..e67f624fc32e 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -980,7 +980,7 @@ def jvp(*xs): out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents) out_primals, nz_out_tangents = split_list(out, [len(out_zeros)]) nz_out_tangents_ = iter(nz_out_tangents) - out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace()) + out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval()) if z else next(nz_out_tangents_) for p, z in zip(out_primals, out_zeros)] assert next(nz_out_tangents_, None) is None diff --git a/jax/_src/core.py b/jax/_src/core.py index 51933a9f8bbf..057a79925e2e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1414,9 +1414,13 @@ def definitely_equal(x, y): class AbstractValue: __slots__: list[str] = [] - def at_least_vspace(self): + def to_tangent_aval(self): raise NotImplementedError("must override") + # TODO(dougalm): deprecate this alias + def at_least_vspace(self): + return self.to_tangent_aval() + def __repr__(self): try: kv_pairs = (f'{k}={v}' for k, v in self.__dict__.items()) @@ -1524,6 +1528,12 @@ def get_aval(x): else: return concrete_aval(x) +def get_type(x): + aval = get_aval(x) + if isinstance(aval, ConcreteArray): + return raise_to_shaped(aval) + else: + return aval def concretization_function_error(fun, suggest_astype=False): fname = getattr(fun, "__name__", fun) @@ -1647,7 +1657,7 @@ def __repr__(self): _oct = concretization_function_error(oct) _index = concretization_function_error(operator.index) - def at_least_vspace(self) -> AbstractValue: + def to_tangent_aval(self) -> AbstractValue: return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -1786,7 +1796,7 @@ def __hash__(self): return hash((self.shape, self.dtype, self.weak_type, getattr(self, 'sharding', None))) - def at_least_vspace(self): + def to_tangent_aval(self): return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -1945,7 +1955,7 @@ def join(self, other): else: raise TypeError(self, other) - def at_least_vspace(self): + def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -2076,7 +2086,7 @@ def join(self, other): else: assert False, f"Cannot join {self} with {other}" def str_short(self, short_dtypes=False): return 'Tok' - def at_least_vspace(self): return self + def to_tangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() # Singleton shaped array used by all abstract tokens when shape/dtype is needed. diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 019948c36683..05ede08d219c 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -67,7 +67,7 @@ def _sum_tangents(_, x, *xs): return reduce(ad.add_tangents, xs, x) def _zeros_like_pytree(x): - return tree_map(Zero.from_value, x) + return tree_map(Zero.from_primal_value, x) _stop_gradient = partial( tree_map, @@ -327,24 +327,27 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - # TODO(mattjj): compare primals' tangent types to tangent objects' types - primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) - for x in primals_out] + primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out] + expected_tangent_avals_out = [ + raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval() + for x in primals_out] tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) if type(t) is not SymbolicZero else t.aval.strip_weak_type() for t in tangents_out] - if primal_avals_out != tangent_avals_out: - if len(primal_avals_out) == 1: - (av1,), (av2,) = primal_avals_out, tangent_avals_out + if expected_tangent_avals_out != tangent_avals_out: + if len(expected_tangent_avals_out) == 1: + (av_p,), (av_et,), (av_t,) = primal_avals_out, expected_tangent_avals_out, tangent_avals_out msg = ("Custom JVP rule must produce primal and tangent outputs with " - "equal shapes and dtypes, but got {} and {} respectively.") - raise TypeError(msg.format(av1.str_short(), av2.str_short())) + "corresponding shapes and dtypes. Expected {} (tangent type of {}) but got {}.") + raise TypeError(msg.format(av_et.str_short(), av_p.str_short(), av_t.str_short())) else: msg = ("Custom JVP rule must produce primal and tangent outputs with " - "equal shapes and dtypes, but got:\n{}") + "corresponding shapes and dtypes, but got:\n{}") disagreements = ( - f" primal {av1.str_short()} for tangent {av2.str_short()}" - for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2) + f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}" + for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out) + if av_et != av_t) + raise TypeError(msg.format('\n'.join(disagreements))) yield primals_out + tangents_out, (out_tree, primal_avals) @@ -392,7 +395,7 @@ def jvp(*xs): out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents) out_primals, nz_out_tangents = split_list(out, [len(out_zeros)]) nz_out_tangents_ = iter(nz_out_tangents) - out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace()) + out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval()) if z else next(nz_out_tangents_) for p, z in zip(out_primals, out_zeros)] assert next(nz_out_tangents_, None) is None @@ -780,10 +783,10 @@ def append(x, d): raise TypeError(msg.format(in_tree2, in_tree)) from None results = [] for kp, a, ct in zip(keypaths, in_avals, cts_in_flat): - if ct is zero or a != a.at_least_vspace(): - results.append(Zero(a.at_least_vspace())) + if ct is zero or a != a.to_tangent_aval(): + results.append(Zero(a.to_tangent_aval())) elif type(ct) is SymbolicZero: - if not core.typecompat(a.at_least_vspace(), a_ := ct.aval): + if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval): msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype " "that does not match the corresponding input tangent shape/dtype: " f"at output{keystr(kp)} the SymbolicZero had shape/dtype " @@ -794,7 +797,7 @@ def append(x, d): raise ValueError(msg) results.append(Zero(ct.aval)) else: - if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct)) + if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct)) and not (_temporary_dtype_exception(a, a_) or _temporary_shape_exception(a, a_))): msg = ("Custom VJP bwd rule must produce an output with the same " @@ -908,16 +911,12 @@ def _custom_vjp_call_jaxpr_jvp( _, res_tree = out_trees() res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] + avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] args_dot = map(ad.instantiate_zeros, args_dot) - # Cast float0 to zeros with the primal dtype because custom vjp rules don't - # currently handle float0s - args_dot = map(ad.replace_float0s, args, args_dot) tangents_out = ad.custom_lin_p.bind( *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) - tangents_out = map(ad.recast_to_float0, primals_out, tangents_out) return primals_out, tangents_out ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp @@ -1039,7 +1038,7 @@ def fwd(*args, **kwargs): ans, rule = fun(*args, **kwargs) ans_flat, out_tree = tree_flatten((ans,)) rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) - ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat] + ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat] jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals) return ans, Residuals(jaxpr, in_tree(), out_tree, consts) @@ -1153,7 +1152,7 @@ def _maybe_perturbed(x: Any) -> bool: elif isinstance(x, pe.DynamicJaxprTracer): # If x is a DynamicJaxprTracer then we're staging out; differentiation could # happen later, but some types always have trivial tangents. - vspace = x.aval.at_least_vspace() + vspace = x.aval.to_tangent_aval() return not (vspace is core.abstract_token or getattr(vspace, 'dtype', None) == dtypes.float0) elif not isinstance(x, ad.JVPTracer): @@ -1425,7 +1424,7 @@ def custom_vjp_by_custom_transpose(fun, fwd, bwd): @fun.defjvp def jvp(primals, tangents): outs, residuals = fwd(*primals) - tan_out_types = tree_map(lambda o: core.get_aval(o).at_least_vspace(), outs) + tan_out_types = tree_map(lambda o: core.get_aval(o).to_tangent_aval(), outs) tan_fn = custom_transpose(partial(disallow_jvp, out_avals=tan_out_types)) tan_fn.def_transpose(bwd) return outs, tan_fn(tan_out_types, residuals, tangents) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 81f4180a1c12..d76b80ad3a89 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -784,7 +784,7 @@ def check_user_dtype_supported(dtype, fun_name=None): uint2, uint4, ] - if np_dtype.kind not in "biufc" and not is_custom_dtype: + if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0: msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" msg += f" in {fun_name}" if fun_name else "" raise TypeError(msg) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index d0159f7a4334..7f7773acbd39 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1127,7 +1127,7 @@ def flattened_primal_fun_jax(*args_flat): vjp_in_avals = list( itertools.chain(in_avals, - map(lambda a: a.at_least_vspace(), out_avals))) + map(lambda a: a.to_tangent_aval(), out_avals))) if apply_jit: assert device_assignment is not None diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index f1b25cf96a95..f1f46a5c18f7 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -57,7 +57,7 @@ def _update_annotation( # Implicit arguments never have tangents, so generate the tangent part of the # type annotation from explicit arguments only. explicit_avals = [aval for aval, explicit in orig_type if explicit] - tan_types = [(aval.at_least_vspace(), True) + tan_types = [(aval.to_tangent_aval(), True) for nz, aval in zip(explicit_nonzeros, explicit_avals) if nz] return lu.annotate(f, (*orig_type, *tan_types)) @@ -72,7 +72,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, @lu.transformation def jvpfun(instantiate, transform_stack, primals, tangents): - tangents = [Zero.from_value(t) if not isinstance(t, Zero) + tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) @@ -124,7 +124,7 @@ def linearize(traceable, *primals, **kwargs): jvpfun, aux = jvp(traceable, has_aux=True) in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) - + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace()) + + tuple(pe.PartialVal.unknown(get_aval(p).to_tangent_aval()) for p in primals)) _, in_tree = tree_flatten(((primals, primals), {})) jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) @@ -166,18 +166,6 @@ def unpair_pval(pval): aval_1, aval_2 = aval return (aval_1, const_1), (aval_2, const_2) -def replace_float0s(primal, tangent): - if dtype(tangent) == float0: - return zeros_like_jaxval(primal) - else: - return tangent - -def recast_to_float0(primal, tangent): - if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0: - return Zero(get_aval(primal).at_least_vspace()) - else: - return tangent - # NOTE: The FIXMEs below are caused by primal/tangent mixups (type # errors if you will) @@ -203,7 +191,7 @@ def write_cotangent(prim, v, ct): # assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): - return ct_env.pop(v, Zero(v.aval.at_least_vspace())) + return ct_env.pop(v, Zero(v.aval.to_tangent_aval())) def read_primal(v): if type(v) is Literal: @@ -295,11 +283,11 @@ def nonzero_tangent_outputs(*args, **kwargs): class JVPTrace(Trace): def pure(self, val): - tangent_zero = Zero(get_aval(val).at_least_vspace()) + tangent_zero = Zero.from_primal_value(val) return JVPTracer(self, val, tangent_zero) def lift(self, val): - tangent_zero = Zero(get_aval(val).at_least_vspace()) + tangent_zero = Zero.from_primal_value(val) return JVPTracer(self, val, tangent_zero) def sublift(self, val): @@ -343,7 +331,7 @@ def new_out_axes_thunk(): result = call_primitive.bind(_update_annotation(f_jvp, f.in_type, which_nz), *args, **new_params) primal_out, tangent_out = tree_unflatten(out_tree(), result) - tangent_out = [Zero(get_aval(p).at_least_vspace()) if t is None else t + tangent_out = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primal_out, tangent_out)] return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] @@ -374,13 +362,11 @@ def process_custom_jvp_call(self, _, __, f_jvp, tracers, *, symbolic_zeros): primals_in = map(core.full_lower, primals_in) if not symbolic_zeros: tangents_in = map(instantiate_zeros, tangents_in) - tangents_in = map(replace_float0s, primals_in, tangents_in) else: tangents_in = map(replace_internal_symbolic_zeros, tangents_in) outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) primals_out, tangents_out = split_list(outs, [len(outs) // 2]) tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out) - tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) def post_process_custom_jvp_call(self, out_tracers, _): @@ -398,14 +384,13 @@ def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, res_and_primals_out = fwd.call_wrapped(*fwd_in) _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] + avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! tangents_in = map(instantiate_zeros, tangents_in) tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) - tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) def post_process_custom_vjp_call(self, out_tracers, _): @@ -505,8 +490,8 @@ def linear_jvp(primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) if all(type(tangent) is Zero for tangent in tangents): if primitive.multiple_results: - return val_out, map(Zero.from_value, val_out) - return val_out, Zero.from_value(val_out) + return val_out, map(Zero.from_primal_value, val_out) + return val_out, Zero.from_primal_value(val_out) else: tangents = map(instantiate_zeros, tangents) return val_out, primitive.bind(*tangents, **params) @@ -533,7 +518,7 @@ def standard_jvp(jvprules, primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents) if rule is not None and type(t) is not Zero] - return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out)) + return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_primal_value(val_out)) def defjvp2(primitive, *jvprules): assert isinstance(primitive, Primitive) @@ -545,7 +530,7 @@ def standard_jvp2(jvprules, primitive, primals, tangents, **params): tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents) if rule is not None and type(t) is not Zero) tangents_out = list(tangents_out) - return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out)) + return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_primal_value(val_out)) def add_tangents(x, y): if type(x) is Zero: @@ -580,7 +565,7 @@ def defjvp_zero(primitive): def zero_jvp(primitive, primals, tangents, **params): r = primitive.bind(*primals, **params) - return r, Zero.from_value(r) + return r, Zero.from_primal_value(r) deflinear2(add_jaxvals_p, lambda t, *args: (t, t)) @@ -591,7 +576,7 @@ def instantiate_zeros(tangent): @lu.transformation_with_aux def traceable(in_tree, *primals_and_tangents): primals, tangents = tree_unflatten(in_tree, primals_and_tangents) - tangents = [Zero(get_aval(p).at_least_vspace()) if t is None else t + tangents = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primals, tangents)] primals_out, tangents_out = yield (primals, tangents), {} tangents_out = [None if type(t) is Zero else t for t in tangents_out] @@ -695,7 +680,7 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) - tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] + tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() @@ -705,7 +690,7 @@ def f_jvp_traceable(nonzeros, *primals_and_nztangents): num_primals = len(nonzeros) primals = list(primals_and_nztangents[:num_primals]) nonzero_tangents = iter(primals_and_nztangents[num_primals:]) - tangents = [next(nonzero_tangents) if nz else Zero.from_value(p) + tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p) for p, nz in zip(primals, nonzeros)] primals_out, tangents_out = yield (primals, tangents), {} out_nonzeros = [type(t) is not Zero for t in tangents_out] diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 374816e001ec..fc2214aaf29f 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2158,6 +2158,7 @@ def post_process_map(self, map_primitive, out_tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): in_avals = [t.aval for t in tracers] + in_tangent_avals = [t.to_tangent_aval() for t in in_avals] with core.new_sublevel(): fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) @@ -2166,7 +2167,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): @_memoize def jvp_jaxpr_thunk(*in_zeros): for store in jvp.stores: store and store.reset() - nz_tangent_avals, zero_avals = partition_list(in_zeros, in_avals) + nz_tangent_avals, zero_avals = partition_list(in_zeros, in_tangent_avals) jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals)) in_avals_ = (*in_avals, *nz_tangent_avals) jaxpr, _, out_consts, () = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_) diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index f2dbd8d4fa0e..0e037ec774b5 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -373,7 +373,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension, reduction_input_size_override, aggregate_to_topk) if type(tangent) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: arg_shape = arg_out.shape rank = len(arg_shape) @@ -385,7 +385,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension, idx = tuple( arg_out if i == reduction_dimension else iotas[i] for i in range(rank)) tangent_out = tangent[idx] - return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out)) + return (val_out, arg_out), (tangent_out, ad_util.Zero.from_primal_value(arg_out)) approx_top_k_p = core.Primitive('approx_top_k') diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index b96f9e8c6e40..4cb38d28c36f 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -434,7 +434,7 @@ def _cond_jvp(primals, tangents, branches): out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp) out_primals, out_tangents = split_list(out, [len(out_nz)]) out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 61b9a24644ce..21b522b3d8bb 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -340,7 +340,7 @@ def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear, # into outputs as well. We don't care about these in AD so we throw them out. out_primals, out_tangents = split_list(out_flat, [len(primals)]) out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, nonzero_tangents)] return out_primals, out_tangents ad.primitive_jvps[for_p] = _for_jvp diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 828728ebdbd2..41d809f8d688 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -547,7 +547,7 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys]) primals_out = carry + ys tangents_out_iter = iter(carry_dot + ys_dot) - tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p) + tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(primals_out, nonzeros_out)] return primals_out, tangents_out @@ -1518,7 +1518,7 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts, out_carry, out_carry_dot = split_list(out, [num_carry]) out_tangents_iter = iter(out_carry_dot) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_carry, nonzeros_out)] return out_carry, out_tangents diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 21105e20aaf8..4e0f5086b121 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -316,7 +316,7 @@ def _tangent_linear_map(func, params, params_dot, *x): this function computes ``∂A @ x``. """ assert any(type(p) is not ad_util.Zero for p in params_dot) - zeros = _map(ad_util.Zero.from_value, x) + zeros = _map(ad_util.Zero.from_primal_value, x) _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped( params + list(x), params_dot + zeros) return out_tangent @@ -352,7 +352,7 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs): # split into x tangents and aux tangents (these become zero) dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves]) - daux_leaves = _map(ad_util.Zero.from_value, daux_leaves) + daux_leaves = _map(ad_util.Zero.from_primal_value, daux_leaves) x_dot = dx_leaves + daux_leaves diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e356756cd3e1..28ad429b0367 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2310,7 +2310,7 @@ def _add_jvp(primals, tangents): xdot, ydot = tangents primal_out = add(x, y) if type(xdot) is type(ydot) is ad_util.Zero: - return primal_out, ad_util.Zero.from_value(primal_out) + return primal_out, ad_util.Zero.from_primal_value(primal_out) if type(xdot) is ad_util.Zero: return primal_out, _maybe_broadcast(primal_out.shape, ydot) elif type(ydot) is ad_util.Zero: @@ -2341,7 +2341,7 @@ def _sub_jvp(primals, tangents): xdot, ydot = tangents primal_out = sub(x, y) if type(xdot) is type(ydot) is ad_util.Zero: - return primal_out, ad_util.Zero.from_value(primal_out) + return primal_out, ad_util.Zero.from_primal_value(primal_out) if type(xdot) is ad_util.Zero: return primal_out, _maybe_broadcast(primal_out.shape, neg(ydot)) elif type(ydot) is ad_util.Zero: @@ -3365,7 +3365,7 @@ def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape, broadcast_dimensions=broadcast_dimensions) if type(operand_dot) is ad_util.Zero: - y_dot = ad_util.Zero.from_value(y) + y_dot = ad_util.Zero.from_primal_value(y) else: y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape, broadcast_dimensions=broadcast_dimensions) @@ -4535,7 +4535,7 @@ def _top_k_jvp(primals, tangents, *, k): tangent, = tangents primals_out = top_k(operand, k) if type(tangent) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(primals_out[0]) + tangent_out = ad_util.Zero.from_primal_value(primals_out[0]) else: _, k_idxs = primals_out idx_shape = k_idxs.shape @@ -4554,7 +4554,7 @@ def _top_k_jvp(primals, tangents, *, k): collapsed_slice_dims=tuple(range(rank)), start_index_map=tuple(range(rank))) tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes) - return primals_out, (tangent_out, ad_util.Zero.from_value(primals_out[1])) + return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1])) def _top_k_batch_rule(batched_args, batch_dims, *, k): operand, = batched_args @@ -4590,7 +4590,7 @@ def _top_k_lower(ctx, operand, k): def _stop_gradient_jvp_rule(primals, tangents): # if we don't call stop_gradient here, we'd only peel off one autodiff tracer x, = primals - return stop_gradient(x), ad_util.Zero.from_value(x) + return stop_gradient(x), ad_util.Zero.from_primal_value(x) def _stop_gradient_batch_rule(batched_args, batch_dims): x, = batched_args diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 8752e0b6d1de..ec0a075dae1b 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1487,8 +1487,8 @@ def _lu_jvp_rule(primals, tangents): l_dot = l @ _tril(lau, -1) u_dot = _triu(lau) @ u lu_dot = l_dot + u_dot - return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots), - ad_util.Zero.from_value(permutation)) + return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_primal_value(pivots), + ad_util.Zero.from_primal_value(permutation)) def _lu_batching_rule(batched_args, batch_dims): diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 39d4b31588c1..5ed1945ecb96 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1362,7 +1362,7 @@ def _dynamic_update_slice_jvp(primals, tangents): g_operand, g_update = tangents[:2] val_out = dynamic_update_slice_p.bind(operand, update, *start_indices) if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_update = ad.instantiate_zeros(g_update) @@ -2000,7 +2000,7 @@ def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) @@ -2180,7 +2180,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, unique_indices=unique_indices, mode=mode) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) @@ -2294,7 +2294,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, update_consts=update_consts, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - return val_out, ad_util.Zero.from_value(val_out) + return val_out, ad_util.Zero.from_primal_value(val_out) g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index dd8e664a095a..089a77de2949 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -707,7 +707,7 @@ def _select_and_scatter_add_jvp( padding) del g_operand if type(g_source) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: tangent_out = _select_and_scatter_add( g_source, operand, select_prim, window_dimensions, @@ -952,7 +952,7 @@ def _select_and_gather_add_jvp( padding, base_dilation, window_dilation) del g_operand if type(g_source) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: tangent_out = _select_and_gather_add( g_source, operand, select_prim, window_dimensions, diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index f8ec3b63339a..7e5768c04092 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -145,9 +145,9 @@ def update(self, inner_aval=None, memory_space=None): memory_space = self.memory_space if memory_space is None else memory_space return AbstractMemoryRef(inner_aval, memory_space) - def at_least_vspace(self): + def to_tangent_aval(self): return AbstractMemoryRef( - self.inner_aval.at_least_vspace(), self.memory_space) + self.inner_aval.to_tangent_aval(), self.memory_space) def __eq__(self, other): return (type(self) is type(other) and self.inner_aval == other.inner_aval diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 4231822965b1..7970440d29a6 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -490,7 +490,7 @@ def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, len(primals)]) del out_consts out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, nonzero_tangents)] return out_primals, out_tangents ad.primitive_jvps[run_state_p] = _run_state_jvp diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 8289f858498b..e64d6258a808 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -243,8 +243,8 @@ def _setitem(self, tracer, idx, value) -> None: def __repr__(self) -> str: return f'Ref{{{self.inner_aval.str_short()}}}' - def at_least_vspace(self): - return AbstractRef(self.inner_aval.at_least_vspace()) + def to_tangent_aval(self): + return AbstractRef(self.inner_aval.to_tangent_aval()) def __eq__(self, other): return (type(self) is type(other) and self.inner_aval == other.inner_aval) diff --git a/jax/core.py b/jax/core.py index 1f433d6f5c29..9857fcf88c02 100644 --- a/jax/core.py +++ b/jax/core.py @@ -85,6 +85,7 @@ full_lower as full_lower, gensym as gensym, get_aval as get_aval, + get_type as get_type, get_referent as get_referent, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 8176465c1470..62da0f231d50 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -169,7 +169,7 @@ def linearize(f, *primals, attrs: list[tuple[Any, str]] = []): def _linearize(traceable: lu.WrappedFun, *primals): jvpfun, attrs = _split_attrs(_jvp(traceable)) in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) - + tuple(pe.PartialVal.unknown(core.get_aval(p).at_least_vspace()) + + tuple(pe.PartialVal.unknown(core.get_aval(p).to_tangent_aval()) for p in primals)) _, in_tree = tree_flatten((primals, primals)) jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree) @@ -211,7 +211,7 @@ def vjp(f, *primals, attrs: list[tuple[Any, str]] = []): f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree) primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( f_, *attr_primals, *primals_flat) - attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).at_least_vspace() + attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).to_tangent_aval() for o, a in attrs_out] f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), attrs, attrs_out) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 8319e3fba70f..fabd45ca069a 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1405,7 +1405,7 @@ def new_out_names_thunk(): f_jvp, out_tree = ad.traceable(f_jvp, in_tree) result = shard_map_p.bind(f_jvp, *args, **params) primal_out, tangent_out = tree_unflatten(out_tree(), result) - tangent_out = [ad.Zero(core.get_aval(p).at_least_vspace()) if t is None else t + tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t for p, t in zip(primal_out, tangent_out)] return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] ad.JVPTrace.process_shard_map = _shard_map_jvp diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 9eafa0db0fc2..d200577c2416 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -332,11 +332,11 @@ def _bcoo_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype data, indices = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = _bcoo_extract(indices, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(indices)) + tangents_out = (data_dot, ad.Zero.from_primal_value(indices)) return primals_out, tangents_out @@ -571,7 +571,7 @@ def _bcoo_transpose_jvp(primals, tangents, *, permutation: Sequence[int], spinfo data_dot, _ = tangents primals_out = _bcoo_transpose(data, indices, permutation=permutation, spinfo=spinfo) data_dot_out, _ = _bcoo_transpose(data_dot, indices, permutation=permutation, spinfo=spinfo) - return primals_out, (data_dot_out, ad.Zero.from_value(indices)) + return primals_out, (data_dot_out, ad.Zero.from_primal_value(indices)) def _bcoo_transpose_transpose(ct, data, indices, *, permutation: Sequence[int], spinfo: SparseInfo): data_ct, indices_ct = ct @@ -1277,7 +1277,7 @@ def _bcoo_spdot_general_jvp(primals, tangents, **kwds): data_dot_out += _bcoo_spdot_general(lhs_data_dot, lhs_indices, rhs_data, rhs_indices, **kwds)[0] if type(rhs_data_dot) is not ad.Zero: data_dot_out += _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data_dot, rhs_indices, **kwds)[0] - return primals_out, [data_dot_out, ad.Zero.from_value(primals_out[1])] + return primals_out, [data_dot_out, ad.Zero.from_primal_value(primals_out[1])] # TODO(JVP): transpose rule batching.primitive_batchers[bcoo_spdot_general_p] = _bcoo_spdot_general_batch_rule @@ -1358,8 +1358,8 @@ def _bcoo_sort_indices_jvp(primals, tangents, *, spinfo): permute = nfold_vmap(lambda d, p: d[p], props.n_batch) data_out = permute(data, perm) - indices_dot_out = ad.Zero.from_value(indices) - data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm) + indices_dot_out = ad.Zero.from_primal_value(indices) + data_dot_out = ad.Zero.from_primal_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm) return (data_out, indices_out), (data_dot_out, indices_dot_out) _bcoo_sort_indices_hlo = mlir.lower_fun( @@ -1544,8 +1544,8 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse): permute = lambda x, i, y: x permute = nfold_vmap(permute, props.n_batch) data_out = permute(data_out, mapping, data) - indices_dot_out = ad.Zero.from_value(indices_out) - data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot) + indices_dot_out = ad.Zero.from_primal_value(indices_out) + data_dot_out = ad.Zero.from_primal_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot) return (data_out, indices_out), (data_dot_out, indices_dot_out) _bcoo_sum_duplicates_hlo = mlir.lower_fun( diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7f3ebb43c0ec..7275d6bb20aa 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -272,11 +272,11 @@ def _bcsr_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype data, indices, indptr = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = bcsr_extract(indices, indptr, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr)) + tangents_out = (data_dot, ad.Zero.from_primal_value(indices), ad.Zero.from_primal_value(indptr)) return primals_out, tangents_out diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index 8863478df4d3..c65bc87235d6 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -348,11 +348,11 @@ def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype): data, row, col = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = _coo_extract(row, col, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col)) + tangents_out = (data_dot, ad.Zero.from_primal_value(row), ad.Zero.from_primal_value(col)) return primals_out, tangents_out diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index c1178943c02a..89d08f109d68 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -380,11 +380,11 @@ def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype): data, indices, indptr = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = _csr_extract(indices, indptr, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr)) + tangents_out = (data_dot, ad.Zero.from_primal_value(indices), ad.Zero.from_primal_value(indptr)) return primals_out, tangents_out diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 6663df3ac473..6bfc3473ff50 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -59,9 +59,7 @@ primitive_jvps as primitive_jvps, primitive_transposes as primitive_transposes, rearrange_binders as rearrange_binders, - recast_to_float0 as recast_to_float0, reducing_transposes as reducing_transposes, - replace_float0s as replace_float0s, standard_jvp as standard_jvp, standard_jvp2 as standard_jvp2, traceable as traceable, diff --git a/tests/api_test.py b/tests/api_test.py index 8b75cb624f1b..b0915a1df44b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7203,10 +7203,11 @@ def foo_jvp(primals, tangents): TypeError, re.escape( "Custom JVP rule must produce primal and tangent outputs " - "with equal shapes and dtypes, but got float32[] and float32[1] " - "respectively."), + "with corresponding shapes and dtypes. " + "Expected float32[] (tangent type of float32[]) but got float32[1]."), lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),))) + def test_jvp_rule_doesnt_return_pair_error_message(self): # https://github.com/google/jax/issues/2516 @@ -7536,12 +7537,12 @@ def g_jvp(primals, tangents): self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32')) def test_float0(self): + scalar_float0 = jnp.zeros((), dtype=float0) @jax.custom_jvp def f(x, y): return x, y def f_jvp(primals, _): - # we need a defined (non-float0) tangent to trigger the rule - return primals, (2., 1) + return primals, (2., scalar_float0) f.defjvp(f_jvp) primals = (2., 3) @@ -7551,12 +7552,13 @@ def f_jvp(primals, _): (primals, expected_tangents)) def test_float0_initial_style(self): + scalar_float0 = jnp.zeros((), dtype=float0) @jax.custom_jvp def f(x, y): return x, y def f_jvp(primals, _): x, y = primals - return (x, y), (2., 1) + return (x, y), (2., scalar_float0) f.defjvp(f_jvp) def foo(x, y): @@ -7564,8 +7566,9 @@ def foo(x, y): return out primals = (2., 3) - tangents = (np.ones(()), np.zeros((), float0),) - expected_tangents = (2., np.zeros((), float0)) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) + self.assertAllClose(api.jvp(foo, primals, tangents), (primals, expected_tangents)) @@ -8730,7 +8733,7 @@ def f(x): def f_fwd(x): return x, (2., x) def f_rev(*_): - return ((2., 1),) + return ((2., jnp.zeros(shape=(), dtype=float0)),) f.defvjp(f_fwd, f_rev) def foo(x, y): @@ -9670,12 +9673,12 @@ def __call__(self, *args): # an option of inferring output types. def custom_transpose(example_out): if isinstance(example_out, Callable): - out_type = core.get_aval(0.).at_least_vspace() + out_type = core.get_aval(0.).to_tangent_aval() return _custom_transpose(out_type, example_out) return partial( _custom_transpose, jax.tree.map( - lambda x: core.get_aval(x).at_least_vspace(), example_out)) + lambda x: core.get_aval(x).to_tangent_aval(), example_out)) class CustomTransposeTest(jtu.JaxTestCase): diff --git a/tests/export_test.py b/tests/export_test.py index b269aef28d79..d5884b7e6b16 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -473,7 +473,8 @@ def f(xi, xf): # Native JAX 1st order vjp (f_outi, f_outf), f_vjp = jax.vjp(f, xi, xf) - f_outi_ct = np.ones(f_outi.shape, dtype=f_outi.dtype) + f_outi_ct = np.ones(f_outi.shape, + dtype=core.primal_dtype_to_tangent_dtype(f_outi.dtype)) f_outf_ct = np.ones(f_outf.shape, dtype=f_outf.dtype) xi_ct, xf_ct = f_vjp((f_outi_ct, f_outf_ct)) From 1d8462189b3542fa78202af4ae9b75bac2d113ec Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 18 Sep 2024 13:59:56 -0700 Subject: [PATCH 553/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d2434f289c130c9d87c05a1e7086abf7922519fc. PiperOrigin-RevId: 676121487 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 0df1b77fbb39..a1e1fc505455 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f6b6175735336f6bdf0ec4af79a3314e6673ccd6" -XLA_SHA256 = "7033fba5ae9cb701173cf534825a7aa95425c0f4d174b6611293d0d08962492e" +XLA_COMMIT = "d2434f289c130c9d87c05a1e7086abf7922519fc" +XLA_SHA256 = "5264285791bda5c123cda881b44cba8fe404cf334d5843c54110d8678e319872" def repo(): tf_http_archive( From 0c7c71e640541709bcbb7173a7623724ee47e8a7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 18 Sep 2024 14:48:52 -0700 Subject: [PATCH 554/702] Update python version from 3.12 to 3.13.0rc2 in Github presubmit jobs. PiperOrigin-RevId: 676140293 --- build/test-requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 4f9d19e76ba2..0c9aa086f109 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -12,4 +12,5 @@ portpicker pytest-xdist wheel rich -setuptools +# TODO(ybaturina): remove setuptools version +setuptools<71.0.0 From ba06bd5aaa80abd0d65cc91da5eaea0748e57a89 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 18 Sep 2024 16:45:20 -0700 Subject: [PATCH 555/702] Reduced duplication between `_bcast` and `_ensure_fa` in Pallas Mosaic GPU lowering PiperOrigin-RevId: 676180945 --- jax/_src/pallas/mosaic_gpu/lowering.py | 32 ++++++++++++-------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 1d76dc8405d5..3afc62aebfcf 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -635,7 +635,8 @@ def _broadcast_in_dim_lowering_rule( ): if broadcast_dimensions: raise NotImplementedError - return _ensure_fa(x, ctx.avals_in[0]).broadcast(shape) + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).broadcast(shape) @register_lowering_rule(lax.convert_element_type_p) @@ -643,7 +644,8 @@ def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): del weak_type, sharding - return _ensure_fa(x, *ctx.avals_in).astype(mlir.dtype_to_ir_type(new_dtype)) + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).astype(mlir.dtype_to_ir_type(new_dtype)) def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): @@ -661,7 +663,8 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): @register_lowering_rule(lax.integer_pow_p) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): - x = _ensure_fa(x, *ctx.avals_in) + [x_aval] = ctx.avals_in + x = _ensure_fa(x, x_aval.dtype) if y == 2: return x * x return NotImplementedError @@ -669,7 +672,8 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): @register_lowering_rule(lax.rsqrt_p) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): - return _ensure_fa(x, *ctx.avals_in).rsqrt(ctx.module_ctx.approx_math) + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).rsqrt(ctx.module_ctx.approx_math) @register_lowering_rule(lax.reduce_sum_p) @@ -721,22 +725,16 @@ def _bcast( y_aval: jax_core.ShapedArray, out_aval: jax_core.ShapedArray, ) -> ir.Value: - if isinstance(x, (np.ndarray, np.number, int, float)): + if not isinstance(x, mgpu.FragmentedArray): x_dtype = x_aval.dtype if x_aval.weak_type: x_dtype = y_aval.dtype - x = mgpu.FragmentedArray.splat( - _ir_constant(x, mlir.dtype_to_ir_type(x_dtype)), () - ) - if isinstance(y, (np.ndarray, np.number, int, float)): + x = _ensure_fa(x, x_dtype) + if not isinstance(y, mgpu.FragmentedArray): y_dtype = y_aval.dtype if y_aval.weak_type: y_dtype = x_aval.dtype - y = mgpu.FragmentedArray.splat( - _ir_constant(y, mlir.dtype_to_ir_type(y_dtype)), () - ) - assert isinstance(x, mgpu.FragmentedArray) - assert isinstance(y, mgpu.FragmentedArray) + y = _ensure_fa(y, y_dtype) if x_aval.shape != out_aval.shape: x = x.broadcast(out_aval.shape) if y_aval.shape != out_aval.shape: @@ -744,17 +742,17 @@ def _bcast( return x, y -def _ensure_fa(x: object, aval: jax_core.ShapedArray) -> mgpu.FragmentedArray: +def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: if isinstance(x, mgpu.FragmentedArray): return x elif isinstance(x, (np.number, np.ndarray, int, float)): return mgpu.FragmentedArray.splat( - _ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)), () + _ir_constant(x, mlir.dtype_to_ir_type(dtype)), () ) elif isinstance(x, ir.Value): if isinstance(x.type, (ir.IntegerType, ir.FloatType)): return mgpu.FragmentedArray.splat(x, ()) - raise NotImplementedError + raise NotImplementedError(f"Unsupported type: {type(x)}") def _ir_constant(v: object, t: ir.Type) -> ir.Value: From 727b79a608b145831385ecc9201b3a7bcba51a47 Mon Sep 17 00:00:00 2001 From: Haichen Li Date: Wed, 18 Sep 2024 14:59:38 -0700 Subject: [PATCH 556/702] declare buffer donation and compilation cache support for platform "neuron" more at https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/jax/index.html --- jax/_src/compilation_cache.py | 2 +- jax/_src/interpreters/mlir.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index b946dc0a2897..c75d1783f356 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -72,7 +72,7 @@ def is_cache_used(backend: xla_client.Client) -> bool: # backend that supports serialization of executables. # TODO(skye): add warning when initializing cache on unsupported default # platform - supported_platforms = ["tpu", "gpu", "cpu"] + supported_platforms = ["tpu", "gpu", "cpu", "neuron"] if not _is_cache_enabled(): monitoring.record_event('/jax/compilation_cache/task_disabled_cache') diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index c4c77c72b88b..8df4176ccb13 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -951,7 +951,7 @@ class LoweringResult(NamedTuple): shape_poly_state: ShapePolyLoweringState -_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"] +_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu", "neuron"] def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim): From 9d2e9c688c4e8b733e68467d713091436a672ac0 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 18 Sep 2024 20:38:54 -0700 Subject: [PATCH 557/702] [Pallas TPU] Add support for passing in and returning semaphores This change enables writing async ops using Pallas. However, there are *extremely sharp edges* using this API. Please read the design note here: https://jax.readthedocs.io/en/latest/pallas/async_note.html. Followup CLs will investigate safer APIs for writing async ops. PiperOrigin-RevId: 676243335 --- jax/_src/pallas/core.py | 107 ++- jax/_src/pallas/mosaic/core.py | 27 +- jax/_src/pallas/mosaic/lowering.py | 46 +- .../pallas/mosaic/pallas_call_registration.py | 51 +- jax/_src/pallas/mosaic/pipeline.py | 2 +- .../mosaic_gpu/pallas_call_registration.py | 3 +- jax/_src/pallas/pallas_call.py | 59 +- jax/_src/pallas/primitives.py | 2 +- .../pallas/triton/pallas_call_registration.py | 3 +- jax/experimental/pallas/__init__.py | 2 +- jax/experimental/pallas/tpu.py | 1 + tests/pallas/BUILD | 14 + tests/pallas/tpu_pallas_async_test.py | 759 ++++++++++++++++++ 13 files changed, 1016 insertions(+), 60 deletions(-) create mode 100644 tests/pallas/tpu_pallas_async_test.py diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 7e5768c04092..00bbbbe888d5 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -31,6 +31,7 @@ from jax._src import config from jax._src import core as jax_core from jax._src import deprecations +from jax._src import dtypes from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import state @@ -114,21 +115,113 @@ def from_pallas_call(pallas_call_name: str | None, " ".join(src_info_parts[1:])) -# Pytrees of jax.ShapeDtypeStruct -ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...] - split_list = util.split_list map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip +class ShapedArrayWithMemorySpace(jax_core.ShapedArray): + __slots__ = ["memory_space"] + + def __init__(self, shape, dtype, weak_type=False, sharding=None, + memory_space=None): + super().__init__(shape, dtype, weak_type=weak_type, sharding=sharding) + self.memory_space = memory_space + + def __eq__(self, other): + return super().__eq__(other) and self.memory_space == other.memory_space + + def __hash__(self): + return hash(( + self.shape, + self.dtype, + self.weak_type, + getattr(self, "sharding", None), + self.memory_space, + )) + + def at_least_vspace(self): + """Vector space method needed for AD.""" + raise NotImplementedError + + def join(self, other): + raise NotImplementedError + + def str_short(self, short_dtypes=False): + dt_str = ( + jax_core._short_dtype_name(self.dtype) + if short_dtypes + else self.dtype.name + ) + dt_str = dt_str.replace("void", "float0") + shapestr = ",".join(map(str, self.shape)) + if hasattr(self, "sharding"): + sharding_str = f"{dt_str}[{shapestr}]({self.sharding})" + else: + sharding_str = "" + memoryspace_str = ( + "" if self.memory_space is None else f"{self.memory_space}>" + ) + return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}" + + def update( + self, + shape=None, + dtype=None, + weak_type=None, + sharding=None, + memory_space=None, + ): + if shape is None: + shape = self.shape + if dtype is None: + dtype = self.dtype + if weak_type is None: + weak_type = self.weak_type + if sharding is None: + sharding = getattr(self, "sharding", None) + if memory_space is None: + memory_space = self.memory_space + return ShapedArrayWithMemorySpace( + shape, dtype, weak_type, sharding=sharding, memory_space=memory_space + ) +mlir.ir_type_handlers[ShapedArrayWithMemorySpace] = mlir._array_ir_types + + +@dataclasses.dataclass(frozen=True) +class MemoryRef: + """Like jax.ShapeDtypeStruct but with memory spaces.""" + shape: tuple[int, ...] + dtype: jnp.dtype + # TODO(b/368122763): Unify memory space types across backends + memory_space: Any + + def get_array_aval(self) -> jax_core.ShapedArray: + dtype = self.dtype + if not isinstance(dtype, (jnp.dtype, dtypes.ExtendedDType)): + dtype = jnp.dtype(dtype) + return ShapedArrayWithMemorySpace( + self.shape, dtype, memory_space=self.memory_space + ) + + def get_ref_aval(self) -> AbstractMemoryRef: + return AbstractMemoryRef( + ShapedArrayWithMemorySpace(self.shape, self.dtype), self.memory_space) + + class AbstractMemoryRef(state.AbstractRef): __slots__ = ["inner_aval", "memory_space"] inner_aval: jax_core.ShapedArray def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any): + if isinstance(inner_aval, ShapedArrayWithMemorySpace): + if inner_aval.memory_space is not None: + assert inner_aval.memory_space == memory_space, ( + f"Mismatched memory spaces: {inner_aval.memory_space=}," + f" {memory_space=}" + ) self.inner_aval = inner_aval self.memory_space = memory_space @@ -158,7 +251,7 @@ def __hash__(self): class MemorySpace(enum.Enum): - """ Logical, device-agnostic memory spaces. + """Logical, device-agnostic memory spaces. Each memory space will be translated to a device-specific memory type during lowering. @@ -731,7 +824,9 @@ def _convert_block_spec_to_block_mapping( class ScratchShape(Protocol): - def get_aval(self) -> jax_core.AbstractValue: + def get_array_aval(self) -> jax_core.AbstractValue: + ... + def get_ref_aval(self) -> state.AbstractRef: ... @@ -833,7 +928,7 @@ def get_grid_mapping( if grid_spec.scratch_shapes: flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( grid_spec.scratch_shapes) - flat_scratch_avals = map(lambda s: s.get_aval(), flat_scratch_shapes) + flat_scratch_avals = map(lambda s: s.get_ref_aval(), flat_scratch_shapes) num_flat_scratch_operands = len(flat_scratch_avals) jaxpr_scratch_avals = tree_util.tree_unflatten( scratch_tree, flat_scratch_avals) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index b2b892a64f90..76166ae61963 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -90,7 +90,7 @@ def __str__(self) -> str: def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. - return MemoryRef(shape, dtype, self) + return pallas_core.MemoryRef(shape, dtype, self) class semaphore_dtype(dtypes.extended): pass class semaphore(semaphore_dtype): pass @@ -102,6 +102,10 @@ class AbstractSemaphoreTyRules: def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE) + @staticmethod + def physical_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), jnp.int32) + class AbstractSemaphoreTy(dtypes.ExtendedDType): name: str _rules = AbstractSemaphoreTyRules @@ -144,10 +148,13 @@ def __call__(self, shape: tuple[int, ...]): dtype = SemaphoreTy() if pallas_core.is_interpret_mode(): dtype = pallas_core.SEMAPHORE_INTERPRET_DTYPE - return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) + return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) + + def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: + return self(()).get_array_aval() - def get_aval(self) -> AbstractMemoryRef: - return self(()).get_aval() + def get_ref_aval(self) -> AbstractMemoryRef: + return self(()).get_ref_aval() @dataclasses.dataclass(frozen=True) class AbstractSemaphore(jax_core.AbstractValue): @@ -163,18 +170,6 @@ def join(self, other): jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval -@dataclasses.dataclass(frozen=True) -class MemoryRef: - """Like jax.ShapeDtypeStruct but with memory spaces.""" - shape: tuple[int, ...] - dtype: jnp.dtype - memory_space: TPUMemorySpace = TPUMemorySpace.ANY - - def get_aval(self) -> AbstractMemoryRef: - return AbstractMemoryRef( - jax_core.ShapedArray(self.shape, self.dtype), self.memory_space) - - @dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): num_scalar_prefetch: int diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f76a4d86616a..7cc8de90b6b8 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -424,24 +424,23 @@ class MeshInfo: axis_names: list[str] mesh_strides: tuple[int, ...] -def lower_jaxpr_to_module( + +def _check_block_mappings( + block_mappings: tuple[pallas_core.BlockMapping, ...], lowering_context: mlir.LoweringRuleContext, - ctx: ir.Context, - grid_mapping: pallas_core.GridMapping, - jaxpr: jax_core.Jaxpr, - *, - dimension_semantics: tuple[str | None, ...] | None, name_and_src_info: pallas_core.NameAndSrcInfo, - mesh: mesh_lib.Mesh | None = None, - for_verification: bool = False, -) -> tuple[Module, tuple[Any, ...]]: - for bm in grid_mapping.block_mappings: +) -> None: + del lowering_context # originally needed for forward compat + for bm in block_mappings: rank = len(bm.block_shape) # TODO(necula): add tests for SMEM blocks with trivial windowing # We support scalars too if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SMEM and bm.has_trivial_window()): continue + if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SEMAPHORE: + continue + def err_details(): return (f"Block spec for {bm.origin} in pallas_call {name_and_src_info} " "has block shape " @@ -482,11 +481,28 @@ def err_details(): if not evenly_divisible: raise ValueError( - "The Pallas TPU lowering currently requires that the last two " - "dimensions of your block shape are divisible by 8 and 128 " - "respectively, or be equal to the respective dimensions of the " - "overall array. " - + err_details()) + "The Pallas TPU lowering currently requires that the last two " + "dimensions of your block shape are divisible by 8 and 128 " + "respectively, or be equal to the respective dimensions of the " + "overall array. " + + err_details() + ) + + +def lower_jaxpr_to_module( + lowering_context: mlir.LoweringRuleContext, + ctx: ir.Context, + grid_mapping: pallas_core.GridMapping, + jaxpr: jax_core.Jaxpr, + *, + dimension_semantics: tuple[str | None, ...] | None, + name_and_src_info: pallas_core.NameAndSrcInfo, + mesh: mesh_lib.Mesh | None = None, + for_verification: bool = False, +) -> tuple[Module, tuple[Any, ...]]: + # Verify that we have legal block mappings to catch errors early. + _check_block_mappings(grid_mapping.block_mappings, lowering_context, + name_and_src_info) mosaic_grid_mapping = MosaicGridMapping( jaxpr, grid_mapping, dimension_semantics, mesh) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 71091af27ca3..b09d36a9d3b2 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -30,19 +30,24 @@ from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.pallas import core +from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import lowering from jax._src.pallas.mosaic import verification +from jax._src import tpu_custom_call from jax.experimental import mosaic from jax.experimental.mosaic.dialects import tpu from jax.experimental.pallas import tpu as pltpu -def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray): +def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue): """Casts boolean values to integers. We perform this cast because Mosaic does not directly support bool values for Memrefs. Instead, we load bools as integers and cast them to bools after loading from a memref inside of the kernel. """ + assert isinstance( + x, (jax.Array, jax_core.ShapedArray, jax_core.DShapedArray) + ), type(x) if isinstance(x, jax.Array): if dtypes.issubdtype(x.dtype, jax.numpy.bool_): return x.astype(lowering.BOOL_MEMREF_TYPE) @@ -63,6 +68,41 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray): ) +def _get_memory_space_from_aval( + out_aval: jax_core.AbstractValue, +) -> tpu_custom_call.MemorySpace | None: + if not isinstance(out_aval, jax_core.ShapedArray): + raise ValueError('Memory spaces not defined for non-ShapedArrays') + if not isinstance(out_aval, core.ShapedArrayWithMemorySpace): + # If we are passed a regular old ShapedArray, we don't constrain the + # memory space + return None + # If we are passed an aval with an explicit memory space tag, we use it + # to constrain the memory space. + match out_aval.memory_space: + case None: + return None + case tpu_core.TPUMemorySpace.ANY: + return None + case tpu_core.TPUMemorySpace.VMEM: + return tpu_custom_call.MemorySpace.VMEM + case tpu_core.TPUMemorySpace.SEMAPHORE: + return tpu_custom_call.MemorySpace.SEMAPHORE_MEM + return None + + +def _get_memory_spaces_from_avals( + out_avals: tuple[jax_core.AbstractValue, ...], +) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None: + output_memory_spaces = None + if any( + isinstance(out_aval, core.ShapedArrayWithMemorySpace) + for out_aval in out_avals + ): + output_memory_spaces = tuple(map(_get_memory_space_from_aval, out_avals)) + return output_memory_spaces + + def pallas_call_tpu_lowering_rule( ctx: mlir.LoweringRuleContext, *in_nodes, @@ -74,6 +114,7 @@ def pallas_call_tpu_lowering_rule( interpret: bool, compiler_params: dict[str, Any], cost_estimate: core.CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): """Lowers a pallas_call to a Mosaic TPU custom call.""" del interpret @@ -129,9 +170,6 @@ def lower_module(for_verification: bool): (a[0] + num_dyn_bounds + num_extra_args, a[1]) for a in input_output_aliases ) - out_avals = [jax_core.ShapedArray(bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype) - for bm in grid_mapping.block_mappings_output] if promela_dump_path := _DUMP_PROMELA_TO.value: num_devices = 1 if mesh is None else mesh.devices.size @@ -174,7 +212,7 @@ def lower_module(for_verification: bool): def _maybe_cast_inputs(*args): args = [_maybe_cast_to_int(x) for x in args] return args - kernel_in_avals = [_maybe_cast_to_int(x) for x in ctx.avals_in] # type: ignore + kernel_in_avals = [_maybe_cast_to_int(x) for x in ctx.avals_in] kernel_out_avals = [_maybe_cast_to_int(x) for x in out_avals] cast_ctx = ctx.replace(avals_out=kernel_in_avals) in_nodes = mlir.lower_fun(_maybe_cast_inputs)(cast_ctx, *in_nodes) @@ -182,6 +220,7 @@ def _maybe_cast_inputs(*args): # Dynamic grid bounds have to go at the front. dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:] kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals) + output_memory_spaces = _get_memory_spaces_from_avals(out_avals) if cost_estimate is not None: mosaic_cost_estimate = pltpu.CostEstimate( flops=cost_estimate.flops, @@ -208,7 +247,7 @@ def _maybe_cast_inputs(*args): device_type=mosaic_params.get("device_type"), internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"), collective_id=mosaic_params.get("collective_id", None), - output_memory_spaces=None, # TODO(apaszke,sharadmv): Implement this. + output_memory_spaces=output_memory_spaces, ) _maybe_cast_to_bool = lambda x, aval: x.astype( jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index e8f2384784eb..005e4acdd106 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -40,7 +40,7 @@ SMEM = tpu_core.TPUMemorySpace.SMEM VMEM = tpu_core.TPUMemorySpace.VMEM DMA = tpu_core.SemaphoreType.DMA -REF = tpu_core.MemoryRef +REF = pallas_core.MemoryRef SemaphoreType = tpu_core.SemaphoreType SemaphoreTuple = jax.Array ArrayRef = Union[REF, jax.Array] diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 5b46caf1553a..5b09cad176a6 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -37,8 +37,9 @@ def pallas_call_lowering( grid_mapping: pallas_core.GridMapping, compiler_params: dict[str, Any], cost_estimate: pallas_core.CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): - del interpret + del interpret, out_avals if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "dynamic grid bounds not supported in the Mosaic GPU backend" diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 206c0cdee876..a3ca823a1fa6 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -168,8 +168,12 @@ def _get_next_indices(grid, indices): next_indices.append(jnp.where(carry, 0, i)) return tuple(reversed(next_indices)) -def _pallas_call_impl(*args, **kwargs): - assert False # We always jit a pallas call, we only need the lowering rule +def _pallas_call_impl(*args, **params): + # Call the lowering path + @partial(jax.jit, inline=True) + def _jit_run(*args): + return pallas_call_p.bind(*args, **params) + return _jit_run(*args) def _pallas_call_impl_interpret( @@ -181,8 +185,9 @@ def _pallas_call_impl_interpret( grid_mapping: GridMapping, compiler_params: Any, cost_estimate: CostEstimate, + out_avals: tuple[jax_core.AbstractValue, ...], ): - del compiler_params, cost_estimate + del compiler_params, cost_estimate, out_avals # If we're in interpret mode, we *scan* over the grid and eval the # discharged jaxpr. dynamic_grid_args, args = split_list( # type: ignore @@ -324,10 +329,14 @@ def body(carry): pallas_call_p.def_impl(_pallas_call_impl) -def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_): - return tuple(jax_core.ShapedArray(bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype) - for bm in grid_mapping.block_mappings_output) + +def _pallas_call_abstract_eval( + *avals, out_avals: tuple[jax_core.AbstractValue, ...], **_ +): + del avals + return out_avals + + pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) @@ -343,6 +352,7 @@ def _pallas_call_jvp_rule( interpret, compiler_params: Any, cost_estimate: CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError("interpret with dynamic grid bounds unsupported") @@ -406,6 +416,7 @@ def _pallas_call_jvp_rule( input_output_aliases=(), compiler_params=compiler_params, cost_estimate=jvp_cost_estimate, + out_avals=(*out_avals, *out_avals) ) out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2]) return out_primals, out_tangents @@ -539,6 +550,7 @@ def _batch_with_explicit_loop( interpret: bool, compiler_params: Any, cost_estimate: CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): """Batch the pallas_call by calling it in loop over the batch size. @@ -605,6 +617,7 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: interpret=interpret, compiler_params=compiler_params, cost_estimate=cost_estimate, + out_avals=out_avals, ) for i, batch_out_array in enumerate(batch_out): state[i] = jax.lax.dynamic_update_index_in_dim( @@ -643,6 +656,7 @@ def _pallas_call_batching_rule( interpret: bool, compiler_params: Any, cost_estimate: CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped @@ -685,6 +699,7 @@ def get_size(i, x, d): interpret=interpret, compiler_params=compiler_params, cost_estimate=cost_estimate, + out_avals=out_avals, ) return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out) @@ -717,6 +732,7 @@ def get_size(i, x, d): interpret=interpret, compiler_params=compiler_params, cost_estimate=cost_estimate, + out_avals=out_avals, ) else: pass # No dynamic grid dimensions @@ -750,6 +766,7 @@ def get_size(i, x, d): interpret=interpret, compiler_params=compiler_params, cost_estimate=cost_estimate, + out_avals=out_avals, ) if not dims: @@ -923,7 +940,11 @@ def g(): assert ragged_axis_length is not None args = (ragged_axis_length, *args) - + assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals) + batched_out_avals = tuple( + aval.update(shape=tuple_insert(aval.shape, 0, axis_size)) + for aval in out_avals + ) out = pallas_call_p.bind( *dynamic_grid_args, *args, @@ -937,6 +958,7 @@ def g(): interpret=interpret, compiler_params=compiler_params, cost_estimate=batched_cost_estimate, + out_avals=batched_out_avals, ) return out, (0,) * len(out) @@ -966,6 +988,7 @@ def pallas_call_checkify_rule(error: checkify.Error, interpret: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, + out_avals: tuple[jax_core.AbstractValue, ...], **kwargs): # We implement the checkify rule in 4 steps: # 1) First, trace the kernel body to get the expected error shapes. @@ -1092,11 +1115,13 @@ def _ensure_2d_error_shape(arg): (i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases new_vals_in = [*scalars, *err_vals, *args] + new_out_avals = (*shaped_err_avals, *out_avals) result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in, jaxpr=final_jaxpr, interpret=interpret, grid_mapping=grid_mapping_with_error, input_output_aliases=input_output_aliases_with_error, + out_avals=new_out_avals, **kwargs) errors, results = split_list(result, [num_err_vals]) # TODO(b/350593266): Remove line below once we support ()-shaped scalars. @@ -1225,6 +1250,17 @@ def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params): ) jax_core.custom_typechecks[pallas_call_p] = _pallas_call_typecheck_rule +def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue: + match out_shape: + case jax.ShapeDtypeStruct(): + return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) + case pallas_core.MemoryRef(): + return out_shape.get_array_aval() + case _: + if not (hasattr(out_shape, "shape") and hasattr(out_shape, "dtype")): + raise ValueError(f"Invalid out_shape type: {type(out_shape)}") + return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) + def pallas_call( kernel: Callable[..., None], @@ -1338,17 +1374,15 @@ def pallas_call( out_shape = tuple(out_shape) flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape) out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths) - flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore - for x in flat_out_shapes] - @jax.jit + @partial(jax.jit, inline=True) def wrapped(*args): flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args) in_paths, flat_args = unzip2(flat_args_with_paths) flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) for a in flat_args) - flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype) + flat_out_avals = tuple(_convert_out_shape_to_aval(v) for v in flat_out_shapes) kernel_fun_sig = api_util.fun_signature(kernel) @@ -1403,6 +1437,7 @@ def wrapped(*args): *dynamic_grid_bounds, *index_args, *rest_args, + out_avals=flat_out_avals, jaxpr=jaxpr, name_and_src_info=name_and_src_info, debug=debug, diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 8cba0a36c6e4..89b6c6e14acd 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -828,7 +828,7 @@ def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any: flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) - avals = [t.get_aval() for t in flat_types] + avals = [t.get_ref_aval() for t in flat_types] # Turn the function into a jaxpr. The body of run_scoped may have # effects (IO) on constvars (i.e. variables inherited from the # parent scope). Jax can't reason about effects to references that diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 5ee7077dcc1f..67b0bd326616 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -49,8 +49,9 @@ def pallas_call_lowering( grid_mapping: pallas_core.GridMapping, compiler_params: dict[str, Any], cost_estimate: pallas_core.CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): - del interpret + del interpret, out_avals if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 832f7b7d1184..c81b509d70cf 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -23,11 +23,11 @@ from jax._src.pallas.core import BlockSpec from jax._src.pallas.core import CompilerParams from jax._src.pallas.core import CostEstimate +from jax._src.pallas.core import GridSpec from jax._src.pallas.core import IndexingMode from jax._src.pallas.core import no_block_spec from jax._src.pallas.core import Unblocked from jax._src.pallas.core import unblocked -from jax._src.pallas.core import GridSpec from jax._src.pallas.pallas_call import pallas_call from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.primitives import atomic_add diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index e7fa25a3fc0d..8a1a223ae36e 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -68,3 +68,4 @@ CMEM = TPUMemorySpace.CMEM SMEM = TPUMemorySpace.SMEM VMEM = TPUMemorySpace.VMEM +SEMAPHORE = TPUMemorySpace.SEMAPHORE diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 9b8167527b92..fd229e1673d5 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -422,6 +422,20 @@ jax_test( ] + py_deps("hypothesis"), ) +jax_test( + name = "tpu_pallas_async_test", + srcs = ["tpu_pallas_async_test.py"], + disable_backends = [ + "cpu", + "gpu", + ], + tags = [ + ], + deps = [ + "//jax:pallas_tpu", + ], +) + jax_test( name = "tpu_pallas_mesh_test", srcs = ["tpu_pallas_mesh_test.py"], diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py new file mode 100644 index 000000000000..4f9d591dbea4 --- /dev/null +++ b/tests/pallas/tpu_pallas_async_test.py @@ -0,0 +1,759 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test TPU-specific uses of Pallas async APIs.""" + +import functools +from typing import Any +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() +P = jax.sharding.PartitionSpec +partial = functools.partial + +Future = Any + + +def make_async_copy(target_memory_space=None): + if target_memory_space is None: + target_memory_space = pltpu.ANY + @jax.named_call + def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: + + def copy_start_kernel(x_ref, aliased_x_ref, o_ref, sem): + del aliased_x_ref + pltpu.make_async_copy(x_ref, o_ref, sem).start() + + x, out, sem = pl.pallas_call( + copy_start_kernel, + out_shape=( + jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x + target_memory_space(x.shape, x.dtype), # out + pltpu.SemaphoreType.DMA(()), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=target_memory_space), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ), + input_output_aliases={0: 0}, + )(x) + return x, (out, sem) + + @jax.named_call + def copy_done(x: jax.Array, future: Future) -> jax.Array: + out, sem = future + + def copy_done_kernel(x_ref, o_ref, sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy(x_ref, o_ref, sem).wait() + + out = pl.pallas_call( + copy_done_kernel, + out_shape=target_memory_space(x.shape, x.dtype), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=target_memory_space), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=target_memory_space), + input_output_aliases={1: 0}, + )(x, out, sem) + return out + + return copy_start, copy_done + + +def make_async_slice(index: int): + + def async_slice_start_kernel(x_ref, aliased_x_ref, o_ref, sem): + del aliased_x_ref + pltpu.make_async_copy(x_ref.at[index], o_ref, sem).start() + + def async_slice_done_kernel(x_ref, o_ref, sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy(x_ref.at[index], o_ref, sem).wait() + + @jax.named_call + def async_slice_start(x: jax.Array) -> tuple[jax.Array, Future]: + + x, out, sem = pl.pallas_call( + async_slice_start_kernel, + out_shape=( + jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x + jax.ShapeDtypeStruct(x.shape[1:], x.dtype), # out + pltpu.SemaphoreType.DMA(()), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ), + input_output_aliases={0: 0}, + )(x) + return x, (out, sem) + + @jax.named_call + def async_slice_done( + x: jax.Array, future: Future + ) -> tuple[jax.Array, Future]: + out, sem = future + out = pl.pallas_call( + async_slice_done_kernel, + out_shape=(jax.ShapeDtypeStruct(x.shape[1:], x.dtype)), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=(pl.BlockSpec(memory_space=pltpu.ANY)), + input_output_aliases={1: 0}, + )(x, out, sem) + return out + + return async_slice_start, async_slice_done + + +def make_async_dynamic_slice(index: jax.Array): + + def async_dslice_start_kernel(index_ref, x_ref, aliased_x_ref, o_ref, sem): + del aliased_x_ref + pltpu.make_async_copy(x_ref.at[index_ref[0]], o_ref, sem).start() + + def async_dslice_done_kernel(x_ref, o_ref, sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy(x_ref.at[0], o_ref, sem).wait() + + @jax.named_call + def async_dslice_start(x: jax.Array) -> tuple[jax.Array, Future]: + + x, out, sem = pl.pallas_call( + async_dslice_start_kernel, + out_shape=( + jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x + jax.ShapeDtypeStruct(x.shape[1:], x.dtype), # out + pltpu.SemaphoreType.DMA(()), + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ), + ), + input_output_aliases={1: 0}, + )(index[None], x) + return x, (out, sem) + + @jax.named_call + def async_dslice_done( + x: jax.Array, future: Future + ) -> tuple[jax.Array, Future]: + out, sem = future + out = pl.pallas_call( + async_dslice_done_kernel, + out_shape=(jax.ShapeDtypeStruct(x.shape[1:], x.dtype)), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=(pl.BlockSpec(memory_space=pltpu.ANY)), + input_output_aliases={1: 0}, + )(x, out, sem) + return out + + return async_dslice_start, async_dslice_done + + +class PallasCallAsyncCopyTest(parameterized.TestCase): + # TODO(b/368123537): add more tests + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('DMAs only guaranteed to work ou TPU v4+') + + def test_basic_async_copy(self): + @jax.jit + def f(x): + copy_start, copy_done = make_async_copy() + x, fut = copy_start(x) + y = copy_done(x, fut) + return y + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_multiple_async_copy(self): + @jax.jit + def f(x): + copy_start, copy_done = make_async_copy() + x, fut = copy_start(x) + x2, fut2 = copy_start(x) + y = copy_done(x, fut) + y2 = copy_done(x2, fut2) + return y, y2 + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y, y2 = f(x) + np.testing.assert_array_equal(y, x) + np.testing.assert_array_equal(y2, x) + + def test_async_slice(self): + @jax.jit + def f(x): + async_slice_start, async_slice_done = make_async_slice(2) + x, fut = async_slice_start(x) + y = async_slice_done(x, fut) + return y + + x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x[2]) + + def test_async_dynamic_slice(self): + @jax.jit + def f(x, i): + async_slice_start, async_slice_done = make_async_dynamic_slice(i) + x, fut = async_slice_start(x) + y = async_slice_done(x, fut) + return y + + x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32) + y = f(x, 2) + np.testing.assert_array_equal(y, x[2]) + + def test_multi_async_dynamic_slice(self): + @jax.jit + def f(x, i, j): + async_slice_start, async_slice_done = make_async_dynamic_slice(i) + async_slice_start2, async_slice_done2 = make_async_dynamic_slice(j) + x, fut = async_slice_start(x) + x2, fut2 = async_slice_start2(x) + y = async_slice_done(x, fut) + y2 = async_slice_done2(x2, fut2) + return y, y2 + + x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32) + y, y2 = f(x, 2, 3) + np.testing.assert_array_equal(y, x[2]) + np.testing.assert_array_equal(y2, x[3]) + + def test_basic_async_copy_into_vmem(self): + @jax.jit + def f(x): + copy_start, copy_done = make_async_copy(pltpu.VMEM) + x, fut = copy_start(x) + y = copy_done(x, fut) + return y + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_multiple_async_copy_into_vmem(self): + @jax.jit + def f(x): + copy_start, copy_done = make_async_copy(pltpu.VMEM) + x1, fut = copy_start(x) + x2, fut2 = copy_start(x) + y = copy_done(x1, fut) + y2 = copy_done(x2, fut2) + return y, y2 + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y, y2 = f(x) + np.testing.assert_array_equal(y, x) + np.testing.assert_array_equal(y2, x) + + def test_copy_in_a_loop(self): + + @jax.jit + def f(x): + def body(_, carry): + x = carry + copy_start, copy_done = make_async_copy() + x, fut = copy_start(x) + y = copy_done(x, fut) + return y + x = jax.lax.fori_loop(0, x.shape[0], body, x) + return x + + x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_staggered_copy_in_a_loop(self): + + @jax.jit + def f(x): + copy_start, copy_done = make_async_copy() + x, fut = copy_start(x) + def body(_, carry): + x, fut = carry + y = copy_done(x, fut) + y, fut = copy_start(y) + return y, fut + # We *must* use unroll > 2 here because of aliasing constraints. XLA will + # introduce copies of the active buffer with unroll=1. + y, fut = jax.lax.fori_loop(0, x.shape[0] - 1, body, (x, fut), unroll=2) + x = copy_done(y, fut) + return x + + x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_full_copy_in_a_loop(self): + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def body(i, carry): + x, ys = carry + copy_start, copy_done = make_async_dynamic_slice(i) + x, fut = copy_start(x) + y = copy_done(x, fut) + ys = ys.at[i].set(y) + return x, ys + _, y = jax.lax.fori_loop(0, x.shape[0], body, (x, y)) + return y + + x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_staggered_full_copy_in_a_loop(self): + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + copy_start, _ = make_async_dynamic_slice(jnp.array(0)) + x, fut = copy_start(x) + def body(i, carry): + x, fut, ys = carry + _, copy_done = make_async_dynamic_slice(i) + y = copy_done(x, fut) + copy_start, _ = make_async_dynamic_slice(i + 1) + ys = ys.at[i].set(y) + x, fut = copy_start(x) + return x, fut, ys + # We can use unroll=1 here because we have the ys.at[i].set(y) in the + # middle + x, fut, ys = jax.lax.fori_loop(0, x.shape[0] - 1, body, (x, fut, y), + unroll=1) + _, copy_done = make_async_dynamic_slice(x.shape[0] - 1) + y = copy_done(x, fut) + ys = ys.at[x.shape[0] - 1].set(y) + return ys + + x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + +def make_async_remote_copy(axis_name: str, direction: str = 'right', + target_memory_space=None): + if target_memory_space is None: + target_memory_space = pltpu.ANY + @jax.named_call + def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: + + def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): + del aliased_x_ref + axis_size = jax.lax.psum(1, axis_name) + left_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size + ) + right_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) + 1, axis_size + ) + if direction == 'right': + src_neighbor = left_neighbor + dst_neighbor = right_neighbor + else: + src_neighbor = right_neighbor + dst_neighbor = left_neighbor + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=src_neighbor, core_index=0) + pltpu.semaphore_wait(barrier_sem, 1) + pltpu.make_async_remote_copy( + x_ref, o_ref, send_sem, recv_sem, device_id=dst_neighbor, + ).start() + + x, out, send_sem, recv_sem = pl.pallas_call( + copy_start_kernel, + out_shape=( + jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x + target_memory_space(x.shape, x.dtype), # out + pltpu.SemaphoreType.DMA(()), # send_sem + pltpu.SemaphoreType.DMA(()), # recv_sem + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=target_memory_space), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ), + input_output_aliases={0: 0}, + compiler_params=pltpu.TPUCompilerParams(collective_id=0), + )(x) + return x, (out, send_sem, recv_sem) + + @jax.named_call + def send_done(x: jax.Array, future: Future) -> jax.Array: + _, send_sem, _ = future + + def send_done_kernel(x_ref, send_sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy(x_ref, x_ref, send_sem).wait() + + x = pl.pallas_call( + send_done_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0: 0}, + )(x, send_sem) + return x + + @jax.named_call + def recv_done(x: jax.Array, future: Future) -> jax.Array: + out, _, recv_sem = future + + def send_done_kernel(x_ref, o_ref, send_sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy(x_ref, o_ref, send_sem).wait() + + out = pl.pallas_call( + send_done_kernel, + out_shape=target_memory_space(x.shape, x.dtype), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=target_memory_space), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=target_memory_space), + input_output_aliases={1: 0}, + )(x, out, recv_sem) + return out + + return copy_start, send_done, recv_done + + +def make_bidi_collective_permute(axis_name: str): + @jax.named_call + def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: + + def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): + del aliased_x_ref + axis_size = jax.lax.psum(1, axis_name) + left_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size + ) + right_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) + 1, axis_size + ) + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor, core_index=0) + pltpu.semaphore_signal( + barrier_sem, device_id=right_neighbor, core_index=0 + ) + pltpu.semaphore_wait(barrier_sem, 2) + assert x.shape[0] % 2 == 0, x.shape + pltpu.make_async_remote_copy( + x_ref.at[pl.ds(0, x.shape[0] // 2)], + o_ref.at[pl.ds(0, x.shape[0] // 2)], + right_sems[0], + right_sems[1], + device_id=right_neighbor, + ).start() + pltpu.make_async_remote_copy( + x_ref.at[pl.ds(x.shape[0] // 2, x.shape[0] // 2)], + o_ref.at[pl.ds(x.shape[0] // 2, x.shape[0] // 2)], + left_sems[0], + left_sems[1], + device_id=left_neighbor, + ).start() + + x, out, left_sems, right_sems = pl.pallas_call( + copy_start_kernel, + out_shape=( + jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x + pltpu.ANY(x.shape, x.dtype), # out + (pltpu.SemaphoreType.DMA(()),) * 2, # left_sems + (pltpu.SemaphoreType.DMA(()),) * 2, # right_sems + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2, + (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2, + ), + input_output_aliases={0: 0}, + compiler_params=pltpu.TPUCompilerParams(collective_id=0), + )(x) + return x, (out, left_sems, right_sems) + + @jax.named_call + def send_done(x: jax.Array, future: Future) -> jax.Array: + _, (send_left_sem, _), (send_right_sem, _) = future + + def send_done_kernel(x_ref, send_left_sem, send_right_sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy( + x_ref.at[x_ref.shape[0] // 2 :], + x_ref.at[x_ref.shape[0] // 2 :], + send_left_sem, + ).wait() + pltpu.make_async_copy( + x_ref.at[x_ref.shape[0] // 2 :], + x_ref.at[x_ref.shape[0] // 2 :], + send_right_sem, + ).wait() + + x = pl.pallas_call( + send_done_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0: 0}, + )(x, send_left_sem, send_right_sem) + return x + + @jax.named_call + def recv_done(x: jax.Array, future: Future) -> jax.Array: + out, (_, recv_left_sem), (_, recv_right_sem) = future + + def recv_done_kernel(o_ref, x_ref, recv_left_sem, recv_right_sem, + aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy( + x_ref.at[o_ref.shape[0] // 2 :], + o_ref.at[o_ref.shape[0] // 2 :], + recv_left_sem, + ).wait() + pltpu.make_async_copy( + x_ref.at[o_ref.shape[0] // 2 :], + o_ref.at[o_ref.shape[0] // 2 :], + recv_right_sem, + ).wait() + + out = pl.pallas_call( + recv_done_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0: 0}, + )(out, x, recv_left_sem, recv_right_sem) + return out + return copy_start, send_done, recv_done + + +class PallasCallRemoteAsyncCopyTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('DMAs only guaranteed to work ou TPU v4+') + if jax.device_count() < 2: + self.skipTest('Test only works with >2 devices') + + def test_basic_remote_copy(self): + + mesh = jax.make_mesh((jax.device_count(),), ('x',)) + + @jax.jit + @partial( + shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), + check_rep=False, + ) + def f(x): + copy_start, send_done, recv_done = make_async_remote_copy('x') + x, fut = copy_start(x) + x = send_done(x, fut) + y = recv_done(x, fut) + return y + + x = jax.random.normal( + jax.random.key(0), (jax.device_count(), 8, 128), dtype=jnp.float32 + ) + y = f(x) + expected = jnp.roll(x, shift=1, axis=0) + np.testing.assert_array_equal(y, expected) + + def test_multi_remote_copy(self): + + mesh = jax.make_mesh((jax.device_count(),), ('x',)) + + @jax.jit + @partial( + shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), + check_rep=False, + ) + def f(x): + copy_start, send_done, recv_done = make_async_remote_copy( + 'x', direction='right' + ) + copy_start2, send_done2, recv_done2 = make_async_remote_copy( + 'x', direction='left' + ) + x, fut = copy_start(x) + x, fut2 = copy_start2(x) + x = send_done(x, fut) + x = send_done2(x, fut2) + y = recv_done(x, fut) + y2 = recv_done2(x, fut2) + return y, y2 + + x = jax.random.normal( + jax.random.key(0), (jax.device_count(), 8, 128), dtype=jnp.float32 + ) + y, y2 = f(x) + y_expected = jnp.roll(x, shift=1, axis=0) + y2_expected = jnp.roll(x, shift=-1, axis=0) + np.testing.assert_array_equal(y, y_expected) + np.testing.assert_array_equal(y2, y2_expected) + + def test_basic_collective_permute_loop(self): + + mesh = jax.make_mesh((jax.device_count(),), ('x',)) + + @jax.jit + @partial( + shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), + check_rep=False, + ) + def f(x): + copy_start, send_done, recv_done = make_async_remote_copy('x') + def body(_, x): + x, fut = copy_start(x) + x = send_done(x, fut) + y = recv_done(x, fut) + return y + # Send all the way around except for one step + return jax.lax.fori_loop(0, jax.device_count() - 1, body, x) + x = jax.random.normal( + jax.random.key(0), (jax.device_count(), 8, 128), dtype=jnp.float32 + ) + y = f(x) + expected = jnp.roll(x, shift=-1, axis=0) + np.testing.assert_array_equal(y, expected) + + def test_staggered_collective_permute_loop(self): + + mesh = jax.make_mesh((jax.device_count(),), ('x',)) + + @jax.jit + @partial( + shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), + check_rep=False, + ) + def f(x): + assert x.shape[0] == 1 + copy_start, send_done, recv_done = make_async_remote_copy('x') + x, fut = copy_start(x) + def body(_, carry): + x, fut = carry + x = send_done(x, fut) + y = recv_done(x, fut) + y, fut = copy_start(y) + return y, fut + # Send all the way around except for one step + x, fut = jax.lax.fori_loop(0, jax.device_count() - 2, body, (x, fut), + unroll=2) + x = send_done(x, fut) + y = recv_done(x, fut) + return y + + n_devices = jax.device_count() + x = jax.random.normal( + jax.random.key(0), (n_devices, 8, 128), dtype=jnp.float32 + ) + y = f(x) + expected = jnp.roll(x, shift=-1, axis=0) + np.testing.assert_array_equal(y, expected) + + def test_bidi_collective_permute_loop(self): + mesh = jax.make_mesh((jax.device_count(),), ('x',)) + + @jax.jit + @partial( + shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), + check_rep=False, + ) + def f(x): + assert x.shape[0] == 1 + x = x[0] + copy_start, send_done, recv_done = make_bidi_collective_permute('x') + def body(_, x): + x, fut = copy_start(x) + x = send_done(x, fut) + y = recv_done(x, fut) + return y + # Send all the way around except for one step + y = jax.lax.fori_loop(0, jax.device_count() - 1, body, x) + return y[None] + x = jax.random.normal( + jax.random.key(0), (jax.device_count(), 16, 128), dtype=jnp.float32 + ) + y = f(x) + expected = jnp.concatenate([ + jnp.roll(x[:, :8], axis=0, shift=-1), + jnp.roll(x[:, 8:], axis=0, shift=1), + ], axis=1) + np.testing.assert_array_equal(y, expected) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 8a8f74663ffcf494420fbff8687d85cd76dad7e7 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 19 Sep 2024 02:53:33 -0700 Subject: [PATCH 558/702] Use `MemoryRef` from Pallas instead of the Mosaic GPU-specific one PiperOrigin-RevId: 676336451 --- jax/_src/pallas/mosaic_gpu/core.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 34ad5acf34d6..dc698b8747d9 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -60,7 +60,7 @@ def __str__(self) -> str: def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. - return MemoryRef(shape, dtype, memory_space=self) + return pallas_core.MemoryRef(shape, dtype, memory_space=self) class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol): @@ -150,20 +150,6 @@ def to_block_mapping( ) -# TODO(b/354568887): Cosolidate this with TPU's MemoryRef. -@dataclasses.dataclass(frozen=True) -class MemoryRef: - """Like jax.ShapeDtypeStruct but with memory spaces.""" - - shape: tuple[int, ...] - dtype: jnp.dtype - memory_space: GPUMemorySpace = dataclasses.field(kw_only=True) - - def get_aval(self) -> AbstractMemoryRef: - return AbstractMemoryRef( - jax_core.ShapedArray(self.shape, self.dtype), self.memory_space - ) - GMEM = GPUMemorySpace.GMEM SMEM = GPUMemorySpace.SMEM REGS = GPUMemorySpace.REGS @@ -189,7 +175,7 @@ class Barrier: num_arrivals: int num_barriers: int = 1 - def get_aval(self) -> AbstractMemoryRef: + def get_ref_aval(self) -> AbstractMemoryRef: aval = jax_core.ShapedArray( [self.num_barriers], BarrierType(self.num_arrivals) ) From 3ccca54e426c7fef772f226319749394ac37a406 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 19 Sep 2024 04:39:11 -0700 Subject: [PATCH 559/702] [Pallas TPU] Fix some issues introduced by the recent changes The new Pallas-specific aval interacts very badly with the default abstract eval rules of most lax ops, causing frequent failures. PiperOrigin-RevId: 676362377 --- jax/_src/pallas/core.py | 6 ++++-- jax/_src/pallas/pallas_call.py | 8 +++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 00bbbbe888d5..1a956de1f7a9 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -161,7 +161,7 @@ def str_short(self, short_dtypes=False): else: sharding_str = "" memoryspace_str = ( - "" if self.memory_space is None else f"{self.memory_space}>" + "" if self.memory_space is None else f"<{self.memory_space}>" ) return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}" @@ -206,8 +206,10 @@ def get_array_aval(self) -> jax_core.ShapedArray: ) def get_ref_aval(self) -> AbstractMemoryRef: + # TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we + # try to apply JAX ops to it. return AbstractMemoryRef( - ShapedArrayWithMemorySpace(self.shape, self.dtype), self.memory_space) + jax_core.ShapedArray(self.shape, self.dtype), self.memory_space) class AbstractMemoryRef(state.AbstractRef): diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index a3ca823a1fa6..1c10d2bda9e9 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -334,7 +334,13 @@ def _pallas_call_abstract_eval( *avals, out_avals: tuple[jax_core.AbstractValue, ...], **_ ): del avals - return out_avals + # Make sure we don't return ShapedArrayWithMemorySpace to the outside world. + return [ + jax_core.ShapedArray(a.shape, a.dtype, a.weak_type) + if isinstance(a, pallas_core.ShapedArrayWithMemorySpace) + else a + for a in out_avals + ] pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) From 3f23866f75f3c209f16a40324425c11783161e99 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 19 Sep 2024 05:26:56 -0700 Subject: [PATCH 560/702] Enable Pallas `ops_test` on GPU in 64-bit mode. Previously, the 64-bit tests are skipped in `PallasBaseTest`, which disables both `OpsTest` and `OpsExtraTest`. This PR enables the 64-bit tests for `OpsTest`, and only disables it for `OpsExtraTest`. PiperOrigin-RevId: 676373904 --- tests/pallas/BUILD | 4 ++-- tests/pallas/ops_test.py | 28 +++++++++++++++------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index fd229e1673d5..6804d91675c1 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -98,13 +98,13 @@ jax_test( disable_configs = [ "gpu", "gpu_x32", - "gpu_a100", "gpu_p100", "gpu_p100_x32", - "gpu_h100", ], enable_configs = [ + "gpu_a100", "gpu_a100_x32", + "gpu_h100", "gpu_h100_x32", ], shard_count = { diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 627ee0e8a227..8495d6a0524e 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -31,6 +31,7 @@ from jax import lax from jax import random from jax._src import config +from jax._src import dtypes from jax._src import linear_util as lu from jax._src import state from jax._src import test_util as jtu @@ -59,6 +60,10 @@ jtu.setup_hypothesis(max_examples=50) +intx = dtypes.canonicalize_dtype(jnp.int64) +floatx = dtypes.canonicalize_dtype(jnp.float64) + + def smem_on_tpu(): if jtu.test_device_matches(["tpu"]): return pltpu.SMEM @@ -245,8 +250,6 @@ class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False def setUp(self): - if jax.config.x64_enabled: - self.skipTest("Only works in 32-bit") if not self.INTERPRET: if jtu.device_under_test() == "cpu": self.skipTest("Only interpret mode supported on CPU") @@ -263,11 +266,6 @@ def pallas_call(cls, *args, **kwargs): class OpsTest(PallasBaseTest): - def setUp(self): - super().setUp() - if jax.config.x64_enabled: - self.skipTest("Only works in 32-bit") - @parameterized.named_parameters( (fn.__name__, fn, dtype) for fn, dtype in [ (lax.pow, jnp.float32), @@ -340,7 +338,7 @@ def kernel(x_ref, y_ref, o_ref): result = self.pallas_call( kernel, - out_shape=jax.ShapeDtypeStruct([1, 128], jnp.int32), + out_shape=jax.ShapeDtypeStruct([1, 128], intx), in_specs=[ pl.BlockSpec(memory_space=smem_on_tpu()), pl.BlockSpec(memory_space=smem_on_tpu()), @@ -435,13 +433,15 @@ def kernel(x_ref, ones_ref, o_ref): float_value = jnp.where(reduced_as_bool, 1.0, 0.0) o_ref[0, 0] = float_value[0, 0] - if input_type == 'all_true': + if input_type == "all_true": x = jnp.ones((8, 128), dtype=jnp.float32) - elif input_type == 'all_false': + elif input_type == "all_false": x = jnp.zeros((8, 128), dtype=jnp.float32) - elif input_type == 'one_false': + elif input_type == "one_false": x = jnp.ones((8, 128), dtype=jnp.float32) x = x.at[0, 0].set(0.0) + else: + raise ValueError(f"Unknown input type: {input_type}") ones = jnp.ones_like(x) result = self.pallas_call( @@ -451,7 +451,7 @@ def kernel(x_ref, ones_ref, o_ref): pl.BlockSpec((8, 128), lambda *_: (0, 0)), ], out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=smem_on_tpu()), - out_shape=jax.ShapeDtypeStruct([1, 1], jnp.float32), + out_shape=jax.ShapeDtypeStruct([1, 1], floatx), grid=(1,), )(x, ones) np.testing.assert_array_equal(result[0, 0], float(expected_result)) @@ -473,7 +473,7 @@ def kernel(x_ref, o_ref): pl.BlockSpec((8, 128), lambda *_: (0, 0)), ], out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()), - out_shape=jax.ShapeDtypeStruct([1, 1], jnp.float32), + out_shape=jax.ShapeDtypeStruct([1, 1], floatx), grid=(1,), )(x) @@ -746,6 +746,8 @@ class OpsExtraTest(PallasBaseTest): def setUp(self): super().setUp() + if jax.config.x64_enabled: + self.skipTest("Only works in 32-bit") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: # TODO: most tests fail on TPU in non-interpret mode self.skipTest("On TPU the test works only in interpret mode") From 22a7c73d27977d22b1331462d9c44c8b26ecfa1a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 19 Sep 2024 05:29:11 -0700 Subject: [PATCH 561/702] Added support for `lax.fori_loop` in the Pallas Mosaic GPU lowering This, coupled with `plgpu.async_copy` and barriers, should be enough to sketch a simple pipelined loop in the kernel. PiperOrigin-RevId: 676374408 --- jax/_src/pallas/mosaic_gpu/lowering.py | 97 ++++++++++++++++++++++++-- jax/_src/pallas/triton/lowering.py | 14 ++-- jax/experimental/mosaic/gpu/utils.py | 5 +- tests/pallas/mosaic_gpu_test.py | 12 ++++ 4 files changed, 111 insertions(+), 17 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 3afc62aebfcf..5eaf6e5233cb 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -24,18 +24,19 @@ from typing import Any, cast import jax +from jax import lax from jax._src import core as jax_core from jax._src import pjit from jax._src import util from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe -from jax._src.lax import lax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives +from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.state import primitives as sp from jax.experimental.mosaic import gpu as mosaic_gpu @@ -320,7 +321,7 @@ def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value): ) program_ids = map(_program_id, range(len(grid_mapping.grid))) start_indices = map( - functools.partial(_eval_index_map, module_ctx, launch_ctx, program_ids), + partial(_eval_index_map, module_ctx, launch_ctx, program_ids), block_mappings, ) in_start_indices, out_start_indices = util.split_list( @@ -718,6 +719,80 @@ def _run_scoped_lowering_rule( return outs +def _lower_jaxpr_to_for_loop( + ctx: LoweringRuleContext, + jaxpr: jax_core.Jaxpr, + start: ir.Value, + length: ir.Value, + consts, + *args, + has_loop_index: bool, +): + + @mgpu.fori(length, [*args]) + def loop(loop_index, body_args): + if has_loop_index: + loop_index = arith_dialect.addi(loop_index, start) + jaxpr_args = [*consts, loop_index, *body_args] + else: + jaxpr_args = [*consts, *body_args] + return lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args + ) + + return loop.results + + +@register_lowering_rule(lax.scan_p) +def _scan_lowering_rule( + ctx: LoweringRuleContext, + *args, + jaxpr: jax_core.ClosedJaxpr, + linear: tuple[bool, ...], + length: int, + reverse: bool, + unroll: bool | int, + num_consts: int, + num_carry: int, + _split_transpose: bool, +): + # Can only handle fori_loop-like scans. + if ( + (num_extensive := len(args) - num_consts - num_carry) + or reverse + or unroll != 1 + ): + raise NotImplementedError + del linear, num_extensive, reverse, unroll + + jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts + if jaxpr_consts: + raise NotImplementedError + del jaxpr_consts + + jaxpr, has_loop_index = pallas_utils.pattern_match_scan_to_fori_loop( + jaxpr, num_consts, num_carry + ) + consts, args = util.split_list(args, [num_consts]) + _consts_avals, arg_avals = util.split_list(ctx.avals_in, [num_consts]) + if has_loop_index: + start, *args = args + index_aval, *_arg_avals = arg_avals + start = _ensure_ir_value(start, index_aval) + length = _ir_constant(length, start.type) + else: + start = _i32_constant(0) + length = _i32_constant(length) + for_out = _lower_jaxpr_to_for_loop( + ctx, jaxpr, start, length, consts, *args, has_loop_index=has_loop_index + ) + if has_loop_index: + # Need to return the final loop index value if the outer scan expects + # it as an output. + return [length, *for_out] + return for_out + + def _bcast( x: ir.Value, y: ir.Value, @@ -750,11 +825,19 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: _ir_constant(x, mlir.dtype_to_ir_type(dtype)), () ) elif isinstance(x, ir.Value): - if isinstance(x.type, (ir.IntegerType, ir.FloatType)): + if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)): return mgpu.FragmentedArray.splat(x, ()) raise NotImplementedError(f"Unsupported type: {type(x)}") +def _ensure_ir_value(x: object, aval: jax_core.ShapedArray) -> ir.Value: + if isinstance(x, ir.Value): + return x + elif isinstance(x, (np.number, np.ndarray, int, float)): + return _ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)) + raise NotImplementedError(f"Unsupported type: {type(x)}") + + def _ir_constant(v: object, t: ir.Type) -> ir.Value: if isinstance(v, (np.number, np.ndarray, int, float)): if isinstance(t, (ir.IntegerType, ir.IndexType)): @@ -766,8 +849,12 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value: raise NotImplementedError(f"Unsupported constant: {v!r}") -def _i32_constant(v: object) -> ir.Value: - return _ir_constant(v, ir.IntegerType.get_signless(32)) +def _i32_constant(v: int) -> ir.Value: + return arith_dialect.constant(ir.IntegerType.get_signless(32), v) + + +def _i64_constant(v: int) -> ir.Value: + return arith_dialect.constant(ir.IntegerType.get_signless(64), v) def _as_index(v: int | ir.Value) -> ir.Value: diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 0a23e512dfb3..948ecf74a8e6 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -2365,10 +2365,8 @@ def _lower_jaxpr_to_for_loop( else: jaxpr_args = [*consts, *for_body_args] all_out = lower_jaxpr_to_triton_ir( - ctx.context, - jaxpr, - ctx.block_infos, - *jaxpr_args) + ctx.context, jaxpr, ctx.block_infos, *jaxpr_args + ) scf_dialect.yield_(all_out) return list(for_op.results_) @@ -2405,11 +2403,9 @@ def _scan_lowering_rule( args = map(_ensure_ir_value, args, ctx.avals_in) consts, args = util.split_list(args, [num_consts]) if has_loop_index: - lb, *args = args - lower_bound = lb - ub = _add(lb, _ir_constant(length, lb.type)) - upper_bound = ub - bound_type = ub.type + lower_bound, *args = args + upper_bound = _add(lower_bound, _ir_constant(length, lower_bound.type)) + bound_type = lower_bound.type else: lower_bound = _i32_constant(0) upper_bound = _i32_constant(length) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 546411c82c4c..30b8ca5cfb14 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -154,9 +154,8 @@ def fori(bound, carrys): flat_carrys, carry_treedef = jax.tree.flatten(carrys) def wrapper(f): - index = ir.IndexType.get() - c0 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0)) - c1 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 1)) + c0 = arith.constant(bound.type, 0) + c1 = arith.constant(bound.type, 1) for_op = scf.ForOp(c0, bound, c1, flat_carrys) with ir.InsertionPoint(for_op.body): i = for_op.induction_variable diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 746c4e93387b..17ef26c7f9b3 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -311,6 +311,18 @@ def kernel(x_ref, o_ref): result = kernel(x) self.assertEqual(result.shape, (4, 2, 64, 64)) + def test_fori_loop(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, o_ref): + # Equivalent to x_ref[...] + 2 + 3. + o_ref[...] = jax.lax.fori_loop(2, 4, lambda i, x: x + i, x_ref[...]) + + x = jnp.arange(256).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 2.0 + 3.0) + if __name__ == "__main__": absltest.main() From d0338f5d13075a781ad81600a6776d9584918597 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Thu, 19 Sep 2024 14:51:02 +0000 Subject: [PATCH 562/702] [ffi] Support handler bundles in GPU plugin extension --- jaxlib/cuda_plugin_extension.cc | 73 ++++++++++++++++++++++++++++++--- jaxlib/rocm_plugin_extension.cc | 73 ++++++++++++++++++++++++++++++--- 2 files changed, 134 insertions(+), 12 deletions(-) diff --git a/jaxlib/cuda_plugin_extension.cc b/jaxlib/cuda_plugin_extension.cc index 0bb8cbbace65..ea81109b36c0 100644 --- a/jaxlib/cuda_plugin_extension.cc +++ b/jaxlib/cuda_plugin_extension.cc @@ -38,7 +38,7 @@ namespace xla { namespace { absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, const char* fn_name_c_str, - size_t fn_name_size, nb::capsule fn, + size_t fn_name_size, nb::object fn, int api_version, XLA_FFI_Handler_Traits traits) { if (c_api->extension_start == nullptr) { @@ -54,6 +54,8 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, if (next == nullptr) { return Unimplemented("The plugin does not have a custom call extension."); } + PJRT_Gpu_Register_Custom_Call* register_custom_call = + reinterpret_cast(next)->custom_call; if (traits != 0) { return Unimplemented("The plugin does not support custom call traits."); @@ -63,14 +65,73 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; args.function_name = fn_name_c_str; args.function_name_size = fn_name_size; + #if PJRT_API_GPU_EXTENSION_VERSION >= 1 args.api_version = api_version; #endif - args.custom_call_function = static_cast(fn.data()); - RETURN_STATUS_IF_PJRT_ERROR( - reinterpret_cast(next)->custom_call(&args), - c_api); + + auto as_capsule = [](nb::object obj) -> absl::StatusOr { + nb::capsule capsule; + if (!nb::try_cast(obj, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration requires handlers as PyCapsules"); + } + return capsule; + }; + +#if PJRT_API_GPU_EXTENSION_VERSION <= 1 + TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn)); + args.custom_call_function = fn_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); return absl::OkStatus(); +#else + args.handler_instantiate = nullptr; + args.handler_prepare = nullptr; + args.handler_initialize = nullptr; + args.handler_execute = nullptr; + + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn)); + args.handler_execute = capsule_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + auto capsule_execute = as_capsule(fn); + if (capsule_execute.ok()) { + args.handler_execute = capsule_execute->data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name])); + return capsule.data(); + }; + + TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute")); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +#endif } nb::dict Registrations() { @@ -97,7 +158,7 @@ NB_MODULE(cuda_plugin_extension, m) { tsl::ImportNumpy(); m.def( "register_custom_call_target", - [](nb::capsule c_api, nb::object fn_name_py, nb::capsule fn, + [](nb::capsule c_api, nb::object fn_name_py, nb::object fn, nb::str xla_platform_name, int api_version, XLA_FFI_Handler_Traits traits) { const char* fn_name_c_str; diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm_plugin_extension.cc index 8a732380d1e2..0100b37b22e9 100644 --- a/jaxlib/rocm_plugin_extension.cc +++ b/jaxlib/rocm_plugin_extension.cc @@ -35,7 +35,7 @@ namespace nb = nanobind; namespace xla { namespace { absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, - nb::capsule fn, int api_version, + nb::object fn, int api_version, XLA_FFI_Handler_Traits traits) { if (c_api->extension_start == nullptr) { return Unimplemented("The plugin does not have extension."); @@ -50,6 +50,8 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, if (next == nullptr) { return Unimplemented("The plugin does not have a custom call extension."); } + PJRT_Gpu_Register_Custom_Call* register_custom_call = + reinterpret_cast(next)->custom_call; if (traits != 0) { return Unimplemented("The plugin does not support custom call traits."); @@ -59,14 +61,73 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; args.function_name = fn_name.c_str(); args.function_name_size = nb::len(fn_name); + #if PJRT_API_GPU_EXTENSION_VERSION >= 1 args.api_version = api_version; #endif - args.custom_call_function = static_cast(fn.data()); - RETURN_STATUS_IF_PJRT_ERROR( - reinterpret_cast(next)->custom_call(&args), - c_api); + + auto as_capsule = [](nb::object obj) -> absl::StatusOr { + nb::capsule capsule; + if (!nb::try_cast(obj, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration requires handlers as PyCapsules"); + } + return capsule; + }; + +#if PJRT_API_GPU_EXTENSION_VERSION <= 1 + TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn)); + args.custom_call_function = fn_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); return absl::OkStatus(); +#else + args.handler_instantiate = nullptr; + args.handler_prepare = nullptr; + args.handler_initialize = nullptr; + args.handler_execute = nullptr; + + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn)); + args.handler_execute = capsule_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + auto capsule_execute = as_capsule(fn); + if (capsule_execute.ok()) { + args.handler_execute = capsule_execute->data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name])); + return capsule.data(); + }; + + TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute")); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +#endif } nb::dict Registrations() { @@ -118,7 +179,7 @@ NB_MODULE(rocm_plugin_extension, m) { tsl::ImportNumpy(); m.def( "register_custom_call_target", - [](nb::capsule c_api, nb::str fn_name, nb::capsule fn, + [](nb::capsule c_api, nb::str fn_name, nb::object fn, nb::str xla_platform_name, int api_version, XLA_FFI_Handler_Traits traits) { xla::ThrowIfError(RegisterCustomCallTarget( From de23fdb5adcc8e9ea9e9ec6bd22ec1425864d5b6 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 9 Jul 2024 00:27:10 +0800 Subject: [PATCH 563/702] [Pallas TPU] Add lowering for 64 bit --- jax/_src/pallas/mosaic/lowering.py | 8 ++- jax/_src/pallas/triton/lowering.py | 2 +- jax/_src/pallas/utils.py | 79 +++++++++++++++++++++++++++++- tests/pallas/ops_test.py | 22 ++++++--- 4 files changed, 96 insertions(+), 15 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 7cc8de90b6b8..775f0c1f8256 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2532,11 +2532,9 @@ def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d): def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x): - (x_aval,) = ctx.avals_in - if x_aval.dtype == jnp.float32: - return lower_fun(pallas_utils.erf_inv_32_lowering_helper, multiple_results=False)(ctx, x) - else: - raise NotImplementedError + return lower_fun( + pallas_utils.erf_inv_lowering_helper, multiple_results=False, + )(ctx, x) lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 948ecf74a8e6..ac28bd21a3dc 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1355,7 +1355,7 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): register_lowering(lax.erf_inv_p)( - lower_fun(pallas_utils.erf_inv_32_lowering_helper, multiple_results=False) + lower_fun(pallas_utils.erf_inv_lowering_helper, multiple_results=False) ) diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index e1fbbde61c56..e485537216ca 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -186,7 +186,7 @@ def pattern_match_while_to_fori_loop( # based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802 -def erf_inv_32_lowering_helper(x): +def _erf_inv_32_lowering_helper(x): k_degree = 9 w_lt_5_constants = [ 2.81022636e-08, 3.43273939e-07, -3.5233877e-06, @@ -212,6 +212,83 @@ def erf_inv_32_lowering_helper(x): return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x) +# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L696-L802 +def _erf_inv_64_lowering_helper(x): + w_lt_625_constants = [ + -3.6444120640178196996e-21, -1.685059138182016589e-19, + 1.2858480715256400167e-18, 1.115787767802518096e-17, + -1.333171662854620906e-16, 2.0972767875968561637e-17, + 6.6376381343583238325e-15, -4.0545662729752068639e-14, + -8.1519341976054721522e-14, 2.6335093153082322977e-12, + -1.2975133253453532498e-11, -5.4154120542946279317e-11, + 1.051212273321532285e-09, -4.1126339803469836976e-09, + -2.9070369957882005086e-08, 4.2347877827932403518e-07, + -1.3654692000834678645e-06, -1.3882523362786468719e-05, + 0.0001867342080340571352, -0.00074070253416626697512, + -0.0060336708714301490533, 0.24015818242558961693, + 1.6536545626831027356 + ] + + w_lt_16_constants = [ + 2.2137376921775787049e-09, 9.0756561938885390979e-08, + -2.7517406297064545428e-07, 1.8239629214389227755e-08, + 1.5027403968909827627e-06, -4.013867526981545969e-06, + 2.9234449089955446044e-06, 1.2475304481671778723e-05, + -4.7318229009055733981e-05, 6.8284851459573175448e-05, + 2.4031110387097893999e-05, -0.0003550375203628474796, + 0.00095328937973738049703, -0.0016882755560235047313, + 0.0024914420961078508066, -0.0037512085075692412107, + 0.005370914553590063617, 1.0052589676941592334, + 3.0838856104922207635, + ] + + w_gt_16_constants = [ + -2.7109920616438573243e-11, -2.5556418169965252055e-10, + 1.5076572693500548083e-09, -3.7894654401267369937e-09, + 7.6157012080783393804e-09, -1.4960026627149240478e-08, + 2.9147953450901080826e-08, -6.7711997758452339498e-08, + 2.2900482228026654717e-07, -9.9298272942317002539e-07, + 4.5260625972231537039e-06, -1.9681778105531670567e-05, + 7.5995277030017761139e-05, -0.00021503011930044477347, + -0.00013871931833623122026, 1.0103004648645343977, + 4.8499064014085844221, + ] # should add "as jnp.float64 array"? + + w = -jnp.log1p(x * -x) + w_lt_625 = w < 6.25 + w_lt_16 = w < 16.0 + + def get_coefficient(i): + c = w_lt_625_constants[i] + if i < 19: + c = jnp.where(w_lt_625, c, w_lt_16_constants[i]) + if i < 17: + c = jnp.where(w_lt_16, c, w_gt_16_constants[i]) + return c + + select2 = jnp.where(w_lt_16, 3.25, 5.0) + select2_result = jnp.sqrt(w) - select2 + w = jnp.where(w_lt_625, w - 3.125, select2_result) + + p = get_coefficient(0) + for i in range(1, 17): + p = get_coefficient(i) + p * w + for i in range(17, 19): + p = jnp.where(w_lt_16, get_coefficient(i) + p * w, p) + for i in range(19, 23): + p = jnp.where(w_lt_625, get_coefficient(i) + p * w, p) + + return jnp.where(jnp.abs(x) == 1.0, np.inf * x, p * x) + + +def erf_inv_lowering_helper(x): + if x.dtype == jnp.float32: + return _erf_inv_32_lowering_helper(x) + if x.dtype == jnp.float64: + return _erf_inv_64_lowering_helper(x) + raise NotImplementedError(f"erf_inv_lowering_helper not implemented for {x.dtype}") + + def sign_lowering_helper(x): if jnp.issubdtype(x.dtype, jnp.unsignedinteger): return (x != 0).astype(x.dtype) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 8495d6a0524e..63c3148e8108 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -702,22 +702,28 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(out, expected) @parameterized.product( - dtype=[jnp.float32], - value=[-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4], + dtype=[jnp.float32, jnp.float64], + value=[-3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4], ) def test_erf_inv(self, dtype, value): + if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64: + self.skipTest("float64 is not supported on TPU") + @functools.partial( self.pallas_call, - # TODO(ayx): add float64 support for `erf_inv` - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + out_shape=jax.ShapeDtypeStruct((8, 128), dtype), ) def kernel(x_ref, o_ref): o_ref[...] = lax.erf_inv(x_ref[...]) - x = jnp.full((8, 128), value, dtype=dtype) - out = kernel(x) - expected = lax.erf_inv(x) - np.testing.assert_array_equal(out, expected) + with contextlib.ExitStack() as stack: + if jnp.dtype(dtype).itemsize == 8: + stack.enter_context(config.enable_x64(True)) + + x = jnp.full((8, 128), value, dtype=dtype) + out = kernel(x) + expected = lax.erf_inv(x) + np.testing.assert_array_equal(out, expected) class OpsInterpretTest(OpsTest): From 56d0c695c91dc7ca6ddcde3c901c820c4388bd81 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 19 Sep 2024 09:03:20 -0700 Subject: [PATCH 564/702] Condition tan lowering on jaxlib version rather than forward compatibility mode. PiperOrigin-RevId: 676436269 --- jax/_src/lax/lax.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 28ad429b0367..48af9c64ffc9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -62,6 +62,7 @@ standard_multi_result_abstract_eval, standard_primitive) from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -2012,19 +2013,17 @@ def _cos_lowering(ctx, x): def _tan_impl(x): return div(sin(x), cos(x)) -def _tan_lowering(ctx, x): - # TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this - # lowering is supported, but export doesn't target a sufficiently up-to-date - # StableHLO version, and the compatibility updates from - # https://github.com/openxla/xla/pull/16649 aren't included in the 0.4.33 - # release. - if ctx.is_forward_compat(): - return _nary_lower_hlo(chlo.tan, ctx, x) - return _nary_lower_hlo(hlo.tan, ctx, x) - tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -mlir.register_lowering(tan_p, _tan_lowering) +# TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this +# lowering is mostly supported, but it fails on export or with the PJRT plugin +# because those modes target an older StableHLO version, and the +# compatibility updates from https://github.com/openxla/xla/pull/16649 aren't +# included in the 0.4.33 release. +if jaxlib_version <= (0, 4, 33): + mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) +else: + mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): From 3b89a2e57369edb7fd0bc41a339b119d47a600f2 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 19 Sep 2024 09:41:28 -0700 Subject: [PATCH 565/702] Add a utility function to create a tangent zero value from a primal value. PiperOrigin-RevId: 676449863 --- jax/_src/ad_util.py | 11 ++++++++++- jax/custom_derivatives.py | 3 ++- tests/api_test.py | 5 +++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index c69ff3754dc6..bd1427f59e01 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -20,7 +20,7 @@ from jax._src import core from jax._src import traceback_util from jax._src.core import Primitive, valid_jaxtype, raise_to_shaped, get_aval -from jax._src.tree_util import register_pytree_node +from jax._src.tree_util import register_pytree_node, tree_map from jax._src.typing import Array, ArrayLike from jax._src.util import safe_map @@ -113,6 +113,15 @@ def __getattr__(self, name): def from_primal_value(val: Any) -> SymbolicZero: return SymbolicZero(get_aval(val).to_tangent_aval()) +def zero_from_primal(val, symbolic_zeros=False): + def f(x): + tangent_aval = get_aval(x).to_tangent_aval() + if symbolic_zeros: + return SymbolicZero(tangent_aval) + else: + return zeros_like_aval(tangent_aval) + return tree_map(f, val) + JaxTypeOrTracer = Any def replace_internal_symbolic_zeros( diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 96dc8898fd8e..8e517f5d4610 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -34,5 +34,6 @@ ) from jax._src.ad_util import ( - SymbolicZero as SymbolicZero + SymbolicZero as SymbolicZero, + zero_from_primal as zero_from_primal ) diff --git a/tests/api_test.py b/tests/api_test.py index b0915a1df44b..1deaa4c08dc8 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7542,7 +7542,8 @@ def test_float0(self): def f(x, y): return x, y def f_jvp(primals, _): - return primals, (2., scalar_float0) + x, y = primals + return (x, y), (2., custom_derivatives_public.zero_from_primal(y)) f.defjvp(f_jvp) primals = (2., 3) @@ -7558,7 +7559,7 @@ def f(x, y): return x, y def f_jvp(primals, _): x, y = primals - return (x, y), (2., scalar_float0) + return (x, y), (2., custom_derivatives_public.zero_from_primal(y)) f.defjvp(f_jvp) def foo(x, y): From cc927dd3227506154444ecdd735bca457290cb15 Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Thu, 19 Sep 2024 10:02:30 -0700 Subject: [PATCH 566/702] Ignore RuntimeWarning "invalid value encountered in cast" for LaxBackedNumpyTests.testUniqueEqualNan This is to fix Mac arm64 pytests on CI. The tests started failing after integrating ml-dtypes-0.5.0. Ignoring warnings is probably Ok, as it is inspired by a similar PR in ml-dtypes repo itself: https://github.com/jax-ml/ml_dtypes/pull/186 PiperOrigin-RevId: 676458202 --- tests/lax_numpy_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f93e28dada71..ddf42a28e2ba 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2134,6 +2134,9 @@ def np_fun(x): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) + @jtu.ignore_warning( + category=RuntimeWarning, message='invalid value encountered in cast' + ) def testUniqueEqualNan(self, dtype, equal_nan): shape = (20,) rng = jtu.rand_some_nan(self.rng()) From ef2f2fff0610833575d1f194ac0dcf5497920062 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Thu, 19 Sep 2024 22:42:56 +0530 Subject: [PATCH 567/702] Improved doc for jnp.vander --- jax/_src/numpy/lax_numpy.py | 42 ++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index dd4c54fe7efa..5503cf87c794 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -8443,11 +8443,51 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array: return reshape(lax.mul(a_reshaped, b_reshaped), out_shape) -@util.implements(np.vander) @partial(jit, static_argnames=('N', 'increasing')) def vander( x: ArrayLike, N: int | None = None, increasing: bool = False ) -> Array: + """Generate a Vandermonde matrix. + + JAX implementation of :func:`numpy.vander`. + + Args: + x: input array. Must have ``x.ndim == 1``. + N: int, optional, default=None. Specifies the number of the columns the + output matrix. If not specified, ``N = len(x)``. + increasing: bool, optional, default=False. Specifies the order of the powers + of the columns. If ``True``, the powers increase from left to right, + :math:`[x^0, x^1, ..., x^{(N-1)}]`. By default, the powers decrease from left to + right :math:`[x^{(N-1)}, ..., x^1, x^0]`. + + Returns: + An array of shape ``[len(x), N]`` containing the generated Vandermonde matrix. + + Examples: + >>> x = jnp.array([1, 2, 3, 4]) + >>> jnp.vander(x) + Array([[ 1, 1, 1, 1], + [ 8, 4, 2, 1], + [27, 9, 3, 1], + [64, 16, 4, 1]], dtype=int32) + + If ``N = 2``, generates a Vandermonde matrix with ``2`` columns. + + >>> jnp.vander(x, N=2) + Array([[1, 1], + [2, 1], + [3, 1], + [4, 1]], dtype=int32) + + Generates the Vandermonde matrix in increaing order of powers, when + ``increasing=True``. + + >>> jnp.vander(x, increasing=True) + Array([[ 1, 1, 1, 1], + [ 1, 2, 4, 8], + [ 1, 3, 9, 27], + [ 1, 4, 16, 64]], dtype=int32) + """ util.check_arraylike("vander", x) x = asarray(x) if x.ndim != 1: From df781e455a5e70894da8f1dc06d719ac39c74f51 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 19 Sep 2024 10:16:34 -0700 Subject: [PATCH 568/702] [JAX] Switch host_callback to use MLIR lowering instead of the older direct HLO translation rules. Change in preparation for removing XlaBuilder from Python bindings. PiperOrigin-RevId: 676465019 --- jax/experimental/host_callback.py | 158 +++++++++++++++++++++++++++--- 1 file changed, 143 insertions(+), 15 deletions(-) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 43e9813d7fac..63c3299c5904 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -536,6 +536,8 @@ def power3_with_cotangents(x): from jax._src import xla_bridge as xb from jax._src.lib import xla_client from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version +from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo import numpy as np @@ -1085,7 +1087,6 @@ def _with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): finally: builder.clear_sharding() - def _outside_call_translation_rule(ctx, avals_in, avals_out, @@ -1185,8 +1186,123 @@ def _outside_call_translation_rule(ctx, f"identity = {identity}") return results + [next_token, next_itoken] +if xla_extension_version < 287: + xla.register_translation(outside_call_p, _outside_call_translation_rule) + + +def _outside_call_outfeed_lowering(ctx: mlir.LoweringRuleContext, + *args_op, + identity, + device_index, + flat_results_aval=(), + **params): + # We expect the current tokens at the end, inserted by _rewrite_jaxpr. + current_token = args_op[-2] + current_itoken = args_op[-1] + + args_to_outfeed = args_op[:-2] + # Some platforms refuse to infeed empty arrays. We generate constants + # instead. + non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)), + flat_results_aval)) + need_callback_results_on_device = (not identity and + len(non_empty_flat_results_aval) > 0) + send_infeed = need_callback_results_on_device + generated_infeed = False # Keep track if we emitted an infeed op + for platform in ctx.module_context.platforms: + _raise_if_using_outfeed_with_pjrt_c_api( + xb.get_backend(platform) + ) + callback_id = _register_callback( + functools.partial( + _outside_call_run_callback, + send_infeed=send_infeed, + identity=identity, + flat_results_aval=flat_results_aval, + **params)) -xla.register_translation(outside_call_p, _outside_call_translation_rule) + outfeed_sharding = xla_client.OpSharding() + outfeed_sharding.type = xla_client.OpSharding.Type.MAXIMAL + outfeed_sharding.tile_assignment_dimensions = [1] + outfeed_sharding.tile_assignment_devices = [device_index] + + # next_token = _callback_handler_data.receiver.add_outfeed( + # comp, current_token, callback_id, args_to_outfeed, device_index) + + xla_shapes = util.flatten( + xla.aval_to_xla_shapes(aval) for aval in ctx.avals_in[:-2]) + _callback_handler_data.receiver.register_outfeed(callback_id, xla_shapes) + outfeed_header_start = 271828 # Must match kOutfeedHeaderStart in C++ + header = mlir.ir_constant(np.array([outfeed_header_start, callback_id], + dtype=np.uint32)) + header_outfeed = hlo.OutfeedOp([header], current_token, + outfeed_config=ir.StringAttr.get('')) + mlir.set_sharding(header_outfeed, outfeed_sharding) + next_token, = header_outfeed.results + data_outfeed = hlo.OutfeedOp(args_to_outfeed, next_token, + outfeed_config=ir.StringAttr.get('')) + mlir.set_sharding(data_outfeed, outfeed_sharding) + next_token, = data_outfeed.results + + + if identity: + results = list(args_to_outfeed) + next_itoken = current_itoken + else: + empty_results = [ + mlir.ir_constant(np.zeros(aval.shape, aval.dtype)) + for aval in flat_results_aval + if _aval_is_empty(aval) + ] + if non_empty_flat_results_aval: + assert need_callback_results_on_device + after_outfeed_itoken = hlo.AfterAllOp([current_itoken, next_token]) + # We shard the infeed as AssignedDevice(device_index). This must match the + # outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support + # this kind of sharding, we use a custom translation for infeed. + array_sharding_proto = xla_client.OpSharding() + array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL + array_sharding_proto.tile_assignment_dimensions = [1] + array_sharding_proto.tile_assignment_devices = [device_index] + + token_sharding_proto = xla_client.OpSharding() + token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED + infeed_sharding_proto = xla.tuple_sharding_proto( + [array_sharding_proto] * len(non_empty_flat_results_aval) + + [token_sharding_proto]) + + output_types = map(mlir.aval_to_ir_types, non_empty_flat_results_aval) + flat_output_types = util.flatten(output_types) + + layouts = ir.ArrayAttr.get([ + ir.ArrayAttr.get( + [mlir.i64_attr(i) + for i in range(len(aval.shape) - 1, -1, -1)]) + for aval in non_empty_flat_results_aval + ]) + infeed = hlo.InfeedOp(flat_output_types + [hlo.TokenType.get()], + after_outfeed_itoken, + infeed_config=ir.StringAttr.get(''), + layout=layouts) + mlir.set_sharding(infeed, infeed_sharding_proto) + non_empty_results = list(infeed.results[:-1]) + next_itoken = infeed.results[-1] + generated_infeed = True + results = [ + empty_results.pop(0) + if _aval_is_empty(result_aval) else non_empty_results.pop(0) + for result_aval in flat_results_aval + ] + else: + results = empty_results + next_itoken = current_itoken + + assert generated_infeed == send_infeed, ( + f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})") + assert identity or len(results) == len(flat_results_aval), ( + f"got {len(results)} but expected {len(flat_results_aval)}. " + f"identity = {identity}") + return results + [next_token, next_itoken] def _outside_call_lowering(ctx: mlir.LoweringRuleContext, @@ -1202,23 +1318,32 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext, platform = ctx.module_context.platforms[0] use_outfeed = _use_outfeed(platform) if use_outfeed: - # Fall back to XLA path if we are using the outfeed - # TODO(sharadmv): update to use MLIR for this path as well and delete - # XLA lowering - return mlir.xla_fallback_lowering(outside_call_p)( - ctx, - *args, - has_token=has_token, - identity=identity, - flat_results_aval=flat_results_aval, - device_index=device_index, - **params) + if xla_extension_version < 287: + return mlir.xla_fallback_lowering(outside_call_p)( + ctx, + *args, + has_token=has_token, + identity=identity, + device_index=device_index, + flat_results_aval=flat_results_aval, + **params, + ) + else: + return _outside_call_outfeed_lowering( + ctx, *args, + has_token=has_token, + identity=identity, + flat_results_aval=flat_results_aval, + device_index=device_index, + **params, + ) else: # TODO(necula): It seems that on CPU, with custom call, the device_index # does not work, and the callback is always run on device_index=0 if (device_index != 0 and "cpu" in ctx.module_context.platforms): raise ValueError( "The device_index feature on CPU works only when using outfeed.") + # We expect the current tokens at the end, inserted by _rewrite_jaxpr. assert has_token current_token = args[-2] @@ -1280,7 +1405,10 @@ def wrapped_callback(*args): f"identity = {identity}") return list(results) + [next_token, next_itoken] -mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu") +if xla_extension_version < 287: + mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu") +else: + mlir.register_lowering(outside_call_p, _outside_call_lowering) def _outside_call_run_callback( arrays, device, *, @@ -1766,7 +1894,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn], id_p.multiple_results = True id_p.def_impl(lambda *args: args) id_p.def_abstract_eval(lambda *args: args) -xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args) +mlir.register_lowering(id_p, lambda ctx, *args: args) dispatch.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False) From f75c5c6b2d84728faacd8ac7420819393b0b17fe Mon Sep 17 00:00:00 2001 From: Loren Maggiore Date: Thu, 19 Sep 2024 10:41:58 -0700 Subject: [PATCH 569/702] [jax] config option to disable using a mesh as a context manager. PiperOrigin-RevId: 676475039 --- jax/_src/config.py | 9 +++++++++ jax/_src/mesh.py | 2 ++ 2 files changed, 11 insertions(+) diff --git a/jax/_src/config.py b/jax/_src/config.py index 75df6be7d9e7..fe56ec68f6cb 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1378,6 +1378,15 @@ def _update_jax_memories_thread_local(val): update_thread_local_hook=lambda val: \ update_thread_local_jit_state(numpy_dtype_promotion=val)) +disallow_mesh_context_manager = bool_state( + name='jax_disallow_mesh_context_manager', + default=False, + help=( + 'If set to True, trying to use a mesh as a context manager will' + ' result in a RuntimeError.' + ), +) + def _update_x64_global(val): lib.jax_jit.global_state().enable_x64 = val diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index b30286b36a76..20234b678172 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -217,6 +217,8 @@ def __setattr__(self, name, value): super().__setattr__(name, value) def __enter__(self): + if jax_config.disallow_mesh_context_manager.value: + raise RuntimeError("Mesh context manager is disabled.") new_env = thread_resources.stack[-1].with_mesh(self) thread_resources.stack.append(new_env) thread_resources.env = new_env From c9bbf71ec638f4d5b4f6ab44c6cc816290f2a5b7 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 19 Sep 2024 11:38:01 -0700 Subject: [PATCH 570/702] Cleanup `ParsedPartitionSpec` and remove `CanonicalizedParsedPartitionSpec`. Also mark `user_spec` as private. PiperOrigin-RevId: 676498946 --- jax/_src/pjit.py | 4 +- jax/_src/sharding_impls.py | 103 +++++++------------------------------ tests/pjit_test.py | 5 -- 3 files changed, 22 insertions(+), 90 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 34bf257f639e..0abaa3fd0139 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1041,8 +1041,8 @@ def _create_sharding_for_array(mesh, x, name, api_name): ' then the mesh context manager is not required.') # A nice user error is raised in prepare_axis_resources. assert x is None or isinstance(x, ParsedPartitionSpec), x - return (pxla.create_mesh_pspec_sharding(mesh, x) - if x is None else pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x)) + return (pxla.create_mesh_pspec_sharding(mesh, x) if x is None else + pxla.create_mesh_pspec_sharding(mesh, x.get_partition_spec(), x)) def _create_sharding_with_device_backend(device, backend): diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index af86425128c9..310ff38b7247 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -18,7 +18,6 @@ from collections import OrderedDict from collections.abc import Mapping, Sequence import dataclasses -import enum import functools import itertools import math @@ -955,43 +954,20 @@ def get_array_mapping( cast(ArrayMapping, get_array_mapping(p))) -class SpecSync(enum.IntEnum): - """Encodes how much out of sync the real value of partitions is compared to the user specified one. - - We use this to make sure we don't show garbage modified values while claiming - that the users have specified them like that. - """ - OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted - DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes - IN_SYNC = 2 # Entirely in sync - class ParsedPartitionSpec: - __slots__ = ('unsafe_user_spec', 'partitions', 'sync') + __slots__ = ('_user_spec', 'partitions') - def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC): - self.unsafe_user_spec = user_spec + def __init__(self, user_spec, partitions): + self._user_spec = user_spec # None in partitions represents unconstrained dim. # TODO(yashkatariya): May use a sentinel value. self.partitions = tuple(partitions) - self.sync = sync - - @property - def user_spec(self): - return self.unsynced_user_spec(SpecSync.IN_SYNC) def get_partition_spec(self) -> PartitionSpec: - if self.sync < SpecSync.IN_SYNC: - return get_single_pspec(self) + if isinstance(self._user_spec, PartitionSpec): + return self._user_spec else: - if isinstance(self.unsafe_user_spec, PartitionSpec): - return self.unsafe_user_spec - else: - return get_single_pspec(self) - - def unsynced_user_spec(self, min_sync): - if self.sync < min_sync: - raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})") - return self.unsafe_user_spec + return get_single_pspec(self) def insert_axis_partitions(self, dim, val): parts = self.partitions @@ -999,8 +975,7 @@ def insert_axis_partitions(self, dim, val): if too_short > 0: parts += ((),) * too_short new_partitions = util.tuple_insert(parts, dim, val) - new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC - return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync) + return ParsedPartitionSpec(None, new_partitions) @classmethod def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False): @@ -1027,13 +1002,12 @@ def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False): return cls(new_entry, axis_specs) def __hash__(self): - return hash((self.partitions, self.sync)) + return hash(self.partitions) def __eq__(self, other): if not isinstance(other, ParsedPartitionSpec): return False - return (self.partitions == other.partitions and - self.sync == other.sync) + return self.partitions == other.partitions def __len__(self): return len(self.partitions) @@ -1045,58 +1019,19 @@ def __iter__(self): return iter(self.partitions) def __repr__(self): - return (f"ParsedPartitionSpec(partitions={self.partitions}, " - f"unsafe_user_spec={self.unsafe_user_spec}, " - f"sync={self.sync})") - -class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec): - """ParsedPartitionSpecs that are canonicalized. - - ParsedPartitionSpecs may contain trailing empty tuples, that make them - semantically different in general, and yet in some situations we prefer - to regard them as equivalent. For example, partitions of () and ((),) - cannot be always considered equivalent, since the first one is a valid - spec for a scalar value, while the second is not! However, when either of - those are applied to a 2D array, they both mean that the array is fully - replicated. - - So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from - partitions. - """ - - def __init__(self, parsed_pspec: ParsedPartitionSpec): - partitions = list(parsed_pspec.partitions) - while partitions and partitions[-1] == (): - partitions.pop() - - super().__init__(parsed_pspec.unsafe_user_spec, partitions, - parsed_pspec.sync) - - def __repr__(self): - return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, " - f"unsafe_user_spec={self.unsafe_user_spec}, " - f"sync={self.sync})") + return f"ParsedPartitionSpec(partitions={self.partitions})" def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): - # This split exists because you can pass `_parsed_pspec` that has been - # modified from the original. For example: Adding extra dimension to - # axis_resources for vmap handlers. In such cases you need to preserve the - # `sync` attribute of parsed pspecs. - # PartitionSpec is inferred from the parsed pspec in this case. - # TODO(yaskatariya): Remove this and replace this with a normalized - # representation of Parsed Pspec if parsed_pspec is None: parsed_pspec = prepare_axis_resources( PartitionSpec() if spec is None else spec, "NamedSharding spec", allow_unconstrained_dims=True) - _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) return parsed_pspec -def prepare_axis_resources(axis_resources, - arg_name, +def prepare_axis_resources(axis_resources, arg_name, allow_unconstrained_dims=False): # PyTrees don't treat None values as leaves, so we use an is_leaf function. entries, treedef = tree_util.tree_flatten( @@ -1133,9 +1068,11 @@ def _check_unique_resources(axis_resources, arg_name): if resource_counts.most_common(1)[0][1] > 1: multiple_uses = [r for r, c in resource_counts.items() if c > 1] if multiple_uses: - raise ValueError(f"A single {arg_name} specification can map every mesh axis " - f"to at most one positional dimension, but {arg_axis_resources.user_spec} " - f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}") + raise ValueError( + f'A single {arg_name} specification can map every mesh axis to at' + ' most one positional dimension, but' + f' {arg_axis_resources.get_partition_spec()} has duplicate entries' + f' for {mesh_lib.show_axes(multiple_uses)}') # Axis environments @@ -1314,8 +1251,7 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding, out.extend(parse_flatten_op_sharding(s, mesh)) return out elif hlo_sharding.is_replicated(): - return [CanonicalizedParsedPartitionSpec( - ParsedPartitionSpec(PartitionSpec(), ()))] + return [ParsedPartitionSpec(PartitionSpec(), ())] elif hlo_sharding.is_tiled(): mesh_shape = mesh.shape mesh_axis_order = unflatten_array( @@ -1339,8 +1275,9 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding, ) if hlo_sharding.replicate_on_last_tile_dim(): partitions = partitions[:-1] - return [CanonicalizedParsedPartitionSpec( - ParsedPartitionSpec('', partitions))] + while partitions and partitions[-1] == (): + partitions.pop() + return [ParsedPartitionSpec(None, partitions)] else: raise AssertionError("Unhandled OpSharding type. Please open a bug report!") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 11a541f2e5f5..57106948f7d3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5208,11 +5208,6 @@ def test_get_partition_spec(self): self.assertEqual(recovered_parsed_pspec[0].get_partition_spec(), P('x', 'y')) - out_of_sync_parsed_pspec = sharding_impls.ParsedPartitionSpec( - P('x', 'y'), ('x', 'y'), sharding_impls.SpecSync.OUT_OF_SYNC) - self.assertEqual(out_of_sync_parsed_pspec.get_partition_spec(), - P('x', 'y')) - def test_mesh_with_list_devices(self): mesh = jax.sharding.Mesh(jax.devices(), ('x',)) self.assertIsInstance(mesh.devices, np.ndarray) From 4bce4f6452c4c1e5d9a7c8df6a50cd56f3a0b5cc Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 16 Sep 2024 14:25:20 -0700 Subject: [PATCH 571/702] [Pallas] Add block-sparse kernel tutorial --- docs/_static/pallas/sparse/block_coo.svg | 1 + docs/_static/pallas/sparse/prefetch_map.svg | 1 + docs/_static/pallas/sparse/sparse_matmul.svg | 1 + docs/conf.py | 2 + docs/pallas/tpu/index.rst | 2 + docs/pallas/tpu/sparse.ipynb | 724 +++++++++++++++++++ docs/pallas/tpu/sparse.md | 567 +++++++++++++++ 7 files changed, 1298 insertions(+) create mode 100644 docs/_static/pallas/sparse/block_coo.svg create mode 100644 docs/_static/pallas/sparse/prefetch_map.svg create mode 100644 docs/_static/pallas/sparse/sparse_matmul.svg create mode 100644 docs/pallas/tpu/sparse.ipynb create mode 100644 docs/pallas/tpu/sparse.md diff --git a/docs/_static/pallas/sparse/block_coo.svg b/docs/_static/pallas/sparse/block_coo.svg new file mode 100644 index 000000000000..474dfcb64d7a --- /dev/null +++ b/docs/_static/pallas/sparse/block_coo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/sparse/prefetch_map.svg b/docs/_static/pallas/sparse/prefetch_map.svg new file mode 100644 index 000000000000..08fdd2c1cf39 --- /dev/null +++ b/docs/_static/pallas/sparse/prefetch_map.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/sparse/sparse_matmul.svg b/docs/_static/pallas/sparse/sparse_matmul.svg new file mode 100644 index 000000000000..06a24317cfe1 --- /dev/null +++ b/docs/_static/pallas/sparse/sparse_matmul.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 1a7bf32842f0..ed6fcfd0dc8b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -134,6 +134,7 @@ def _do_not_evaluate_in_jax( 'pallas/quickstart.md', 'pallas/tpu/pipelining.md', 'pallas/tpu/distributed.md', + 'pallas/tpu/sparse.md', 'pallas/tpu/matmul.md', 'jep/9407-type-promotion.md', 'autodidax.md', @@ -224,6 +225,7 @@ def _do_not_evaluate_in_jax( 'pallas/quickstart.*', 'pallas/tpu/pipelining.*', 'pallas/tpu/distributed.*', + 'pallas/tpu/sparse.*', 'pallas/tpu/matmul.*', 'sharded-computation.*', 'distributed_data_loading.*' diff --git a/docs/pallas/tpu/index.rst b/docs/pallas/tpu/index.rst index eba986c2cfe8..20abad5f610e 100644 --- a/docs/pallas/tpu/index.rst +++ b/docs/pallas/tpu/index.rst @@ -9,4 +9,6 @@ TPU specific documentation. details pipelining matmul + sparse distributed + diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb new file mode 100644 index 000000000000..909103273e1e --- /dev/null +++ b/docs/pallas/tpu/sparse.ipynb @@ -0,0 +1,724 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ZHuzXqQ-9JUQ" + }, + "source": [ + "# Scalar Prefetch and Block-Sparse Computation\n", + "\n", + "In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 56, + "status": "ok", + "timestamp": 1726001133029, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "ibeIs_6QFMAM", + "outputId": "d72edb91-4529-4650-c9e9-b96788608635" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on TPU v5 lite\n" + ] + } + ], + "source": [ + "import functools\n", + "import timeit\n", + "import numpy as np\n", + "import jax\n", + "from jax import numpy as jnp\n", + "from jax import lax\n", + "from jax.experimental import checkify\n", + "from jax.experimental import pallas as pl\n", + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "assert \"TPU\" in jax.devices()[0].device_kind, \"Please run this notebook with TPU devices.\"\n", + "print(\"Running on\", jax.devices()[0].device_kind)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FIDGpPTEIcOa" + }, + "source": [ + "## Dynamic Block Indexing with Scalar Prefetch\n", + "\n", + "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar memory\") that is loaded before the start of the pipeline (\"prefetch\"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n", + "\n", + "To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`:\n", + "\n", + "```python\n", + "class PrefetchScalarGridSpec:\n", + " def __init__(self,\n", + " num_scalar_prefetch: int,\n", + " grid: tuple[int, ...],\n", + " in_specs: PyTree[BlockSpec],\n", + " out_specs: PyTree[BlockSpec],\n", + " scratch_shapes: tuple[MemorySpace, ...]):\n", + " ...\n", + "```\n", + "\n", + "The `num_scalar_prefetch` parameter indicates the number of scalar prefetch values. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetch `Ref`s passed in to the `index_map` and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to both `index_map` and kernel are always fixed and described below:\n", + "\n", + "- Each `BlockSpec`'s `index_map` now expects the prefetch `Ref`s to come after the grid indices:\n", + "```python\n", + "def index_map(*grid_indices, *prefetch_refs):\n", + " ...\n", + "```\n", + "\n", + "- The user-defined kernel expects prefetch `Ref`s to come before the input `Ref`s. Additionally, the scratch refs come after the output `Ref`s.\n", + "```python\n", + "def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):\n", + " ...\n", + "```\n", + "\n", + "- When calling a new kernel using `pallas_call`, the function returned by `pallas_call` also expects the scalar prefetch arguments to come before the inputs, e.g.\n", + "```python\n", + "kernel = pl.pallas_call(...)\n", + "result = kernel(*prefetch_args, *input_args)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pA8RmHEA2HN3" + }, + "source": [ + "## Example: Block Dynamic Slice with Scalar Prefetch\n", + "\n", + "Let's begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices:\n", + "\n", + "1. Outside of the kernel, we compute the block index to extract as: `block_idx = (start[0] // size[0], start[1] // size[1])`\n", + "\n", + "2. We pass `block_idx` as a scalar prefetch argument into `pallas_call`.\n", + "\n", + "3. In our index map, we use the block index to select the corresponding block by returning `(block_idx[0], block_idx[1])`.\n", + "\n", + "Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 143, + "status": "ok", + "timestamp": 1726003877561, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "FWeTBlEYlCGD", + "outputId": "4b04a441-c97c-4d0d-d167-c60d4d31fd2e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error |result - lax.dynamic_slice| = 0\n" + ] + } + ], + "source": [ + "def dynamic_slice_kernel(indices, x_ref, o_ref):\n", + " del indices\n", + " o_ref[...] = x_ref[...]\n", + "\n", + "@checkify.checkify\n", + "@functools.partial(jax.jit, static_argnums=(2,))\n", + "def block_dynamic_slice(x, starts, sizes):\n", + " grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=1,\n", + " grid=(1, 1),\n", + " in_specs=[pl.BlockSpec(\n", + " sizes,\n", + " lambda i, j, block_idx: (block_idx[0], block_idx[1]))],\n", + " out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),\n", + " )\n", + "\n", + " kernel = pl.pallas_call(\n", + " dynamic_slice_kernel,\n", + " grid_spec=grid_spec,\n", + " out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),\n", + " )\n", + " # Checkify inserts a runtime assert that starts are divisible by block size.\n", + " checkify.check(starts[0] % sizes[0] == 0, \"Starts must be divisible by size.\")\n", + " checkify.check(starts[1] % sizes[1] == 0, \"Starts must be divisible by size.\")\n", + " block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])\n", + " return kernel(block_idx, x)\n", + "\n", + "shape = (512, 512)\n", + "x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)\n", + "err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))\n", + "err.throw()\n", + "ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))\n", + "diff = jnp.max(jnp.abs(result - ref))\n", + "print(\"Error |result - lax.dynamic_slice| =\", diff)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K2dod4lkoifa" + }, + "source": [ + "## Sparse Kernels: Representing Sparse Data\n", + "\n", + "Before we dive into implementing sparse kernels, let's first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of `(block_index, block_data)` pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix.\n", + "\n", + "The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements.\n", + "\n", + "![block_coo](../../_static/pallas/sparse/block_coo.svg)\n", + "\n", + "We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1gLiSvgIYUEx" + }, + "outputs": [], + "source": [ + "def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):\n", + " \"\"\"Returns a sampled matrix and its block-sparse representation.\n", + "\n", + " Args:\n", + " key: RNG Key.\n", + " M: Major array dimension.\n", + " N: Minor array dimension.\n", + " blk_M: Block size along M dimension.\n", + " blk_N: Block size along N dimension.\n", + " p: Probability that a block will be non-zero.\n", + " dtype: dtype of the sampled matrix.\n", + "\n", + " Returns:\n", + " dense_mat: A (M, N) dense sampled array.\n", + " block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing\n", + " the non-zero blocks of the matrix.\n", + " indices_i: A (num_blocks,) array of block indices for the first axis.\n", + " indices_j: A (num_blocks,) array of block indices for the second axis.\n", + " \"\"\"\n", + " mask_key, blocks_key = jax.random.split(key)\n", + " num_blocks = (M // blk_M, N // blk_N)\n", + " # We first sample a block mask, denoting which blocks are nonzero.\n", + " block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)\n", + " num_blocks = jnp.sum(block_mask)\n", + " indices = jnp.where(block_mask)\n", + " # For each non-zero block, we sample a block of random values.\n", + " block_data = jax.random.uniform(blocks_key,\n", + " shape=(num_blocks, blk_M, blk_N),\n", + " dtype=dtype)\n", + " # For checking purposes, create the dense version of the sparse matrix.\n", + " dense_mat = jnp.zeros((M, N), dtype=dtype)\n", + " for blk in range(num_blocks):\n", + " idx_i = indices[0][blk]\n", + " idx_j = indices[1][blk]\n", + " slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)\n", + " slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)\n", + " dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])\n", + " return dense_mat, block_data, indices[0], indices[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eFyoZSTOH9Fk" + }, + "source": [ + "## Example: Sparse @ Dense Matrix Multiplication\n", + "\n", + "In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n", + "\n", + "We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:\n", + "\n", + "![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg)\n", + "\n", + "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 673, + "status": "ok", + "timestamp": 1725919879291, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "WfyV2WWhjsyA", + "outputId": "fa4d4fff-bc6b-4dc9-ac14-63276ca14131" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean |result - ref|: 0\n" + ] + } + ], + "source": [ + "M = N = K = 16384\n", + "blk_M = blk_N = blk_K = 512\n", + "\n", + "\n", + "def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.\n", + " x_ref, y_ref, _, o_ref, # Kernel inputs.\n", + " accum_scratch,\n", + " ):\n", + " \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n", + " del idxs_k_ref\n", + " blk_idx = pl.program_id(1)\n", + " is_start = blk_idx == 0\n", + " changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n", + " @pl.when(is_start | changed_blocks)\n", + " def _():\n", + " accum_scratch[...] = jnp.zeros_like(accum_scratch)\n", + " accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)\n", + "\n", + " next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])\n", + " is_end = blk_idx == (num_blocks - 1)\n", + " @pl.when(is_end | next_block_change)\n", + " def _():\n", + " o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n", + "\n", + "\n", + "def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", + " del j, blk_idxs_i, blk_idxs_k\n", + " return (blk_idx, 0, 0)\n", + "def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", + " del blk_idxs_i\n", + " return (blk_idxs_k[blk_idx], j)\n", + "def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", + " del blk_idxs_k\n", + " return (blk_idxs_i[blk_idx], j)\n", + "\n", + "(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(\n", + " jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)\n", + "num_blocks = X_blocks.shape[0]\n", + "Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)\n", + "zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)\n", + "out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=2,\n", + " # Note that while num_blocks is static here, Pallas does support\n", + " # dynamic grid sizes.\n", + " grid=(M // blk_M, num_blocks),\n", + " in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n", + " pl.BlockSpec((blk_K, blk_N), y_map),\n", + " # Placeholder for a zeros-array used by input_output_aliases.\n", + " pl.BlockSpec((blk_M, blk_N), o_map),\n", + " ],\n", + " out_specs=pl.BlockSpec((blk_M, blk_N), o_map),\n", + " scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]\n", + ")\n", + "kernel = pl.pallas_call(\n", + " dsd_kernel,\n", + " grid_spec=grid_spec,\n", + " out_shape=out_shape,\n", + " # We use input-output aliases to zero-out o_ref for blocks that we never\n", + " # visit. By passing in an array of zeros we avoid having o_ref start with\n", + " # uninitialized values.\n", + " input_output_aliases={4: 0}, # Map zeros to o_ref.\n", + ")\n", + "args = (indices_i, indices_k, X_blocks, Y, zeros)\n", + "result = kernel(*args)\n", + "\n", + "ref = X_dense @ Y\n", + "diff = jnp.abs(ref - result)\n", + "print('mean |result - ref|:', jnp.mean(diff))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2KDgPKF2tUjq" + }, + "source": [ + "We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5e chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor.\n", + "\n", + "There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM:\n", + "- Using `dtype=jnp.bfloat16` is critical for performance since it reduces memory bandwidth by half.\n", + "- Using larger block sizes also helps, since matrix multiply is an $O(N^3)$ compute and $O(N^2)$ memory operation. As $N$ grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 6576, + "status": "ok", + "timestamp": 1725919886762, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "CkzjqnekpZbx", + "outputId": "1ae9031e-705a-4d05-f8b9-d09623918300" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sparse Kernel: 8.136 ms (avg over 100 trials)\n", + "Reference: 46.953 ms (avg over 100 trials)\n" + ] + } + ], + "source": [ + "# Benchmark Sparse Pallas kernel vs reference JAX implementation\n", + "\n", + "def benchmark(f, ntrials: int = 100):\n", + " def run(*args, **kwargs):\n", + " # Compile function first\n", + " jax.block_until_ready(f(*args, **kwargs))\n", + " # Time function\n", + " result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),\n", + " number=ntrials)\n", + " time = result / ntrials\n", + " return time\n", + " return run\n", + "\n", + "\n", + "n_trials = 100\n", + "\n", + "pallas_impl = lambda *args: kernel(*args)\n", + "time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)\n", + "print(\"Sparse Kernel: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))\n", + "\n", + "ref_impl = jax.jit(lambda x, y: x @ y)\n", + "time = benchmark(ref_impl, n_trials)(X_dense, Y)\n", + "print(\"Reference: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q1KKd5vTCwnB" + }, + "source": [ + "## Sparse Access Patterns on Dense Data\n", + "\n", + "In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`).\n", + "\n", + "A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).\n", + "\n", + "The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.\n", + "\n", + "![prefetch_map](../../_static/pallas/sparse/prefetch_map.svg)\n", + "\n", + "*Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.*\n", + "\n", + "Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in the `index_map` function of the BlockSpec.\n", + "\n", + "```python\n", + "def mask_index_map(prefetch_map, i, j, ...):\n", + " next_nonzero_block = prefetch_map[i, j]\n", + " return (next_nonzero_block, 0, 0)\n", + "```\n", + "\n", + "We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ii7rzL5YIA8-" + }, + "source": [ + "## Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ecjiqWfA2RlV" + }, + "source": [ + "In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch map to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs.\n", + "\n", + "As we will be working with a sparse mask, we will begin by implementing a function that converts an `N x M` mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, our `sparsify_mask` function computes:\n", + "- A `block_mask` of shape `(num_N_blocks, num_M_blocks)` indicating if a block is all-zeros (value `0`) or contains non-zero elements (value `1`). If the `block_mask` has a value of 0 we can skip computing the block in the kernel.\n", + "- A `prefetch_mask` array of shape `(num_N_blocks, num_M_blocks)` consisting of indices into `mask_data` for the next non-zero block.\n", + "- A `prefetch_i` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `i` index of the mask.\n", + "- A `prefetch_j` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `j` index of the mask.\n", + "- A `mask_data` array of shape `(num_blocks, blk_N, blk_M)` containing data for non-zero blocks of the mask." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "19zGcliL2SJy" + }, + "outputs": [], + "source": [ + "def sparsify_mask(mask: jax.Array,\n", + " block_shape: tuple[int, int]):\n", + " \"\"\"Preprocesses a mask into a sparse reprentation.\n", + "\n", + " Args:\n", + " mask: A boolean array of shape [M, N]\n", + " block_shape: The size of a single block.\n", + "\n", + " Returns:\n", + " block_mask: A block_shape array of booleans indicating whether a block\n", + " is all-zeros (0) or contains non-zero elements (1).\n", + " prefetch_mask: A block_shape array of integers indicating the index of the\n", + " next non-zero block.\n", + " mask_data: A (num_blocks, block_shape) array containing\n", + " the data for non-zero blocks of the mask.\n", + " \"\"\"\n", + " M, N = mask.shape\n", + " bm, bn = block_shape\n", + "\n", + " block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)\n", + " mask_types_finder = []\n", + " mask_data = []\n", + " mask_type_idxs = []\n", + "\n", + " next_mask_type_idx = 0\n", + " prefetch_mask = jnp.zeros_like(block_mask)\n", + " next_i = (M // bm) - 1\n", + " next_j = (N // bn) - 1\n", + " prefetch_i = jnp.zeros_like(block_mask)\n", + " prefetch_j = jnp.zeros_like(block_mask)\n", + " for i in range(M // bm, -1, -1):\n", + " for j in range(N // bn, -1, -1):\n", + " mask_block = mask[i * bm :(i + 1) * bm,\n", + " j * bn :(j + 1) * bn]\n", + " is_nonzero = jnp.any(mask_block)\n", + " if is_nonzero:\n", + " try:\n", + " type_index = mask_types_finder.index(str(mask_block))\n", + " except ValueError:\n", + " type_index = len(mask_types_finder)\n", + " mask_types_finder.append(str(mask_block))\n", + " mask_data.append(mask_block)\n", + " next_mask_type_idx = type_index\n", + " next_i = i\n", + " next_j = j\n", + " else:\n", + " type_index = -1\n", + " mask_type_idxs.append(type_index)\n", + " block_mask = block_mask.at[i, j].set(is_nonzero)\n", + " prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)\n", + " prefetch_i = prefetch_i.at[i, j].set(next_i)\n", + " prefetch_j = prefetch_j.at[i, j].set(next_j)\n", + " return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "w4b7ckKq67Xw" + }, + "source": [ + "In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over the `N`, `M`, and `K` dimensions. Within the kernel itself, we first check the `block_mask` to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 5374, + "status": "ok", + "timestamp": 1725919713252, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "4YQ9OmbTCSjT", + "outputId": "2d752609-34f2-4059-e8ba-4d80afe8cb26" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean |result - ref|: 1.0252e-05\n" + ] + } + ], + "source": [ + "M = N = K = 16384\n", + "blk_M = blk_N = 512\n", + "blk_K = 1024\n", + "\n", + "def sparse_mask_matmul(\n", + " block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.\n", + " x_ref, y_ref, mask_ref, o_ref, # Kernel inputs.\n", + " accum_scratch\n", + " ):\n", + " del prefetch_mask, prefetch_i, prefetch_j\n", + " i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)\n", + " should_compute = block_mask_ref[i, j] != 0\n", + " @pl.when(k == 0)\n", + " def _():\n", + " o_ref[...] = jnp.zeros_like(o_ref)\n", + " accum_scratch[...] = jnp.zeros_like(accum_scratch[...])\n", + "\n", + " # We only compute the output for blocks with non-zero masks.\n", + " # Otherwise we skip the computation entirely.\n", + " @pl.when(should_compute)\n", + " def _():\n", + " result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)\n", + " accum_scratch[...] += result\n", + " @pl.when(k == pl.num_programs(2) - 1)\n", + " def _():\n", + " o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)\n", + "\n", + "X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)\n", + "Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)\n", + "mask = jnp.ones((M, N), dtype=jnp.int32)\n", + "mask = jnp.tril(mask)\n", + "block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))\n", + "\n", + "def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):\n", + " del prefetch_mask, prefetch_j\n", + " # Zero-out the k index if the mask is zero, to avoid constantly fetching\n", + " # new blocks in the inner loop for blocks we are skipping.\n", + " k_fetch = (block_mask[i, j] != 0) * k\n", + " return (prefetch_i[i, j], k_fetch)\n", + "\n", + "def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):\n", + " del prefetch_mask, prefetch_i\n", + " k_fetch = (block_mask[i, j] != 0) * k\n", + " return (k_fetch, prefetch_j[i, j])\n", + "\n", + "def mask_map(i, j, k, block_mask, prefetch_mask, *_):\n", + " del k, block_mask\n", + " return (prefetch_mask[i, j], 0, 0)\n", + "\n", + "def o_map(i, j, k, *_):\n", + " del k\n", + " return (i, j)\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=4,\n", + " grid=(M // blk_M, N // blk_N, K // blk_K),\n", + " in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),\n", + " pl.BlockSpec((blk_K, blk_N), y_map),\n", + " pl.BlockSpec((1, blk_M, blk_N), mask_map)],\n", + " out_specs=pl.BlockSpec((blk_M, blk_N), o_map),\n", + " scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]\n", + ")\n", + "kernel = pl.pallas_call(\n", + " sparse_mask_matmul,\n", + " grid_spec=grid_spec,\n", + " out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),\n", + ")\n", + "args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)\n", + "result = kernel(*args)\n", + "\n", + "ref = mask * (X @ Y)\n", + "diff = jnp.abs(ref - result)\n", + "print('mean |result - ref|:', jnp.mean(diff))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uutNGgjZGGhB" + }, + "source": [ + "Now let's compare performance versus a naive dense implementation. On TPU v5e, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs.\n", + "\n", + "We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are:\n", + "- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.\n", + "- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 8877, + "status": "ok", + "timestamp": 1725917397452, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "MAT9JjGNvsx8", + "outputId": "a32d56fb-a71b-4007-c6a5-e5270dcaa6cf" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sparse Kernel: 28.648 ms (avg over 100 trials)\n", + "Reference: 49.988 ms (avg over 100 trials)\n" + ] + } + ], + "source": [ + "n_trials = 100\n", + "\n", + "pallas_impl = lambda *args: kernel(*args)\n", + "time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)\n", + "print(\"Sparse Kernel: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))\n", + "\n", + "ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))\n", + "time = benchmark(ref_impl, n_trials)(mask, X, Y)\n", + "print(\"Reference: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md new file mode 100644 index 000000000000..23e14bb9bc0b --- /dev/null +++ b/docs/pallas/tpu/sparse.md @@ -0,0 +1,567 @@ +--- +jupytext: + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + name: python3 +--- + ++++ {"id": "ZHuzXqQ-9JUQ"} + +# Scalar Prefetch and Block-Sparse Computation + +In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory. + +```{code-cell} +--- +executionInfo: + elapsed: 56 + status: ok + timestamp: 1726001133029 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: ibeIs_6QFMAM +outputId: d72edb91-4529-4650-c9e9-b96788608635 +--- +import functools +import timeit +import numpy as np +import jax +from jax import numpy as jnp +from jax import lax +from jax.experimental import checkify +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + +assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices." +print("Running on", jax.devices()[0].device_kind) +``` + ++++ {"id": "FIDGpPTEIcOa"} + +## Dynamic Block Indexing with Scalar Prefetch + +We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar memory") that is loaded before the start of the pipeline ("prefetch"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature. + +To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`: + +```python +class PrefetchScalarGridSpec: + def __init__(self, + num_scalar_prefetch: int, + grid: tuple[int, ...], + in_specs: PyTree[BlockSpec], + out_specs: PyTree[BlockSpec], + scratch_shapes: tuple[MemorySpace, ...]): + ... +``` + +The `num_scalar_prefetch` parameter indicates the number of scalar prefetch values. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetch `Ref`s passed in to the `index_map` and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to both `index_map` and kernel are always fixed and described below: + +- Each `BlockSpec`'s `index_map` now expects the prefetch `Ref`s to come after the grid indices: +```python +def index_map(*grid_indices, *prefetch_refs): + ... +``` + +- The user-defined kernel expects prefetch `Ref`s to come before the input `Ref`s. Additionally, the scratch refs come after the output `Ref`s. +```python +def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs): + ... +``` + +- When calling a new kernel using `pallas_call`, the function returned by `pallas_call` also expects the scalar prefetch arguments to come before the inputs, e.g. +```python +kernel = pl.pallas_call(...) +result = kernel(*prefetch_args, *input_args) +``` + ++++ {"id": "pA8RmHEA2HN3"} + +## Example: Block Dynamic Slice with Scalar Prefetch + +Let's begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices: + +1. Outside of the kernel, we compute the block index to extract as: `block_idx = (start[0] // size[0], start[1] // size[1])` + +2. We pass `block_idx` as a scalar prefetch argument into `pallas_call`. + +3. In our index map, we use the block index to select the corresponding block by returning `(block_idx[0], block_idx[1])`. + +Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices. + +```{code-cell} +--- +executionInfo: + elapsed: 143 + status: ok + timestamp: 1726003877561 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: FWeTBlEYlCGD +outputId: 4b04a441-c97c-4d0d-d167-c60d4d31fd2e +--- +def dynamic_slice_kernel(indices, x_ref, o_ref): + del indices + o_ref[...] = x_ref[...] + +@checkify.checkify +@functools.partial(jax.jit, static_argnums=(2,)) +def block_dynamic_slice(x, starts, sizes): + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(1, 1), + in_specs=[pl.BlockSpec( + sizes, + lambda i, j, block_idx: (block_idx[0], block_idx[1]))], + out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)), + ) + + kernel = pl.pallas_call( + dynamic_slice_kernel, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype), + ) + # Checkify inserts a runtime assert that starts are divisible by block size. + checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.") + checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.") + block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]]) + return kernel(block_idx, x) + +shape = (512, 512) +x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape) +err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128)) +err.throw() +ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128)) +diff = jnp.max(jnp.abs(result - ref)) +print("Error |result - lax.dynamic_slice| =", diff) +``` + ++++ {"id": "K2dod4lkoifa"} + +## Sparse Kernels: Representing Sparse Data + +Before we dive into implementing sparse kernels, let's first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of `(block_index, block_data)` pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix. + +The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements. + +![block_coo](../../_static/pallas/sparse/block_coo.svg) + +We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis. + +```{code-cell} +:id: 1gLiSvgIYUEx + +def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32): + """Returns a sampled matrix and its block-sparse representation. + + Args: + key: RNG Key. + M: Major array dimension. + N: Minor array dimension. + blk_M: Block size along M dimension. + blk_N: Block size along N dimension. + p: Probability that a block will be non-zero. + dtype: dtype of the sampled matrix. + + Returns: + dense_mat: A (M, N) dense sampled array. + block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing + the non-zero blocks of the matrix. + indices_i: A (num_blocks,) array of block indices for the first axis. + indices_j: A (num_blocks,) array of block indices for the second axis. + """ + mask_key, blocks_key = jax.random.split(key) + num_blocks = (M // blk_M, N // blk_N) + # We first sample a block mask, denoting which blocks are nonzero. + block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks) + num_blocks = jnp.sum(block_mask) + indices = jnp.where(block_mask) + # For each non-zero block, we sample a block of random values. + block_data = jax.random.uniform(blocks_key, + shape=(num_blocks, blk_M, blk_N), + dtype=dtype) + # For checking purposes, create the dense version of the sparse matrix. + dense_mat = jnp.zeros((M, N), dtype=dtype) + for blk in range(num_blocks): + idx_i = indices[0][blk] + idx_j = indices[1][blk] + slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M) + slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N) + dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk]) + return dense_mat, block_data, indices[0], indices[1] +``` + ++++ {"id": "eFyoZSTOH9Fk"} + +## Example: Sparse @ Dense Matrix Multiplication + +In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output. + +We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram: + +![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg) + +It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct. + +```{code-cell} +--- +executionInfo: + elapsed: 673 + status: ok + timestamp: 1725919879291 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: WfyV2WWhjsyA +outputId: fa4d4fff-bc6b-4dc9-ac14-63276ca14131 +--- +M = N = K = 16384 +blk_M = blk_N = blk_K = 512 + + +def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs. + x_ref, y_ref, _, o_ref, # Kernel inputs. + accum_scratch, + ): + """A DSD (Dense = Sparse @ Dense) matmul kernel.""" + del idxs_k_ref + blk_idx = pl.program_id(1) + is_start = blk_idx == 0 + changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)]) + @pl.when(is_start | changed_blocks) + def _(): + accum_scratch[...] = jnp.zeros_like(accum_scratch) + accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32) + + next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)]) + is_end = blk_idx == (num_blocks - 1) + @pl.when(is_end | next_block_change) + def _(): + o_ref[...] = accum_scratch[...].astype(o_ref.dtype) + + +def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k): + del j, blk_idxs_i, blk_idxs_k + return (blk_idx, 0, 0) +def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k): + del blk_idxs_i + return (blk_idxs_k[blk_idx], j) +def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k): + del blk_idxs_k + return (blk_idxs_i[blk_idx], j) + +(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat( + jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16) +num_blocks = X_blocks.shape[0] +Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16) +zeros = jnp.zeros((M, N), dtype=jnp.bfloat16) +out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=2, + # Note that while num_blocks is static here, Pallas does support + # dynamic grid sizes. + grid=(M // blk_M, num_blocks), + in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map), + pl.BlockSpec((blk_K, blk_N), y_map), + # Placeholder for a zeros-array used by input_output_aliases. + pl.BlockSpec((blk_M, blk_N), o_map), + ], + out_specs=pl.BlockSpec((blk_M, blk_N), o_map), + scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)] +) +kernel = pl.pallas_call( + dsd_kernel, + grid_spec=grid_spec, + out_shape=out_shape, + # We use input-output aliases to zero-out o_ref for blocks that we never + # visit. By passing in an array of zeros we avoid having o_ref start with + # uninitialized values. + input_output_aliases={4: 0}, # Map zeros to o_ref. +) +args = (indices_i, indices_k, X_blocks, Y, zeros) +result = kernel(*args) + +ref = X_dense @ Y +diff = jnp.abs(ref - result) +print('mean |result - ref|:', jnp.mean(diff)) +``` + ++++ {"id": "2KDgPKF2tUjq"} + +We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5e chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor. + +There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM: +- Using `dtype=jnp.bfloat16` is critical for performance since it reduces memory bandwidth by half. +- Using larger block sizes also helps, since matrix multiply is an $O(N^3)$ compute and $O(N^2)$ memory operation. As $N$ grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully. + +```{code-cell} +--- +executionInfo: + elapsed: 6576 + status: ok + timestamp: 1725919886762 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: CkzjqnekpZbx +outputId: 1ae9031e-705a-4d05-f8b9-d09623918300 +--- +# Benchmark Sparse Pallas kernel vs reference JAX implementation + +def benchmark(f, ntrials: int = 100): + def run(*args, **kwargs): + # Compile function first + jax.block_until_ready(f(*args, **kwargs)) + # Time function + result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)), + number=ntrials) + time = result / ntrials + return time + return run + + +n_trials = 100 + +pallas_impl = lambda *args: kernel(*args) +time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros) +print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) + +ref_impl = jax.jit(lambda x, y: x @ y) +time = benchmark(ref_impl, n_trials)(X_dense, Y) +print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) +``` + ++++ {"id": "Q1KKd5vTCwnB"} + +## Sparse Access Patterns on Dense Data + +In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`). + +A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/). + +The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format. + +![prefetch_map](../../_static/pallas/sparse/prefetch_map.svg) + +*Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.* + +Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in the `index_map` function of the BlockSpec. + +```python +def mask_index_map(prefetch_map, i, j, ...): + next_nonzero_block = prefetch_map[i, j] + return (next_nonzero_block, 0, 0) +``` + +We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps. + ++++ {"id": "ii7rzL5YIA8-"} + +## Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask + ++++ {"id": "ecjiqWfA2RlV"} + +In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch map to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs. + +As we will be working with a sparse mask, we will begin by implementing a function that converts an `N x M` mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, our `sparsify_mask` function computes: +- A `block_mask` of shape `(num_N_blocks, num_M_blocks)` indicating if a block is all-zeros (value `0`) or contains non-zero elements (value `1`). If the `block_mask` has a value of 0 we can skip computing the block in the kernel. +- A `prefetch_mask` array of shape `(num_N_blocks, num_M_blocks)` consisting of indices into `mask_data` for the next non-zero block. +- A `prefetch_i` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `i` index of the mask. +- A `prefetch_j` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `j` index of the mask. +- A `mask_data` array of shape `(num_blocks, blk_N, blk_M)` containing data for non-zero blocks of the mask. + +```{code-cell} +:id: 19zGcliL2SJy + +def sparsify_mask(mask: jax.Array, + block_shape: tuple[int, int]): + """Preprocesses a mask into a sparse reprentation. + + Args: + mask: A boolean array of shape [M, N] + block_shape: The size of a single block. + + Returns: + block_mask: A block_shape array of booleans indicating whether a block + is all-zeros (0) or contains non-zero elements (1). + prefetch_mask: A block_shape array of integers indicating the index of the + next non-zero block. + mask_data: A (num_blocks, block_shape) array containing + the data for non-zero blocks of the mask. + """ + M, N = mask.shape + bm, bn = block_shape + + block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype) + mask_types_finder = [] + mask_data = [] + mask_type_idxs = [] + + next_mask_type_idx = 0 + prefetch_mask = jnp.zeros_like(block_mask) + next_i = (M // bm) - 1 + next_j = (N // bn) - 1 + prefetch_i = jnp.zeros_like(block_mask) + prefetch_j = jnp.zeros_like(block_mask) + for i in range(M // bm, -1, -1): + for j in range(N // bn, -1, -1): + mask_block = mask[i * bm :(i + 1) * bm, + j * bn :(j + 1) * bn] + is_nonzero = jnp.any(mask_block) + if is_nonzero: + try: + type_index = mask_types_finder.index(str(mask_block)) + except ValueError: + type_index = len(mask_types_finder) + mask_types_finder.append(str(mask_block)) + mask_data.append(mask_block) + next_mask_type_idx = type_index + next_i = i + next_j = j + else: + type_index = -1 + mask_type_idxs.append(type_index) + block_mask = block_mask.at[i, j].set(is_nonzero) + prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx) + prefetch_i = prefetch_i.at[i, j].set(next_i) + prefetch_j = prefetch_j.at[i, j].set(next_j) + return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data) +``` + ++++ {"id": "w4b7ckKq67Xw"} + +In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over the `N`, `M`, and `K` dimensions. Within the kernel itself, we first check the `block_mask` to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result. + +```{code-cell} +--- +executionInfo: + elapsed: 5374 + status: ok + timestamp: 1725919713252 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: 4YQ9OmbTCSjT +outputId: 2d752609-34f2-4059-e8ba-4d80afe8cb26 +--- +M = N = K = 16384 +blk_M = blk_N = 512 +blk_K = 1024 + +def sparse_mask_matmul( + block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs. + x_ref, y_ref, mask_ref, o_ref, # Kernel inputs. + accum_scratch + ): + del prefetch_mask, prefetch_i, prefetch_j + i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2) + should_compute = block_mask_ref[i, j] != 0 + @pl.when(k == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + accum_scratch[...] = jnp.zeros_like(accum_scratch[...]) + + # We only compute the output for blocks with non-zero masks. + # Otherwise we skip the computation entirely. + @pl.when(should_compute) + def _(): + result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32) + accum_scratch[...] += result + @pl.when(k == pl.num_programs(2) - 1) + def _(): + o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype) + +X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16) +Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16) +mask = jnp.ones((M, N), dtype=jnp.int32) +mask = jnp.tril(mask) +block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N)) + +def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j): + del prefetch_mask, prefetch_j + # Zero-out the k index if the mask is zero, to avoid constantly fetching + # new blocks in the inner loop for blocks we are skipping. + k_fetch = (block_mask[i, j] != 0) * k + return (prefetch_i[i, j], k_fetch) + +def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j): + del prefetch_mask, prefetch_i + k_fetch = (block_mask[i, j] != 0) * k + return (k_fetch, prefetch_j[i, j]) + +def mask_map(i, j, k, block_mask, prefetch_mask, *_): + del k, block_mask + return (prefetch_mask[i, j], 0, 0) + +def o_map(i, j, k, *_): + del k + return (i, j) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=4, + grid=(M // blk_M, N // blk_N, K // blk_K), + in_specs=[pl.BlockSpec((blk_M, blk_K), x_map), + pl.BlockSpec((blk_K, blk_N), y_map), + pl.BlockSpec((1, blk_M, blk_N), mask_map)], + out_specs=pl.BlockSpec((blk_M, blk_N), o_map), + scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)] +) +kernel = pl.pallas_call( + sparse_mask_matmul, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16), +) +args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data) +result = kernel(*args) + +ref = mask * (X @ Y) +diff = jnp.abs(ref - result) +print('mean |result - ref|:', jnp.mean(diff)) +``` + ++++ {"id": "uutNGgjZGGhB"} + +Now let's compare performance versus a naive dense implementation. On TPU v5e, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs. + +We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are: +- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation. +- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger. + +```{code-cell} +--- +executionInfo: + elapsed: 8877 + status: ok + timestamp: 1725917397452 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: MAT9JjGNvsx8 +outputId: a32d56fb-a71b-4007-c6a5-e5270dcaa6cf +--- +n_trials = 100 + +pallas_impl = lambda *args: kernel(*args) +time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data) +print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) + +ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y)) +time = benchmark(ref_impl, n_trials)(mask, X, Y) +print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) +``` From 6a3736a1d78aec1770ac9bcdc9faf4e7fb7abdfc Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 19 Sep 2024 15:37:52 -0400 Subject: [PATCH 572/702] Add a note to the changelog about the new CPU thunks backend, enabled in 0.4.32. --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 659a8ee04db0..ee782d04a02c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,6 +120,12 @@ Note: This release was yanked from PyPi because of a data corruption bug on TPU. See the 0.4.33 release notes for more details. * Breaking changes + * This release of jaxlib switched to a new version of the CPU backend, which + should compile faster and leverage parallelism better. If you experience + any problems due to this change, you can temporarily enable the old CPU + backend by setting the environment variable + `XLA_FLAGS=--xla_cpu_use_thunk_runtime=false`. If you need to do this, + please file a JAX bug with instructions to reproduce. * Hermetic CUDA support is added. Hermetic CUDA uses a specific downloadable version of CUDA instead of the user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL From 63e7b7d364ec8a3c33dd2f673879e89311b78f30 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 19 Sep 2024 12:58:59 -0700 Subject: [PATCH 573/702] Remove some untested dynamic shapes paths (prep work for stackless). PiperOrigin-RevId: 676529297 --- jax/_src/interpreters/partial_eval.py | 103 +------------------------- jax/interpreters/partial_eval.py | 1 - 2 files changed, 3 insertions(+), 101 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index fc2214aaf29f..6bc6539b9262 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -168,11 +168,6 @@ def new_instantiated_literal(self, val) -> JaxprTracer: def new_instantiated_const(self, val) -> JaxprTracer: aval = get_aval(val) - if isinstance(aval, DShapedArray): - shape = [self.new_instantiated_const(d) - if isinstance(d, Tracer) and d._trace.level < self.level else d - for d in aval.shape] - aval = aval.update(shape=tuple(shape)) return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(val)) def new_arg(self, pval: PartialVal) -> JaxprTracer: @@ -258,15 +253,9 @@ def process_call(self, primitive, f, tracers, params): # which were unknown to the first call (corresponding to in_avals). # Wrap f to perform the partial evaluation and plumb out aux data. - if not config.dynamic_shapes.value: - f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False) - f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), - tuple(in_avals)) - else: - if f.in_type is None: - f = lu.annotate(f, tuple((a, True) for a in in_avals)) - f_, aux = trace_to_subjaxpr_nounits_dyn(f, self.main, tuple(in_knowns), - f.in_type, False) + f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False) + f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), + tuple(in_avals)) # Adjust parameters (e.g. donated_invars) for the call to be evaluated now. const_params = update_params(params, in_knowns, 0) @@ -569,92 +558,6 @@ def partial_eval_wrapper_nounits( out_knowns, out_avals, out_consts = partition_pvals(out_pvals) yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env) -@lu.transformation_with_aux -def trace_to_subjaxpr_nounits_dyn( - main: core.MainTrace, in_knowns: Sequence[bool], in_type: InputType, - instantiate: bool | Sequence[bool], - *in_consts: Any): - trace = main.with_cur_sublevel() - in_avals, which_explicit = unzip2(in_type) - - # To form input tracers from in_type, we need to first build ConstVar tracers - # for all axis sizes, so that we can then use those tracers in the shapes of - # avals for unknown inputs' tracers. We use ConstVar recipes for on-the-fly - # type agreement checking via get_referent. - in_consts_full: list[JaxprTracer | None] = [None] * len(in_type) - in_consts_iter, in_knowns_iter = iter(in_consts), iter(in_knowns) - for idx, (aval, explicit) in enumerate(in_type): - if explicit and next(in_knowns_iter): - constval = next(in_consts_iter) - if isinstance(aval, DShapedArray): - for i, d in enumerate(aval.shape): - if isinstance(d, DBIdx): - if in_consts_full[d.val] is None: - in_consts_full[d.val] = \ - JaxprTracer(trace, PartialVal.unknown(in_avals[d.val]), - ConstVar(constval.shape[i])) - assert core.same_referent(constval.shape[i], in_consts_full[d.val]) - shape = [in_consts_full[d.val] if type(d) is DBIdx else d - for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - in_consts_full[idx] = JaxprTracer(trace, PartialVal.unknown(aval), - ConstVar(constval)) - # Check that we covered all axis sizes with ConstVar tracers. - for idx, (aval, explicit) in enumerate(in_type): - if not explicit: assert in_consts_full[idx] is not None - if isinstance(aval, DShapedArray): - assert all(type(d) is not DBIdx or in_consts_full[d.val] is not None - for d in aval.shape) - - # Next, build tracers for all unknown inputs, using the in_consts_full list - # for axis size tracers when necessary. - in_tracers = [] - in_knowns_iter = iter(in_knowns) - for aval, explicit in in_type: - if explicit and not next(in_knowns_iter): - if isinstance(aval, DShapedArray): - shape = [in_consts_full[d.val] if type(d) is DBIdx else d - for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - tracer = JaxprTracer(trace, PartialVal.unknown(aval), LambdaBinding()) - in_tracers.append(tracer) - - # Merge in_consts and in_tracers and call wrapped fn with explicit arguments. - in_args = merge_lists(in_knowns, in_tracers, in_consts) - ans = yield in_args, {} - - # Instantiate outputs and build jaxpr. - if isinstance(instantiate, bool): - instantiate = [instantiate] * len(ans) - out_tracers = map(trace.full_raise, map(core.full_lower, ans)) - out_tracers = [trace.instantiate_const(trace.full_raise(t)) if inst else t - for inst, t in zip(instantiate, out_tracers)] - - # Collect known outputs. - out_knowns: list[bool] = [t.is_known() for t in out_tracers] - out_consts: list[Any] = [t.pval.get_known() for t in out_tracers - if t.is_known()] - - # Build the jaxpr. - out_tracers = [t for t in out_tracers if not t.is_known()] - jaxpr, res, env = tracers_to_jaxpr(in_tracers, out_tracers) - out_avals = [v.aval for v in jaxpr.outvars] - idx_map = {v: InDBIdx(i) - for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))} - out_type = [(a.update(shape=tuple(idx_map.get(d, d) for d in a.shape)) # type: ignore - if type(a) is DShapedArray else a, True) for a in out_avals] - - # Which residuals are just forwarded inputs? Check obj id, then prune. - id_map = {id(c.recipe.val): i for i, c in enumerate(in_consts_full) # type: ignore - if c is not None} - fwds: list[int | None] = [id_map.get(id(c)) for c in res] - res = tuple(c for c, fwd in zip(res, fwds) if fwd is None) - - del main, in_consts, trace, in_consts_iter, in_knowns_iter, in_consts_full, \ - in_tracers, in_args, ans, out_tracers, out_avals - yield (*out_consts, *res), (fwds, out_knowns, tuple(out_type), jaxpr, env) - - custom_partial_eval_rules: dict[Primitive, Callable] = {} call_partial_eval_rules: dict[Primitive, Callable] = {} call_param_updaters: dict[Primitive, Callable] = {} diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 706f5a2fe253..3c63948bee63 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -91,7 +91,6 @@ trace_to_subjaxpr_dynamic as trace_to_subjaxpr_dynamic, trace_to_subjaxpr_dynamic2 as trace_to_subjaxpr_dynamic2, trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, - trace_to_subjaxpr_nounits_dyn as trace_to_subjaxpr_nounits_dyn, trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, tracers_to_jaxpr as tracers_to_jaxpr, trivial_ctx as trivial_ctx, From 815dc3ba633e6e840514e6d025a805b708b93869 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 19 Sep 2024 13:38:28 -0700 Subject: [PATCH 574/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a0cb79873742367204ad1386e9ca4fd815b3f860. PiperOrigin-RevId: 676545740 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a1e1fc505455..72bad324e0f0 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d2434f289c130c9d87c05a1e7086abf7922519fc" -XLA_SHA256 = "5264285791bda5c123cda881b44cba8fe404cf334d5843c54110d8678e319872" +XLA_COMMIT = "a0cb79873742367204ad1386e9ca4fd815b3f860" +XLA_SHA256 = "bcedc70cf3cdcc94159313365b15eb49e25e0d8a9d4713c290ead5a507d2b366" def repo(): tf_http_archive( From 47b177bd03ed19753ac70fc309e7c2f30e2efd7c Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 19 Sep 2024 14:34:56 -0700 Subject: [PATCH 575/702] [Mosaic TPU][NFC] Remove FailureOr in getNativeVregOrVmaskTypeImpl PiperOrigin-RevId: 676566796 --- .../tpu/transforms/apply_vector_layout.cc | 105 ++++++++---------- 1 file changed, 45 insertions(+), 60 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 7f1f9b63bdc8..9b21ec1803c8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -526,8 +526,7 @@ FailureOr appendConstant(RewriteContext &ctx, func::FuncOp func, return argument; } -// TODO(tlongeri): This function and others below never fail, remove FailureOr -FailureOr getNativeVregOrVmaskTypeImpl( +VectorType getNativeVregOrVmaskTypeImpl( Type elem_ty, const int8_t bitwidth, const std::array target_shape) { if (bitwidth == 32) { @@ -537,9 +536,8 @@ FailureOr getNativeVregOrVmaskTypeImpl( elem_ty); } -FailureOr getNativeVregOrVmaskType( - Type elem_ty, const int8_t layout_bitwidth, - const std::array target_shape) { +VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth, + const std::array target_shape) { int8_t bitwidth = elem_ty.getIntOrFloatBitWidth(); if (bitwidth == 1) { bitwidth = layout_bitwidth; @@ -549,8 +547,8 @@ FailureOr getNativeVregOrVmaskType( return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape); } -FailureOr getNativeVregType( - Type elem_ty, const std::array target_shape) { +VectorType getNativeVregType(Type elem_ty, + const std::array target_shape) { return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(), target_shape); } @@ -572,7 +570,7 @@ FailureOr maskOOB(RewriteContext &ctx, OpBuilder &builder, const VRegDataBounds &bounds, const TypedAttr neutral) { auto native_vreg_ty = - *getNativeVregType(value.getType().getElementType(), ctx.target_shape); + getNativeVregType(value.getType().getElementType(), ctx.target_shape); TPU_ASSERT_LOC(value.getLoc(), llvm::equal(value.getType().getShape(), native_vreg_ty.getShape())); if (bounds.isComplete(ctx.target_shape)) { @@ -709,10 +707,8 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op, in_vreg_arrays.emplace_back(std::move(tile_array)); } - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType out_vreg_ty, - getNativeVregOrVmaskType(out_ty.getElementType(), layout_out.bitwidth(), - ctx.target_shape)); + const VectorType out_vreg_ty = getNativeVregOrVmaskType( + out_ty.getElementType(), layout_out.bitwidth(), ctx.target_shape); NamedAttrList attributes(op.getAttrDictionary()); attributes.erase("in_layout"); @@ -783,9 +779,8 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape)); } - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType res_vreg_ty, - getNativeVregType(result_ty.getElementType(), ctx.target_shape)); + const VectorType res_vreg_ty = + getNativeVregType(result_ty.getElementType(), ctx.target_shape); if (layout_in.implicit_dim() != layout_out.implicit_dim()) { return op.emitOpError( "Not implemented: Change of implicit dim during the cast"); @@ -891,9 +886,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape)); } - FAILUREOR_ASSIGN_OR_RETURN( - VectorType res_vreg_ty, - getNativeVregType(result_ty.getElementType(), ctx.target_shape)); + VectorType res_vreg_ty = + getNativeVregType(result_ty.getElementType(), ctx.target_shape); if (layout_out.tiling() == ctx.target_shape) { const int packing = layout_out.packing(); output_vregs.Each([&](absl::Span idxs, Value *v) { @@ -1558,9 +1552,8 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op, } ImplicitLocOpBuilder builder(op.getLoc(), &op); - FAILUREOR_ASSIGN_OR_RETURN( - VectorType vreg_ty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + VectorType vreg_ty = + getNativeVregType(vty.getElementType(), ctx.target_shape); bool is_load_op = true; xla::Array tiles( @@ -1738,9 +1731,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(padded_lhs_rows, acc_vregs.dim(0) * layout_acc.tiling()[0]); TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType i32_vreg, - getNativeVregType(builder.getI32Type(), ctx.target_shape)); + const VectorType i32_vreg = + getNativeVregType(builder.getI32Type(), ctx.target_shape); auto getVmaskByPaddingEnd = [&](int64_t dim, int64_t padding, VectorType vreg_ty) { CHECK(dim == 0 || dim == 1); @@ -2012,9 +2004,8 @@ LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op, } } ImplicitLocOpBuilder builder(op.getLoc(), &op); - FAILUREOR_ASSIGN_OR_RETURN( - const auto native_vreg_ty, - getNativeVregType(out_ty.getElementType(), ctx.target_shape)); + const auto native_vreg_ty = + getNativeVregType(out_ty.getElementType(), ctx.target_shape); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array in_tiles, disassemble(builder, in_layout, bitcast_op.getInput(), ctx.target_shape)); @@ -2064,9 +2055,8 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op, SmallVector layout_shape = layout->tileArrayShape(vty.getShape(), ctx.target_shape); const int64_t num_vectors = ShapedType::getNumElements(layout_shape); - FAILUREOR_ASSIGN_OR_RETURN( - VectorType vreg_ty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + VectorType vreg_ty = + getNativeVregType(vty.getElementType(), ctx.target_shape); // We can not use disassemble here because the val is block argument. auto unrolled_op = builder.create( val.getLoc(), SmallVector(num_vectors, vreg_ty), val); @@ -2104,16 +2094,15 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, } ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); - FAILUREOR_ASSIGN_OR_RETURN( - VectorType res_vreg_ty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + + VectorType res_vreg_ty = + getNativeVregType(vty.getElementType(), ctx.target_shape); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array in_tiles, disassemble(builder, layout_in, op.getValue(), ctx.target_shape)); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType i32_vreg, - getNativeVregType(builder.getI32Type(), ctx.target_shape)); + const VectorType i32_vreg = + getNativeVregType(builder.getI32Type(), ctx.target_shape); // Some helper functions for math ops. auto mlirI32Const = [&](int d) { @@ -2518,9 +2507,9 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op, if (!layout_out.hasNativeTiling(ctx.target_shape)) { return iota_op.emitOpError("Not implemented: Only native tiling supported"); } - FAILUREOR_ASSIGN_OR_RETURN( - const auto native_vreg_ty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + + const auto native_vreg_ty = + getNativeVregType(vty.getElementType(), ctx.target_shape); if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { return op.emitOpError("Not implemented: Only 2D layouts supported"); } @@ -2807,9 +2796,8 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, auto load_op = cast(op); const auto memref_ty = getMemRefType(load_op.getBase()); const auto vty = cast(load_op.getResult().getType()); - FAILUREOR_ASSIGN_OR_RETURN( - VectorType target_ty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + VectorType target_ty = + getNativeVregType(vty.getElementType(), ctx.target_shape); if (vty.getRank() == 0) { op.emitOpError("Not implemented: scalar loads from vmem"); } @@ -3017,9 +3005,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op, } const VectorLayout &layout_out = *layouts_out.front(); DenseElementsAttr value = cast(constant_op.getValue()); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType target_vty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + const VectorType target_vty = + getNativeVregType(vty.getElementType(), ctx.target_shape); if (value.isSplat()) { if (layout_out.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) { return op.emitOpError( @@ -3270,9 +3257,9 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, // yields the vmask. auto src_i32 = builder.create( broadcast_op.getLoc(), builder.getI32Type(), broadcast_op.getSource()); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType native_vreg_ty, - getNativeVregType(src_i32.getType(), ctx.target_shape)); + + const VectorType native_vreg_ty = + getNativeVregType(src_i32.getType(), ctx.target_shape); auto tile_i32 = builder.create(native_vreg_ty, src_i32); auto zeros = builder.create( @@ -3313,13 +3300,13 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, loc, src_i32, builder.create(loc, src_i32, shift_width)); } - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType i32_vreg_ty, - getNativeVregType(src_i32.getType(), ctx.target_shape)); + + const VectorType i32_vreg_ty = + getNativeVregType(src_i32.getType(), ctx.target_shape); auto tile_i32 = builder.create(i32_vreg_ty, src_i32); - FAILUREOR_ASSIGN_OR_RETURN(const VectorType native_vreg_ty, - getNativeVregType(src_ty, ctx.target_shape)); + const VectorType native_vreg_ty = + getNativeVregType(src_ty, ctx.target_shape); auto tile = builder.create(native_vreg_ty, tile_i32); const xla::Array dst_tiles(dst_tiles_shape, tile); @@ -3329,9 +3316,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, broadcast_op.erase(); return success(); } else { - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType native_vreg_ty, - getNativeVregType(broadcast_op.getSourceType(), ctx.target_shape)); + const VectorType native_vreg_ty = + getNativeVregType(broadcast_op.getSourceType(), ctx.target_shape); auto tile = builder.create(native_vreg_ty, broadcast_op.getSource()); const xla::Array dst_tiles(dst_tiles_shape, tile); @@ -4816,7 +4802,7 @@ Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx, CHECK_EQ(bitwidth, cast(dst_vreg.getType()).getElementTypeBitWidth()); const VectorType vmask_ty = - *getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); + getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); auto sublanes_mask = builder.create( src_vreg.getLoc(), vmask_ty, ValueRange{boundIdxConst(dst_sl_idx), boundIdxConst(0)}, @@ -4864,9 +4850,8 @@ FailureOr> tpu_rotate_with_overflow( SmallVector dst_tiles_shape = layout_out.tileArrayImplicitShape(vty.getShape(), target_shape); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType res_vreg_ty, - getNativeVregType(vty.getElementType(), target_shape)); + const VectorType res_vreg_ty = + getNativeVregType(vty.getElementType(), target_shape); xla::Array out_tiles(dst_tiles_shape); From e209abfb2c50ba458da885f8d3c084236b11e58b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 19 Sep 2024 14:48:33 -0700 Subject: [PATCH 576/702] Improve the coverage of shard map tests for < 8 devices. Due to the skip in SetupModule before this change, we lost a lot of coverage on latest hardware. PiperOrigin-RevId: 676571965 --- tests/shard_map_test.py | 89 ++++++++++++++++------------------------- 1 file changed, 35 insertions(+), 54 deletions(-) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 3d9b567e2ef4..ae22eeca0cd4 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -56,9 +56,7 @@ # Helper for some tests. def create_inputs(a_sharding, b_sharding): - x, y, z = 2, 2, 2 # pylint: disable=invalid-name - devices = np.array(jax.devices()[:x * y * z]).reshape((x, y, z)) - mesh = Mesh(devices, axis_names=('x', 'y', 'z')) + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) b, e, f = 8, 8, 8 # pylint: disable=invalid-name m1 = jax.device_put( jnp.arange(b * e).reshape((b, e)), @@ -74,8 +72,6 @@ def create_inputs(a_sharding, b_sharding): def setUpModule(): _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - if len(jax.devices()) < 8: - raise unittest.SkipTest("tests require 8 devices") def tearDownModule(): _exit_stack.close() @@ -93,7 +89,7 @@ def identity(x): @jax.jit def fwd(a): c = shard_map( - lambda x: x, + identity, mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))(a) @@ -219,8 +215,7 @@ def fwd(a): self.assertAllClose(np.squeeze(c.addressable_data(2 * i + 1), -1), sums) def test_collective_permute(self): - devices = np.array(jax.devices()[:8]) # Take up to 8 devices - mesh = Mesh(devices, axis_names=('x')) + mesh = jtu.create_mesh((8,), 'x') a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None))) @@ -238,10 +233,7 @@ def fwd(a): self.assertAllClose(c[1, :], a[0, :]) def test_collective_permute_with_multiple_axis_names(self): - mesh = Mesh( - np.array(jax.devices()[:8]).reshape((2, 2, 2)), - axis_names=('x', 'y', 'z'), - ) + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) a = jax.device_put( jnp.arange(8 * 8).reshape((4, 16)), jax.sharding.NamedSharding(mesh, P('x', ('y', 'z'))), @@ -284,11 +276,7 @@ def fwd(a): ), ) def test_all_to_all(self, axis_name, mesh_axes): - devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))]) - mesh = Mesh( - devices.reshape(tuple(mesh_axes.values())), - axis_names=tuple(mesh_axes.keys()), - ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P(axis_name, None)), @@ -310,12 +298,7 @@ def fwd(a): assert (c == jnp.reshape(a.T, (1, 64))).all() def test_all_to_all_with_axis_index_groups(self): - mesh_axes = dict(x=4) - devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))]) - mesh = Mesh( - devices.reshape(tuple(mesh_axes.values())), - axis_names=tuple(mesh_axes.keys()), - ) + mesh = jtu.create_mesh((4,), ('x',)) a = jax.device_put( jnp.arange(4 * 4).reshape((4, 4)), jax.sharding.NamedSharding(mesh, P('x', None)), @@ -348,12 +331,7 @@ def fwd(a): self.assertAllClose(block, c.addressable_data(2 * i + j)) def test_all_to_all_grad(self): - mesh_axes = dict(x=4) - devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))]) - mesh = Mesh( - devices.reshape(tuple(mesh_axes.values())), - axis_names=tuple(mesh_axes.keys()), - ) + mesh = jtu.create_mesh((4,), 'x') a = jax.device_put( jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None)), @@ -382,7 +360,7 @@ def loss_and_grad(x): self.assertAllClose(grad, 2 * np.ones_like(a)) def test_eager_repr(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = None @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y')) @@ -396,7 +374,7 @@ def f(x): self.assertIn('at mesh coordinates', s) def test_jvp_basic(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) args = np.arange(4 * 4.).reshape(4, 4), @@ -404,7 +382,7 @@ def test_jvp_basic(self): jtu.check_grads(jax.jit(g), args, 2, ['fwd']) def test_linearize_basic(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) x = np.arange(4 * 4.).reshape(4, 4) @@ -418,7 +396,7 @@ def test_linearize_basic(self): self.assertAllClose(y_dot, y_dot_, check_dtypes=False) def test_linearize_basic_repres(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) @@ -432,7 +410,7 @@ def test_linearize_basic_repres(self): self.assertAllClose(y_dot, y_dot_, check_dtypes=False) def test_linearize_basic_repres_jit(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) @@ -446,7 +424,7 @@ def test_linearize_basic_repres_jit(self): self.assertAllClose(y_dot, y_dot_, check_dtypes=False) def test_replication_checker_eager(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = np.arange(8 * 8.).reshape(8, 8) def f(x): @@ -464,7 +442,7 @@ def g2(x): _ = g2(x) # doesn't crash def test_replication_checker_jit(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = np.arange(8 * 8.).reshape(8, 8) def f(x): @@ -494,7 +472,7 @@ def g(x): jtu.check_grads(g, (x,), modes=['fwd'], order=2) def test_eager_control_flow(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(2 * 2.).reshape(2, 2) def f(x): @@ -510,12 +488,12 @@ def g(x): self.assertAllClose(y, -x, check_dtypes=False) def test_outer_jit_detects_shard_map_mesh(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x')) _ = jax.jit(f)(jnp.array(2.0)) # doesn't crash def test_vmap_basic(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): @@ -525,7 +503,7 @@ def g(x): self.assertAllClose(y, 2 * x, check_dtypes=False) def test_vmap_basic_axis_name(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): @@ -535,7 +513,7 @@ def g(x): self.assertAllClose(y, 2 * x, check_dtypes=False) def test_vmap_basic_axis_name_reuse_mesh_name(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): @@ -545,7 +523,7 @@ def g(x): self.assertAllClose(y, 2 * x, check_dtypes=False) def test_tree_prefix_error(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(shard_map, mesh=mesh, in_specs=([P('x', 'y')],), out_specs=P('x', 'y')) def f(x): @@ -556,7 +534,7 @@ def f(x): f([x, x]) def test_rank_errors(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) def foo(): return {'hi': [3.]} @@ -577,7 +555,7 @@ def foo(): shard_map(foo, mesh=mesh, in_specs=P(None), out_specs=())(3.) def test_reverse_mode_ad(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @jax.jit @partial(shard_map, mesh=mesh, @@ -591,7 +569,7 @@ def f(x, y): def test_post_process(self): # JVPTrace.post_process_shard_map and JaxprTrace.post_process_shard_map - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) def f(x): @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) @@ -608,7 +586,7 @@ def g(y): @jtu.run_on_devices('gpu', 'tpu') def test_axis_index(self): - mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) + mesh = jtu.create_mesh((4,), 'x') @jax.jit @partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('x')) @@ -716,7 +694,7 @@ def f3(): jax.jit(f3)() def test_vmap_spmd_axis_name(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): @@ -731,7 +709,7 @@ def f(x): self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},)) def test_vmap_spmd_axis_name_pair(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) def f(x): @@ -747,7 +725,7 @@ def f(x): def test_nested_vmap_with_capture_spmd_axis_name(self): self.skipTest('https://github.com/google/jax/issues/23476') - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) def to_map_with_capture(x, y): @@ -1545,7 +1523,6 @@ def test_check_rep_false_grads(self): mesh = jtu.create_mesh((4,), ('heads',)) def f(q, k, v): - def body(q, k, v): return q * k[None, :] + v[None, :] @@ -1560,7 +1537,11 @@ def body(q, k, v): k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec)) v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec)) - jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2) + if jtu.device_under_test() == 'tpu': + rtol = 2e-2 + else: + rtol = 1e-2 + jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=rtol) def test_axis_env_extension_regression(self): def foo(x): @@ -1675,7 +1656,7 @@ def test_erf_rules(self): mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) # don't crash def test_error_for_variable_num_args(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) def f(*args): return args[0] @ args[1] @@ -1687,7 +1668,7 @@ def f(*args): shard_f(jnp.ones((8, 8)), jnp.ones((8, 8))) def test_custom_vjp_replication_error_message_hint(self): - mesh = Mesh(np.array(jax.devices()[:4]), ('i',)) + mesh = jtu.create_mesh((4,), 'i') @jax.custom_vjp def f(x): @@ -1710,7 +1691,7 @@ def g(x): def test_repeated_psum_allowed(self): # https://github.com/google/jax/issues/19175 - mesh = Mesh(jax.devices()[:4], ('i',)) + mesh = jtu.create_mesh((4,), 'i') @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) def g(x): From 7571b9e7f82c22c374b2de97d47583b0bcf5e49b Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 19 Sep 2024 23:31:40 +0000 Subject: [PATCH 577/702] custom_vjp: don't drop tangents just because they have a different dtype than the primal instead, drop them when primal_aval.to_tangent_aval().dtype == float0 TODO: don't do that either. we shouldn't drop the user's output on the floor; we should require that their rule produce a value of the correct float0 dtype, or else produce a special symbol that means "zero of whatever type I need" (and that symbol should probably be a None). but i'm not doing that TODO right now... --- jax/_src/custom_derivatives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 05ede08d219c..88be655a0ddd 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -783,7 +783,7 @@ def append(x, d): raise TypeError(msg.format(in_tree2, in_tree)) from None results = [] for kp, a, ct in zip(keypaths, in_avals, cts_in_flat): - if ct is zero or a != a.to_tangent_aval(): + if ct is zero or getattr(a.to_tangent_aval(), 'dtype') == dtypes.float0: results.append(Zero(a.to_tangent_aval())) elif type(ct) is SymbolicZero: if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval): From 1db47fd85dbb38a2c7bd3edec480a705653eaad6 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 19 Sep 2024 19:07:35 -0700 Subject: [PATCH 578/702] [Pallas] Minor cleanup of memory spaces. Also add ANY as a general memory space PiperOrigin-RevId: 676650904 --- jax/_src/pallas/core.py | 1 + jax/_src/pallas/mosaic/core.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 49 ++++++++++++++++--------- jax/experimental/pallas/__init__.py | 4 +++ tests/pallas/tpu_pallas_test.py | 56 ++++++++++++++--------------- 5 files changed, 67 insertions(+), 45 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 1a956de1f7a9..1c3fa10e9e92 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -258,6 +258,7 @@ class MemorySpace(enum.Enum): Each memory space will be translated to a device-specific memory type during lowering. """ + ANY = "any" # Unrestricted memory space (usually HBM) ERROR = "error" # Memory space for checkify errors. INDEX = "index" # Memory space for scalar prefetch arguments. diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 76166ae61963..4ff9d894da8f 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -79,7 +79,7 @@ class TPUCompilerParams(pallas_core.CompilerParams): device_type: str | None = None class TPUMemorySpace(enum.Enum): - ANY = "any" + ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY. VMEM = "vmem" SMEM = "smem" CMEM = "cmem" diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 775f0c1f8256..f120bbabf8a4 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -138,15 +138,29 @@ class LoweringRuleContext: replace = dataclasses.replace -def _memory_space_to_tpu_memspace(memory_space: MemorySpace | None - ) -> ir.Attribute: - if memory_space is None: - memory_space = VMEM - elif memory_space == pallas_core.MemorySpace.ERROR: - memory_space = SMEM - elif memory_space == pallas_core.MemorySpace.INDEX: - memory_space = SMEM - return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>") +def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None + ) -> TPUMemorySpace: + match memory_space: + case None: + # We pick VMEM as the default one when no memory space is + # specified + return TPUMemorySpace.VMEM + case pallas_core.MemorySpace.ANY: + # Map the general ANY memory space to TPU ANY memory space + return TPUMemorySpace.ANY + case pallas_core.MemorySpace.ERROR | pallas_core.MemorySpace.INDEX: + return TPUMemorySpace.SMEM + case TPUMemorySpace(): + # Leave the memory space unchanged + return memory_space + case _: + raise ValueError("Invalid memory space: {memory_space}") + + +def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None + ) -> ir.Attribute: + tpu_memory_space = _memory_space_to_tpu_memory_space(memory_space) + return ir.Attribute.parse(f"#tpu.memory_space<{tpu_memory_space}>") def _dtype_to_ir_type(dtype: jnp.dtype, is_kernel_boundary: bool = False) -> ir.Type: @@ -182,7 +196,7 @@ def aval_to_ir_type(aval, sem_type = ir.Type.parse("!tpu.semaphore") else: raise ValueError(f"Cannot allocate {aval.sem_type}.") - memspace = _memory_space_to_tpu_memspace(TPUMemorySpace.SEMAPHORE) + memspace = _memory_space_to_mosaic_attribute(TPUMemorySpace.SEMAPHORE) return ir.MemRefType.get((), sem_type, memory_space=memspace) if dtypes.issubdtype(aval.dtype, dtypes.prng_key): shape = aval.dtype._impl.key_shape @@ -190,13 +204,13 @@ def aval_to_ir_type(aval, memory_space = TPUMemorySpace.SMEM if memory_space != TPUMemorySpace.SMEM: raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}") - memspace = _memory_space_to_tpu_memspace(memory_space) + memspace = _memory_space_to_mosaic_attribute(memory_space) return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)), memory_space=memspace) if isinstance(aval, state.AbstractRef): if shape is None: shape = aval.shape - memspace = _memory_space_to_tpu_memspace(memory_space) + memspace = _memory_space_to_mosaic_attribute(memory_space) return ir.MemRefType.get(shape, _dtype_to_ir_type(aval.dtype, is_kernel_boundary=True), memory_space=memspace) @@ -524,7 +538,9 @@ def lower_jaxpr_to_module( for i, bm in enumerate(grid_mapping.block_mappings): func_name = f"transform_{i}" # ANY operands don't support windowing and require empty window_params. - if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY: + tpu_memory_space = _memory_space_to_tpu_memory_space( + bm.block_aval.memory_space) + if tpu_memory_space == tpu_core.TPUMemorySpace.ANY: # We checked above that the block does not require windowing. window_params.append(ir.DictAttr.get()) continue @@ -2560,7 +2576,7 @@ def _bitcast_convert_type_lowering_rule( def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value: if isinstance(aval, pallas_core.AbstractMemoryRef): - memspace = _memory_space_to_tpu_memspace(aval.memory_space) + memspace = _memory_space_to_mosaic_attribute(aval.memory_space) if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype): assert aval.memory_space == TPUMemorySpace.SEMAPHORE memref_type = aval_to_ir_type(aval, memory_space=TPUMemorySpace.SEMAPHORE) @@ -2905,9 +2921,10 @@ def body(*args): out = pallas_call.pallas_call( body, out_shape=in_avals, - in_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)] + in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)] * len(in_avals), - out_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)] + out_specs=[pallas_core.BlockSpec( + memory_space=pallas_core.MemorySpace.ANY)] * len(in_avals), input_output_aliases={i: i for i in range(len(in_avals))}, grid=((core_axis_name, num_cores),), diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index c81b509d70cf..bb733e794c5f 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -28,6 +28,7 @@ from jax._src.pallas.core import no_block_spec from jax._src.pallas.core import Unblocked from jax._src.pallas.core import unblocked +from jax._src.pallas.core import MemorySpace from jax._src.pallas.pallas_call import pallas_call from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.primitives import atomic_add @@ -57,5 +58,8 @@ from jax._src.state.indexing import Slice from jax._src.state.primitives import broadcast_to +ANY = MemorySpace.ANY + + _register_deprecation("pallas-block-spec-order") del _register_deprecation diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index e100a5a39e49..5bcf0964419f 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -726,8 +726,8 @@ def kernel(x_ref, y_ref): x = jnp.ones((8, 128), dtype=jnp.float32) y = self.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) jax.block_until_ready(y) @@ -1043,9 +1043,9 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), ], scratch_shapes=[pltpu.SemaphoreType.DMA], @@ -1069,9 +1069,9 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x) @@ -1088,9 +1088,9 @@ def body(sem): self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) @@ -1105,9 +1105,9 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x) @@ -1126,9 +1126,9 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), grid=(2,), )(x) @@ -1147,7 +1147,7 @@ def body(x_ref, sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) @@ -1164,7 +1164,7 @@ def body(y_ref, sem): x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) @@ -1184,8 +1184,8 @@ def body(x_ref, y_ref, sem): x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) @@ -1202,7 +1202,7 @@ def body(x_ref, sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) @@ -1223,7 +1223,7 @@ def body(y_ref, sem): in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32), )(x) expected = jnp.zeros_like(x[0:1, 0:2]).at[0, 1].set(x[4, 4]) @@ -1261,7 +1261,7 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), @@ -1284,7 +1284,7 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), @@ -1313,7 +1313,7 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((3, 16, 128), jnp.float32), @@ -1340,7 +1340,7 @@ def body(sem): _ = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), @@ -1410,7 +1410,7 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], scratch_shapes=[pltpu.SemaphoreType.REGULAR, pltpu.SemaphoreType.DMA], @@ -1436,9 +1436,9 @@ def kernel(index, x, y, sem): num_scalar_prefetch=1, in_specs=[ pl.BlockSpec( - memory_space=pltpu.TPUMemorySpace.ANY)], + memory_space=pl.ANY)], out_specs=pl.BlockSpec( - memory_space=pltpu.TPUMemorySpace.ANY), + memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA], ), out_shape=jax.ShapeDtypeStruct(x.shape[1:], dtype), @@ -1480,7 +1480,7 @@ def test_kernel(x_ref, grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], scratch_shapes=( [pltpu.SemaphoreType.DMA(2,)] @@ -1949,9 +1949,9 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 2, 128), jnp.int16), )(x) expected = ( From 7f3a90c63b9adfcf3877f0ae0550397ba9b24769 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Fri, 20 Sep 2024 07:31:35 -0700 Subject: [PATCH 579/702] Change references in setup.py and utilities to reference the JAX repo move to the JAX-ML org PiperOrigin-RevId: 676838502 --- .github/workflows/self_hosted_runner_utils/setup_runner.sh | 4 ++-- jax_plugins/cuda/plugin_setup.py | 4 ++-- jax_plugins/cuda/setup.py | 2 +- jax_plugins/rocm/setup.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/self_hosted_runner_utils/setup_runner.sh b/.github/workflows/self_hosted_runner_utils/setup_runner.sh index 79c1224c13cc..ef501784b45e 100755 --- a/.github/workflows/self_hosted_runner_utils/setup_runner.sh +++ b/.github/workflows/self_hosted_runner_utils/setup_runner.sh @@ -31,7 +31,7 @@ runner_token="$3" # - sets empty string as default to avoid unbound variable error from set -u jax_repo_url="${4-}" if [ -z "${jax_repo_url}" ]; then - jax_repo_url="https://github.com/google/jax" + jax_repo_url="https://github.com/jax-ml/jax" fi # Create `runner` user. This user won't have sudo access unless you ssh into the @@ -67,7 +67,7 @@ cd ~/ git clone ${jax_repo_url} -# Based on https://github.com/google/jax/settings/actions/runners/new +# Based on https://github.com/jax-ml/jax/settings/actions/runners/new # (will be 404 for github users with insufficient repo permissions) mkdir actions-runner && cd actions-runner curl -o actions-runner-linux-x64.tar.gz -L ${actions_runner_download} diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index 468c0c48709f..8e99907d7078 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -66,13 +66,13 @@ def has_ext_modules(self): # dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages # do not have a version constraint on their dependencies, so the # package doesn't get upgraded even though not doing that can cause - # problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196) + # problems (https://github.com/jax-ml/jax/issues/18027#issuecomment-1756305196) # Until NVIDIA add version constraints, add a version constraint # here. "nvidia-nvjitlink-cu12>=12.1.105", ], }, - url="https://github.com/google/jax", + url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", diff --git a/jax_plugins/cuda/setup.py b/jax_plugins/cuda/setup.py index 96ce577fc643..1ce555978dac 100644 --- a/jax_plugins/cuda/setup.py +++ b/jax_plugins/cuda/setup.py @@ -48,7 +48,7 @@ def load_version_module(pkg_path): author_email="jax-dev@google.com", packages=packages, install_requires=[], - url="https://github.com/google/jax", + url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", diff --git a/jax_plugins/rocm/setup.py b/jax_plugins/rocm/setup.py index 8782676ce9a2..d131e732c91a 100644 --- a/jax_plugins/rocm/setup.py +++ b/jax_plugins/rocm/setup.py @@ -48,7 +48,7 @@ def load_version_module(pkg_path): author_email="Ruturaj.Vaidya@amd.com", packages=packages, install_requires=[], - url="https://github.com/google/jax", + url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", From afaa3bf43c10304e97e6fd041f22882a8f91ee3d Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 20 Sep 2024 07:34:05 -0700 Subject: [PATCH 580/702] Port GPU kernels for SVD to the FFI. Unlike the other GPU linear algebra kernels that I've ported so far, this one isn't straightforward to implement as a single kernel, and while it does support lowering without access to a GPU (no more descriptor!), it only supports dynamics shapes in the batch dimensions. There are two main technical challenges: 1. The main `gesvd` kernels in cuSolver/hipSolver only support matrices with shape `(m, n)` with `m >= n`. This means that we need to transpose the inputs and outputs as part of the lowering rule when `m < n`. (Note: we actually just use C layouts instead of Fortran layouts to implement this case.) While this could be handled in the kernel, this seemed like a lot of work for somewhat limited benefit, and it would probably have performance implications. 2. The `gesvd` and `gesvdj` kernels return `V^H` and `V` respectively, and the batched version of `gesvdj` doesn't support `full_matrices=False`. This means that we need logic in the lowering rule to handle transposition and slicing. This makes it hard to have the algorithm selection be a parameter to the kernel. Another note: cuSolver has a 64-bit implementation of the SVD, and we always use that implementation on the CUDA backend. The 32-bit interface is included for ROCM support, and I have tested it manually. This was a feature request from https://github.com/jax-ml/jax/issues/23413. PiperOrigin-RevId: 676839182 --- jaxlib/gpu/solver.cc | 7 +- jaxlib/gpu/solver_interface.cc | 85 +++++++ jaxlib/gpu/solver_interface.h | 43 ++++ jaxlib/gpu/solver_kernels_ffi.cc | 393 +++++++++++++++++++++++++++++-- jaxlib/gpu/solver_kernels_ffi.h | 5 + jaxlib/gpu/vendor.h | 39 +++ 6 files changed, 551 insertions(+), 21 deletions(-) diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index c65ad088af21..38936ee497cf 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -20,8 +20,8 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_kernels.h" @@ -481,6 +481,11 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi); dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); + dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi); + +#ifdef JAX_GPU_CUDA + dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi); +#endif // JAX_GPU_CUDA return dict; } diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc index 3c8282ec603a..4d1af3c50d76 100644 --- a/jaxlib/gpu/solver_interface.cc +++ b/jaxlib/gpu/solver_interface.cc @@ -232,6 +232,91 @@ JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk); JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk); #undef JAX_GPU_DEFINE_SYRK +// Singular Value Decomposition: gesvd + +#define JAX_GPU_DEFINE_GESVD(Type, Name) \ + template <> \ + absl::StatusOr GesvdBufferSize(gpusolverDnHandle_t handle, \ + signed char job, int m, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR( \ + JAX_AS_STATUS(Name##_bufferSize(handle, job, job, m, n, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Gesvd(gpusolverDnHandle_t handle, signed char job, int m, \ + int n, Type *a, RealType::value *s, Type *u, \ + Type *vt, Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS(Name(handle, job, job, m, n, a, m, s, u, m, vt, n, \ + workspace, lwork, /*rwork=*/nullptr, info)); \ + } + +JAX_GPU_DEFINE_GESVD(float, gpusolverDnSgesvd); +JAX_GPU_DEFINE_GESVD(double, gpusolverDnDgesvd); +JAX_GPU_DEFINE_GESVD(gpuComplex, gpusolverDnCgesvd); +JAX_GPU_DEFINE_GESVD(gpuDoubleComplex, gpusolverDnZgesvd); +#undef JAX_GPU_DEFINE_GESVD + +#ifdef JAX_GPU_CUDA + +#define JAX_GPU_DEFINE_GESVDJ(Type, Name) \ + template <> \ + absl::StatusOr GesvdjBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, \ + int n, gpuGesvdjInfo_t params) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \ + handle, job, econ, m, n, /*a=*/nullptr, /*lda=*/m, /*s=*/nullptr, \ + /*u=*/nullptr, /*ldu=*/m, /*v=*/nullptr, /*ldv=*/n, &lwork, params))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Gesvdj( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, \ + int n, Type *a, RealType::value *s, Type *u, Type *v, \ + Type *workspace, int lwork, int *info, gpuGesvdjInfo_t params) { \ + return JAX_AS_STATUS(Name(handle, job, econ, m, n, a, m, s, u, m, v, n, \ + workspace, lwork, info, params)); \ + } + +JAX_GPU_DEFINE_GESVDJ(float, gpusolverDnSgesvdj); +JAX_GPU_DEFINE_GESVDJ(double, gpusolverDnDgesvdj); +JAX_GPU_DEFINE_GESVDJ(gpuComplex, gpusolverDnCgesvdj); +JAX_GPU_DEFINE_GESVDJ(gpuDoubleComplex, gpusolverDnZgesvdj); +#undef JAX_GPU_DEFINE_GESVDJ + +#define JAX_GPU_DEFINE_GESVDJ_BATCHED(Type, Name) \ + template <> \ + absl::StatusOr GesvdjBatchedBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \ + gpuGesvdjInfo_t params, int batch) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, job, m, n, /*a=*/nullptr, /*lda=*/m, \ + /*s=*/nullptr, /*u=*/nullptr, /*ldu=*/m, \ + /*v=*/nullptr, /*ldv=*/n, &lwork, params, batch))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status GesvdjBatched( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \ + Type *a, RealType::value *s, Type *u, Type *v, Type *workspace, \ + int lwork, int *info, gpuGesvdjInfo_t params, int batch) { \ + return JAX_AS_STATUS(Name(handle, job, m, n, a, m, s, u, m, v, n, \ + workspace, lwork, info, params, batch)); \ + } + +JAX_GPU_DEFINE_GESVDJ_BATCHED(float, gpusolverDnSgesvdjBatched); +JAX_GPU_DEFINE_GESVDJ_BATCHED(double, gpusolverDnDgesvdjBatched); +JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuComplex, gpusolverDnCgesvdjBatched); +JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched); +#undef JAX_GPU_DEFINE_GESVDJ_BATCHED + +#endif // JAX_GPU_CUDA + } // namespace solver } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h index 5072be98489f..336480e2e13b 100644 --- a/jaxlib/gpu/solver_interface.h +++ b/jaxlib/gpu/solver_interface.h @@ -165,6 +165,49 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevd); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk); #undef JAX_GPU_SOLVER_Syrk_ARGS +// Singular Value Decomposition: gesvd + +#define JAX_GPU_SOLVER_GesvdBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, signed char job, int m, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdBufferSize); +#undef JAX_GPU_SOLVER_GesvdBufferSize_ARGS + +#define JAX_GPU_SOLVER_Gesvd_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, signed char job, int m, int n, Type *a, Real *s, \ + Type *u, Type *vt, Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvd); +#undef JAX_GPU_SOLVER_Gesvd_ARGS + +#ifdef JAX_GPU_CUDA + +#define JAX_GPU_SOLVER_GesvdjBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \ + gesvdjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBufferSize); +#undef JAX_GPU_SOLVER_GesvdjBufferSize_ARGS + +#define JAX_GPU_SOLVER_Gesvdj_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \ + Type *a, Real *s, Type *u, Type *v, Type *workspace, \ + int lwork, int *info, gesvdjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj); +#undef JAX_GPU_SOLVER_Gesvdj_ARGS + +#define JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \ + gpuGesvdjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBatchedBufferSize); +#undef JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS + +#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \ + Real *s, Type *u, Type *v, Type *workspace, int lwork, \ + int *info, gpuGesvdjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched); +#undef JAX_GPU_SOLVER_GesvdjBatched_ARGS + +#endif // JAX_GPU_CUDA + #undef JAX_GPU_SOLVER_EXPAND_DEFINITION } // namespace solver diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index e3f63234f538..9191a0ff8dec 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -33,6 +33,14 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" +#if JAX_GPU_64_BIT +#include +#endif + +#ifdef JAX_GPU_CUDA +#include +#endif + #define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) @@ -56,26 +64,32 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, return static_cast(maybe_workspace.value()); } -#define SOLVER_DISPATCH_IMPL(impl, ...) \ - if (dataType == ffi::F32) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::F64) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::C64) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::C128) { \ - return impl(__VA_ARGS__); \ +#define SOLVER_DISPATCH_IMPL(impl, ...) \ + switch (dataType) { \ + case ffi::F32: \ + return impl(__VA_ARGS__); \ + case ffi::F64: \ + return impl(__VA_ARGS__); \ + case ffi::C64: \ + return impl(__VA_ARGS__); \ + case ffi::C128: \ + return impl(__VA_ARGS__); \ + default: \ + break; \ } -#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \ - if (dataType == ffi::F32) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::F64) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::C64) { \ - return impl(__VA_ARGS__); \ - } else if (dataType == ffi::C128) { \ - return impl(__VA_ARGS__); \ +#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \ + switch (dataType) { \ + case ffi::F32: \ + return impl(__VA_ARGS__); \ + case ffi::F64: \ + return impl(__VA_ARGS__); \ + case ffi::C64: \ + return impl(__VA_ARGS__); \ + case ffi::C128: \ + return impl(__VA_ARGS__); \ + default: \ + break; \ } // LU decomposition: getrf @@ -445,8 +459,8 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, } ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, - SyevdAlgorithm algorithm, bool lower, ffi::AnyBuffer a, - ffi::Result out, + SyevdAlgorithm algorithm, bool lower, + ffi::AnyBuffer a, ffi::Result out, ffi::Result w, ffi::Result> info) { auto dataType = a.element_type(); @@ -561,6 +575,345 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch, .Ret() // c_out ); +// Singular Value Decomposition: gesvd + +#if JAX_GPU_64_BIT + +ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool full_matrices, + bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result vt, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N'; + + auto dataType = a.element_type(); + gpuDataType aType, sType; + switch (dataType) { + case ffi::F32: + aType = GPU_R_32F; + sType = GPU_R_32F; + break; + case ffi::F64: + aType = GPU_R_64F; + sType = GPU_R_64F; + break; + case ffi::C64: + aType = GPU_C_32F; + sType = GPU_R_32F; + break; + case ffi::C128: + aType = GPU_C_64F; + sType = GPU_R_64F; + break; + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType))); + } + + gpusolverDnParams_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms)); + std::unique_ptr + params_cleanup( + params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); }); + + size_t workspaceInBytesOnDevice, workspaceInBytesOnHost; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgesvd_bufferSize( + handle.get(), params, job, job, m, n, aType, /*a=*/nullptr, m, sType, + /*s=*/nullptr, aType, /*u=*/nullptr, m, aType, /*vt=*/nullptr, n, aType, + &workspaceInBytesOnDevice, &workspaceInBytesOnHost)); + + auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice); + if (!maybe_workspace.has_value()) { + return ffi::Error(ffi::ErrorCode::kResourceExhausted, + "Unable to allocate device workspace for gesvd"); + } + auto workspaceOnDevice = maybe_workspace.value(); + auto workspaceOnHost = + std::unique_ptr(new char[workspaceInBytesOnHost]); + + const char* a_data = static_cast(a.untyped_data()); + char* out_data = static_cast(out->untyped_data()); + char* s_data = static_cast(s->untyped_data()); + char* u_data = static_cast(u->untyped_data()); + char* vt_data = static_cast(vt->untyped_data()); + int* info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + size_t out_step = m * n * ffi::ByteWidth(dataType); + size_t s_step = n * ffi::ByteWidth(ffi::ToReal(dataType)); + size_t u_step = 0; + size_t vt_step = 0; + if (compute_uv) { + u_step = m * (full_matrices ? m : n) * ffi::ByteWidth(dataType); + vt_step = n * n * ffi::ByteWidth(dataType); + } + for (auto i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgesvd( + handle.get(), params, job, job, m, n, aType, out_data, m, sType, s_data, + aType, u_data, m, aType, vt_data, n, aType, workspaceOnDevice, + workspaceInBytesOnDevice, workspaceOnHost.get(), workspaceInBytesOnHost, + info_data)); + out_data += out_step; + s_data += s_step; + u_data += u_step; + vt_data += vt_step; + ++info_data; + } + + return ffi::Error::Success(); +} + +#else + +template +ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + bool full_matrices, bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result vt, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N'; + + FFI_ASSIGN_OR_RETURN(int lwork, + solver::GesvdBufferSize(handle.get(), job, m, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "gesvd")); + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto s_data = static_cast::value*>(s->untyped_data()); + auto u_data = compute_uv ? static_cast(u->untyped_data()) : nullptr; + auto vt_data = compute_uv ? static_cast(vt->untyped_data()) : nullptr; + auto info_data = info->typed_data(); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + } + + int out_step = m * n; + int u_step = compute_uv ? m * (full_matrices ? m : n) : 0; + int vt_step = compute_uv ? n * n : 0; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS( + solver::Gesvd(handle.get(), job, m, n, out_data, s_data, u_data, + vt_data, workspace, lwork, info_data)); + out_data += out_step; + s_data += n; // n is always less than m because of the logic in dispatch. + u_data += u_step; + vt_data += vt_step; + ++info_data; + } + return ffi::Error::Success(); +} + +#endif // JAX_GPU_64_BIT + +ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool full_matrices, bool compute_uv, bool transposed, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result vt, + ffi::Result> info) { + auto dataType = a.element_type(); + if (out->element_type() != dataType || + s->element_type() != ffi::ToReal(dataType) || + u->element_type() != dataType || vt->element_type() != dataType) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to gesvd must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + int64_t m = transposed ? cols : rows; + int64_t n = transposed ? rows : cols; + if (n > m) { + return ffi::Error::InvalidArgument( + "The GPU implementation of gesvd requires that the input matrix be m x " + "n with m >= n"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gesvd")); + FFI_RETURN_IF_ERROR(CheckShape(s->dimensions(), {batch, n}, "s", "gesvd")); + if (compute_uv) { + if (full_matrices) { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, m, m}, "u", "gesvd")); + } else { + if (transposed) { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, n, m}, "u", "gesvd")); + } else { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, m, n}, "u", "gesvd")); + } + } + FFI_RETURN_IF_ERROR( + CheckShape(vt->dimensions(), {batch, n, n}, "vt", "gesvd")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvd")); + +#if JAX_GPU_64_BIT + return Gesvd64Impl(batch, m, n, stream, scratch, full_matrices, compute_uv, a, + out, s, u, vt, info); +#else + SOLVER_DISPATCH_IMPL(GesvdImpl, batch, m, n, stream, scratch, full_matrices, + compute_uv, a, out, s, u, vt, info); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType))); +#endif +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdFfi, GesvdDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("full_matrices") + .Attr("compute_uv") + .Attr("transposed") + .Arg() // a + .Ret() // out + .Ret() // s + .Ret() // u + .Ret() // vt + .Ret>() // info +); + +#ifdef JAX_GPU_CUDA + +template +ffi::Error GesvdjImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + bool full_matrices, bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result v, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverEigMode_t job = + compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : GPUSOLVER_EIG_MODE_NOVECTOR; + int econ = full_matrices ? 0 : 1; + + gpuGesvdjInfo_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateGesvdjInfo(¶ms)); + std::unique_ptr params_cleanup( + params, [](gpuGesvdjInfo_t p) { gpusolverDnDestroyGesvdjInfo(p); }); + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto s_data = static_cast::value*>(s->untyped_data()); + auto u_data = static_cast(u->untyped_data()); + auto v_data = static_cast(v->untyped_data()); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + if (batch <= 1 || batch > std::numeric_limits::max() || m > 32 || + n > 32 || econ) { + FFI_ASSIGN_OR_RETURN(int lwork, solver::GesvdjBufferSize( + handle.get(), job, econ, m, n, params)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "gesvdj")); + int k = std::min(m, n); + int out_step = m * n; + int u_step = m * (full_matrices ? m : k); + int v_step = n * (full_matrices ? n : k); + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Gesvdj( + handle.get(), job, econ, m, n, out_data, s_data, u_data, v_data, + workspace, lwork, info_data, params)); + out_data += out_step; + s_data += k; + u_data += u_step; + v_data += v_step; + ++info_data; + } + } else { + FFI_ASSIGN_OR_RETURN(int lwork, solver::GesvdjBatchedBufferSize( + handle.get(), job, m, n, params, + static_cast(batch))); + FFI_ASSIGN_OR_RETURN( + auto workspace, AllocateWorkspace(scratch, lwork, "gesvdj_batched")); + FFI_RETURN_IF_ERROR_STATUS(solver::GesvdjBatched( + handle.get(), job, m, n, out_data, s_data, u_data, v_data, workspace, + lwork, info_data, params, static_cast(batch))); + } + return ffi::Error::Success(); +} + +ffi::Error GesvdjDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool full_matrices, bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result v, + ffi::Result> info) { + auto dataType = a.element_type(); + if (out->element_type() != dataType || + s->element_type() != ffi::ToReal(dataType) || + u->element_type() != dataType || v->element_type() != dataType) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to gesvdj must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + int64_t size = std::min(rows, cols); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gesvdj")); + FFI_RETURN_IF_ERROR( + CheckShape(s->dimensions(), {batch, size}, "s", "gesvdj")); + // U and V must always be allocated even if compute_uv is false. + if (full_matrices) { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, rows, rows}, "u", "gesvdj")); + FFI_RETURN_IF_ERROR( + CheckShape(v->dimensions(), {batch, cols, cols}, "v", "gesvdj")); + } else { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, rows, size}, "u", "gesvdj")); + FFI_RETURN_IF_ERROR( + CheckShape(v->dimensions(), {batch, cols, size}, "v", "gesvdj")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvdj")); + + SOLVER_DISPATCH_IMPL(GesvdjImpl, batch, rows, cols, stream, scratch, + full_matrices, compute_uv, a, out, s, u, v, info); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gesvdj", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("full_matrices") + .Attr("compute_uv") + .Arg() // a + .Ret() // out + .Ret() // s + .Ret() // u + .Ret() // v + .Ret>() // info +); + +#endif // JAX_GPU_CUDA + #undef SOLVER_DISPATCH_IMPL #undef SOLVER_BLAS_DISPATCH_IMPL diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 3bebe40bee26..022564eb108c 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -35,6 +35,11 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdFfi); + +#ifdef JAX_GPU_CUDA +XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi); +#endif // JAX_GPU_CUDA } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index bc61d58181ab..fa247b08b207 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -78,6 +78,8 @@ typedef cusolverStatus_t gpusolverStatus_t; typedef cusolverEigMode_t gpusolverEigMode_t; typedef syevjInfo gpuSyevjInfo; typedef syevjInfo_t gpuSyevjInfo_t; +typedef gesvdjInfo gpuGesvdjInfo; +typedef gesvdjInfo_t gpuGesvdjInfo_t; typedef cusparseIndexType_t gpusparseIndexType_t; typedef cusparseHandle_t gpusparseHandle_t; typedef cusparseOperation_t gpusparseOperation_t; @@ -120,6 +122,8 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusolverDnSetStream cusolverDnSetStream #define gpusolverDnCreateSyevjInfo cusolverDnCreateSyevjInfo #define gpusolverDnDestroySyevjInfo cusolverDnDestroySyevjInfo +#define gpusolverDnCreateGesvdjInfo cusolverDnCreateGesvdjInfo +#define gpusolverDnDestroyGesvdjInfo cusolverDnDestroyGesvdjInfo #define gpusolverDnSgeqrf cusolverDnSgeqrf #define gpusolverDnDgeqrf cusolverDnDgeqrf #define gpusolverDnCgeqrf cusolverDnCgeqrf @@ -184,6 +188,22 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; cusolverDnCgesvd_bufferSize(h, m, n, lwork) #define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ cusolverDnZgesvd_bufferSize(h, m, n, lwork) +#define gpusolverDnSgesvdj cusolverDnSgesvdj +#define gpusolverDnDgesvdj cusolverDnDgesvdj +#define gpusolverDnCgesvdj cusolverDnCgesvdj +#define gpusolverDnZgesvdj cusolverDnZgesvdj +#define gpusolverDnSgesvdj_bufferSize cusolverDnSgesvdj_bufferSize +#define gpusolverDnDgesvdj_bufferSize cusolverDnDgesvdj_bufferSize +#define gpusolverDnCgesvdj_bufferSize cusolverDnCgesvdj_bufferSize +#define gpusolverDnZgesvdj_bufferSize cusolverDnZgesvdj_bufferSize +#define gpusolverDnSgesvdjBatched cusolverDnSgesvdjBatched +#define gpusolverDnDgesvdjBatched cusolverDnDgesvdjBatched +#define gpusolverDnCgesvdjBatched cusolverDnCgesvdjBatched +#define gpusolverDnZgesvdjBatched cusolverDnZgesvdjBatched +#define gpusolverDnSgesvdjBatched_bufferSize cusolverDnSgesvdjBatched_bufferSize +#define gpusolverDnDgesvdjBatched_bufferSize cusolverDnDgesvdjBatched_bufferSize +#define gpusolverDnCgesvdjBatched_bufferSize cusolverDnCgesvdjBatched_bufferSize +#define gpusolverDnZgesvdjBatched_bufferSize cusolverDnZgesvdjBatched_bufferSize #define gpusolverDnSsytrd_bufferSize cusolverDnSsytrd_bufferSize #define gpusolverDnDsytrd_bufferSize cusolverDnDsytrd_bufferSize #define gpusolverDnChetrd_bufferSize cusolverDnChetrd_bufferSize @@ -196,6 +216,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSOLVER_FILL_MODE_LOWER CUBLAS_FILL_MODE_LOWER #define GPUSOLVER_FILL_MODE_UPPER CUBLAS_FILL_MODE_UPPER #define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR +#define GPUSOLVER_EIG_MODE_NOVECTOR CUSOLVER_EIG_MODE_NOVECTOR #define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS #define GPUBLAS_OP_N CUBLAS_OP_N @@ -311,6 +332,22 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuGetDeviceProperties cudaGetDeviceProperties #define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel +#define JAX_GPU_64_BIT 1 + +#define GPU_R_32F CUDA_R_32F +#define GPU_R_64F CUDA_R_64F +#define GPU_C_32F CUDA_C_32F +#define GPU_C_64F CUDA_C_64F + +typedef cudaDataType gpuDataType; +typedef cusolverDnParams gpusolverDnParams; +typedef cusolverDnParams_t gpusolverDnParams_t; +#define gpusolverDnCreateParams cusolverDnCreateParams +#define gpusolverDnDestroyParams cusolverDnDestroyParams + +#define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize +#define gpusolverDnXgesvd cusolverDnXgesvd + namespace jax::JAX_GPU_NAMESPACE { namespace { constexpr uint32_t kNumThreadsPerWarp = 32; @@ -331,6 +368,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #define JAX_GPU_PREFIX "hip" #define JAX_GPU_HAVE_SPARSE 1 +#define JAX_GPU_64_BIT 0 #define JAX_GPU_HAVE_FP8 0 typedef hipFloatComplex gpuComplex; @@ -472,6 +510,7 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSOLVER_FILL_MODE_LOWER HIPSOLVER_FILL_MODE_LOWER #define GPUSOLVER_FILL_MODE_UPPER HIPSOLVER_FILL_MODE_UPPER #define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR +#define GPUSOLVER_EIG_MODE_NOVECTOR HIPSOLVER_EIG_MODE_NOVECTOR #define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS #define GPUBLAS_OP_N HIPBLAS_OP_N From d4d1518c3d852b25465955796e65d74c6d34b029 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Fri, 20 Sep 2024 07:51:48 -0700 Subject: [PATCH 581/702] Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax PiperOrigin-RevId: 676843138 --- .github/ISSUE_TEMPLATE/bug-report.yml | 6 +- .github/ISSUE_TEMPLATE/config.yml | 2 +- .github/workflows/upstream-nightly.yml | 2 +- CHANGELOG.md | 206 +++++++++--------- CITATION.bib | 2 +- README.md | 22 +- cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb | 2 +- cloud_tpu_colabs/JAX_demo.ipynb | 4 +- cloud_tpu_colabs/Pmap_Cookbook.ipynb | 4 +- cloud_tpu_colabs/README.md | 18 +- docs/_tutorials/advanced-autodiff.md | 2 +- docs/autodidax.ipynb | 4 +- docs/autodidax.md | 4 +- docs/autodidax.py | 4 +- docs/beginner_guide.rst | 2 +- docs/conf.py | 4 +- docs/contributing.md | 18 +- docs/developer.md | 14 +- docs/export/export.md | 6 +- docs/export/jax2tf.md | 2 +- docs/faq.rst | 6 +- docs/ffi.ipynb | 4 +- docs/ffi.md | 4 +- docs/installation.md | 4 +- docs/investigating_a_regression.md | 4 +- docs/jep/11830-new-remat-checkpoint.md | 6 +- docs/jep/12049-type-annotations.md | 10 +- docs/jep/15856-jex.md | 2 +- docs/jep/18137-numpy-scipy-scope.md | 4 +- docs/jep/2026-custom-derivatives.md | 34 +-- docs/jep/4008-custom-vjp-update.md | 8 +- docs/jep/4410-omnistaging.md | 4 +- docs/jep/9263-typed-keys.md | 2 +- docs/jep/9407-type-promotion.ipynb | 2 +- docs/jep/9407-type-promotion.md | 2 +- docs/jep/9419-jax-versioning.md | 8 +- docs/jep/index.rst | 2 +- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 8 +- docs/notebooks/Common_Gotchas_in_JAX.md | 8 +- ...tom_derivative_rules_for_Python_code.ipynb | 2 +- ...Custom_derivative_rules_for_Python_code.md | 2 +- ...arrays_and_automatic_parallelization.ipynb | 2 +- ...ed_arrays_and_automatic_parallelization.md | 2 +- docs/notebooks/How_JAX_primitives_work.ipynb | 2 +- docs/notebooks/How_JAX_primitives_work.md | 2 +- .../Neural_Network_and_Data_Loading.ipynb | 6 +- .../Neural_Network_and_Data_Loading.md | 6 +- .../Writing_custom_interpreters_in_Jax.ipynb | 8 +- .../Writing_custom_interpreters_in_Jax.md | 8 +- docs/notebooks/autodiff_cookbook.ipynb | 6 +- docs/notebooks/autodiff_cookbook.md | 6 +- docs/notebooks/convolutions.ipynb | 2 +- docs/notebooks/convolutions.md | 2 +- .../neural_network_with_tfds_data.ipynb | 4 +- .../neural_network_with_tfds_data.md | 4 +- docs/notebooks/thinking_in_jax.ipynb | 2 +- docs/notebooks/thinking_in_jax.md | 2 +- docs/notebooks/vmapped_log_probs.ipynb | 2 +- docs/notebooks/vmapped_log_probs.md | 2 +- docs/pallas/tpu/details.rst | 2 +- docs/persistent_compilation_cache.md | 4 +- docs/sphinxext/jax_extensions.py | 4 +- docs/stateful-computations.md | 2 +- jax/__init__.py | 4 +- jax/_src/ad_checkpoint.py | 2 +- jax/_src/api.py | 4 +- jax/_src/callback.py | 2 +- jax/_src/config.py | 2 +- jax/_src/core.py | 14 +- jax/_src/custom_batching.py | 2 +- jax/_src/custom_derivatives.py | 2 +- jax/_src/custom_partitioning.py | 2 +- jax/_src/debugging.py | 2 +- jax/_src/dtypes.py | 2 +- jax/_src/flatten_util.py | 2 +- jax/_src/internal_test_util/test_harnesses.py | 6 +- jax/_src/interpreters/partial_eval.py | 4 +- jax/_src/interpreters/pxla.py | 6 +- jax/_src/lax/control_flow/conditionals.py | 2 +- jax/_src/lax/control_flow/loops.py | 6 +- jax/_src/lax/fft.py | 2 +- jax/_src/lax/lax.py | 6 +- jax/_src/lax/linalg.py | 2 +- jax/_src/lib/__init__.py | 4 +- jax/_src/mesh.py | 2 +- jax/_src/numpy/lax_numpy.py | 8 +- jax/_src/numpy/reductions.py | 2 +- jax/_src/numpy/setops.py | 4 +- jax/_src/numpy/ufuncs.py | 2 +- jax/_src/pallas/mosaic/error_handling.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 2 +- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- jax/_src/pallas/triton/lowering.py | 2 +- jax/_src/pjit.py | 6 +- jax/_src/random.py | 4 +- jax/_src/scipy/ndimage.py | 2 +- jax/_src/shard_alike.py | 2 +- jax/_src/test_util.py | 2 +- jax/_src/typing.py | 2 +- jax/_src/xla_bridge.py | 2 +- jax/core.py | 2 +- jax/custom_derivatives.py | 2 +- jax/dtypes.py | 2 +- jax/errors.py | 2 +- jax/experimental/__init__.py | 2 +- jax/experimental/array_api/__init__.py | 2 +- jax/experimental/checkify.py | 2 +- jax/experimental/custom_partitioning.py | 2 +- jax/experimental/host_callback.py | 32 +-- .../jax2tf/JAX2TF_getting_started.ipynb | 2 +- jax/experimental/jax2tf/README.md | 34 +-- jax/experimental/jax2tf/call_tf.py | 12 +- jax/experimental/jax2tf/examples/README.md | 24 +- .../jax2tf/examples/serving/README.md | 10 +- .../jax2tf/examples/tflite/mnist/README.md | 2 +- .../jax2tf/g3doc/convert_models_results.md | 98 ++++----- .../g3doc/convert_models_results.md.template | 4 +- .../jax2tf/g3doc/no_xla_limitations.md | 2 +- .../g3doc/primitives_with_limited_support.md | 4 +- ...rimitives_with_limited_support.md.template | 4 +- jax/experimental/jax2tf/impl_no_xla.py | 2 +- jax/experimental/jax2tf/jax2tf.py | 10 +- jax/experimental/jax2tf/tests/call_tf_test.py | 4 +- .../jax2tf/tests/jax2tf_limitations.py | 4 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 14 +- .../jax2tf/tests/primitives_test.py | 2 +- .../jax2tf/tests/savedmodel_test.py | 2 +- .../jax2tf/tests/shape_poly_test.py | 4 +- .../jax2tf/tests/sharding_test.py | 2 +- jax/experimental/jet.py | 4 +- jax/experimental/shard_map.py | 28 +-- jax/experimental/sparse/__init__.py | 2 +- jax/experimental/sparse/bcoo.py | 6 +- jax/extend/backend.py | 2 +- jax/extend/core/__init__.py | 2 +- jax/extend/core/primitives.py | 2 +- jax/extend/ffi.py | 2 +- jax/extend/ifrt_programs.py | 2 +- jax/extend/linear_util.py | 2 +- jax/extend/random.py | 2 +- jax/extend/source_info_util.py | 2 +- jax/image/__init__.py | 2 +- jax/interpreters/ad.py | 2 +- jax/interpreters/batching.py | 2 +- jax/lax/__init__.py | 2 +- jax/nn/__init__.py | 2 +- jax/nn/initializers.py | 2 +- jax/numpy/__init__.py | 2 +- jax/numpy/fft.py | 2 +- jax/numpy/linalg.py | 2 +- jax/ops/__init__.py | 2 +- jax/profiler.py | 2 +- jax/random.py | 4 +- jax/scipy/__init__.py | 2 +- jax/scipy/cluster/__init__.py | 2 +- jax/scipy/cluster/vq.py | 2 +- jax/scipy/fft.py | 2 +- jax/scipy/integrate.py | 2 +- jax/scipy/linalg.py | 2 +- jax/scipy/ndimage.py | 2 +- jax/scipy/optimize/__init__.py | 2 +- jax/scipy/signal.py | 2 +- jax/scipy/sparse/__init__.py | 2 +- jax/scipy/sparse/linalg.py | 2 +- jax/scipy/spatial/transform.py | 2 +- jax/scipy/special.py | 2 +- jax/scipy/stats/__init__.py | 2 +- jax/scipy/stats/bernoulli.py | 2 +- jax/scipy/stats/beta.py | 2 +- jax/scipy/stats/betabinom.py | 2 +- jax/scipy/stats/cauchy.py | 2 +- jax/scipy/stats/chi2.py | 2 +- jax/scipy/stats/dirichlet.py | 2 +- jax/scipy/stats/expon.py | 2 +- jax/scipy/stats/gamma.py | 2 +- jax/scipy/stats/gennorm.py | 2 +- jax/scipy/stats/geom.py | 2 +- jax/scipy/stats/laplace.py | 2 +- jax/scipy/stats/logistic.py | 2 +- jax/scipy/stats/multinomial.py | 2 +- jax/scipy/stats/multivariate_normal.py | 2 +- jax/scipy/stats/norm.py | 2 +- jax/scipy/stats/pareto.py | 2 +- jax/scipy/stats/poisson.py | 2 +- jax/scipy/stats/t.py | 2 +- jax/scipy/stats/truncnorm.py | 2 +- jax/scipy/stats/uniform.py | 2 +- jax/scipy/stats/vonmises.py | 2 +- jax/scipy/stats/wrapcauchy.py | 2 +- jax/sharding.py | 2 +- jax/stages.py | 2 +- jax/test_util.py | 2 +- jax/tree_util.py | 2 +- jax/util.py | 2 +- jax/version.py | 2 +- jax_plugins/rocm/plugin_setup.py | 2 +- jaxlib/README.md | 2 +- jaxlib/setup.py | 2 +- jaxlib/tools/build_wheel.py | 2 +- setup.py | 2 +- tests/api_test.py | 184 ++++++++-------- tests/array_interoperability_test.py | 2 +- tests/attrs_test.py | 2 +- tests/batching_test.py | 16 +- tests/core_test.py | 4 +- tests/custom_linear_solve_test.py | 4 +- tests/debug_nans_test.py | 4 +- tests/dtypes_test.py | 6 +- tests/dynamic_api_test.py | 2 +- tests/export_test.py | 2 +- tests/fft_test.py | 2 +- tests/host_callback_test.py | 6 +- tests/image_test.py | 2 +- tests/jet_test.py | 4 +- tests/lax_autodiff_test.py | 20 +- tests/lax_control_flow_test.py | 34 +-- tests/lax_metal_test.py | 48 ++-- tests/lax_numpy_einsum_test.py | 2 +- tests/lax_numpy_indexing_test.py | 28 +-- tests/lax_numpy_operators_test.py | 2 +- tests/lax_numpy_reducers_test.py | 12 +- tests/lax_numpy_test.py | 44 ++-- tests/lax_numpy_ufuncs_test.py | 2 +- tests/lax_numpy_vectorize_test.py | 2 +- tests/lax_scipy_sparse_test.py | 2 +- tests/lax_scipy_test.py | 12 +- tests/lax_test.py | 26 +-- tests/lax_vmap_test.py | 4 +- tests/linalg_test.py | 6 +- tests/multi_device_test.py | 2 +- tests/multibackend_test.py | 6 +- tests/multiprocess_gpu_test.py | 2 +- tests/nn_test.py | 12 +- tests/notebooks/colab_cpu.ipynb | 2 +- tests/notebooks/colab_gpu.ipynb | 2 +- tests/ode_test.py | 16 +- tests/optimizers_test.py | 2 +- tests/pallas/gpu_ops_test.py | 4 +- tests/pallas/ops_test.py | 6 +- tests/pallas/pallas_test.py | 2 +- tests/pallas/pallas_vmap_test.py | 2 +- tests/pjit_test.py | 6 +- tests/pmap_test.py | 38 ++-- tests/python_callback_test.py | 4 +- tests/pytorch_interoperability_test.py | 2 +- tests/random_lax_test.py | 22 +- tests/random_test.py | 14 +- tests/scipy_ndimage_test.py | 2 +- tests/scipy_stats_test.py | 10 +- tests/shape_poly_test.py | 2 +- tests/shard_map_test.py | 18 +- tests/sparse_bcoo_bcsr_test.py | 12 +- tests/sparse_test.py | 2 +- tests/sparsify_test.py | 2 +- tests/stax_test.py | 2 +- tests/tree_util_test.py | 10 +- tests/x64_context_test.py | 2 +- 257 files changed, 906 insertions(+), 906 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index c19832e63163..628310519b66 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -20,11 +20,11 @@ body: * If you prefer a non-templated issue report, click [here][Raw report]. - [Discussions]: https://github.com/google/jax/discussions + [Discussions]: https://github.com/jax-ml/jax/discussions - [issue search]: https://github.com/google/jax/search?q=is%3Aissue&type=issues + [issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues - [Raw report]: http://github.com/google/jax/issues/new + [Raw report]: http://github.com/jax-ml/jax/issues/new - type: textarea attributes: label: Description diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index cabbed58967a..f078e8e94182 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,5 @@ blank_issues_enabled: false contact_links: - name: Have questions or need support? - url: https://github.com/google/jax/discussions + url: https://github.com/jax-ml/jax/discussions about: Please ask questions on the Discussions tab diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 1e345954d0f7..74cb45920949 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -84,7 +84,7 @@ jobs: failure() && steps.status.outcome == 'failure' && github.event_name == 'schedule' - && github.repository == 'google/jax' + && github.repository == 'jax-ml/jax' uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 with: name: output-${{ matrix.python-version }}-log.jsonl diff --git a/CHANGELOG.md b/CHANGELOG.md index ee782d04a02c..43db6e197b5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -279,7 +279,7 @@ See the 0.4.33 release notes for more details. which manifested as an incorrect output for cumulative reductions (#21403). * Fixed a bug where XLA:CPU miscompiled certain matmul fusions (https://github.com/openxla/xla/pull/13301). - * Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396). + * Fixes a compiler crash on GPU (https://github.com/jax-ml/jax/issues/21396). * Deprecations * `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will @@ -401,7 +401,7 @@ See the 0.4.33 release notes for more details. branch consistent with that of NumPy 2.0. * The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'` and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has - changed](https://github.com/google/jax/issues/19085) so that + changed](https://github.com/jax-ml/jax/issues/19085) so that mapping over keys results in random generation only from the first key in the batch. * Docs now use `jax.random.key` for construction of PRNG key arrays @@ -433,7 +433,7 @@ See the 0.4.33 release notes for more details. * JAX export does not support older serialization versions anymore. Version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024. - See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). + See [a description of the versions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). This change could break clients that set a specific JAX serialization version lower than 9. @@ -506,7 +506,7 @@ See the 0.4.33 release notes for more details. * added the ability to specify symbolic constraints on the dimension variables. This makes shape polymorphism more expressive, and gives a way to workaround limitations in the reasoning about inequalities. - See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. * with the addition of symbolic constraints ({jax-issue}`#19235`) we now consider dimension variables from different scopes to be different, even if they have the same name. Symbolic expressions from different scopes @@ -516,7 +516,7 @@ See the 0.4.33 release notes for more details. The scope of a symbolic expression `e` can be read with `e.scope` and passed into the above functions to direct them to construct symbolic expressions in a given scope. - See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. * simplified and faster equality comparisons, where we consider two symbolic dimensions to be equal if the normalized form of their difference reduces to 0 ({jax-issue}`#19231`; note that this may result in user-visible behavior @@ -535,7 +535,7 @@ See the 0.4.33 release notes for more details. strings for polymorphic shapes specifications ({jax-issue}`#19284`). * JAX default native serialization version is now 9. This is relevant for {mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`. - See [description of version numbers](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). + See [description of version numbers](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). * Refactored the API for `jax.experimental.export`. Instead of `from jax.experimental.export import export` you should use now `from jax.experimental import export`. The old way of importing will @@ -781,19 +781,19 @@ See the 0.4.33 release notes for more details. * When not running under IPython: when an exception is raised, JAX now filters out the entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace" that previously appeared.) This should produce much friendlier-looking tracebacks. See - [here](https://github.com/google/jax/pull/16949) for an example. + [here](https://github.com/jax-ml/jax/pull/16949) for an example. This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two separate unfiltered/filtered tracebacks, which was the old behavior) or `JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback). * jax2tf default serialization version is now 7, which introduces new shape - [safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). + [safety assertions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). * Devices passed to `jax.sharding.Mesh` should be hashable. This specifically applies to mock devices or user created devices. `jax.devices()` are already hashable. * Breaking changes: * jax2tf now uses native serialization by default. See - the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) + the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) for details and for mechanisms to override the default. * The option `--jax_coordination_service` has been removed. It is now always `True`. @@ -922,7 +922,7 @@ See the 0.4.33 release notes for more details. arguments will always resolve to the "common operands" `cond` behavior (as documented) if the second and third arguments are callable, even if other operands are callable as well. See - [#16413](https://github.com/google/jax/issues/16413). + [#16413](https://github.com/jax-ml/jax/issues/16413). * The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`, which did nothing, have been removed. These options have been true by default for many releases. @@ -933,7 +933,7 @@ See the 0.4.33 release notes for more details. serialization version ({jax-issue}`#16746`). * jax2tf in presence of shape polymorphism now generates code that checks certain shape constraints, if the serialization version is at least 7. - See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism. + See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism. ## jaxlib 0.4.14 (July 27, 2023) @@ -1095,14 +1095,14 @@ See the 0.4.33 release notes for more details. {func}`jax.experimental.host_callback` is no longer supported on Cloud TPU with the new runtime component. Please file an issue on the [JAX issue - tracker](https://github.com/google/jax/issues) if the new `jax.debug` APIs + tracker](https://github.com/jax-ml/jax/issues) if the new `jax.debug` APIs are insufficient for your use case. The old runtime component will be available for at least the next three months by setting the environment variable `JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new runtime for any reason, please let us know on the [JAX issue - tracker](https://github.com/google/jax/issues). + tracker](https://github.com/jax-ml/jax/issues). * Changes * The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7. @@ -1126,7 +1126,7 @@ See the 0.4.33 release notes for more details. StableHLO module for the entire JAX function instead of lowering each JAX primitive to a TensorFlow op. This simplifies the internals and increases the confidence that what you serialize matches the JAX native semantics. - See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + See [documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md). As part of this change the config flag `--jax2tf_default_experimental_native_lowering` has been renamed to `--jax2tf_native_serialization`. * JAX now depends on `ml_dtypes`, which contains definitions of NumPy types @@ -1403,7 +1403,7 @@ Changes: ## jaxlib 0.3.22 (Oct 11, 2022) ## jax 0.3.21 (Sep 30, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.20...jax-v0.3.21). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.20...jax-v0.3.21). * Changes * The persistent compilation cache will now warn instead of raising an exception on error ({jax-issue}`#12582`), so program execution can continue @@ -1417,18 +1417,18 @@ Changes: * Fix incorrect `pip` url in `setup.py` comment ({jax-issue}`#12528`). ## jaxlib 0.3.20 (Sep 28, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20). * Bug fixes * Fixes support for limiting the visible CUDA devices via `jax_cuda_visible_devices` in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU ({jax-issue}`#12533`). ## jax 0.3.19 (Sep 27, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.18...jax-v0.3.19). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.18...jax-v0.3.19). * Fixes required jaxlib version. ## jax 0.3.18 (Sep 26, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.17...jax-v0.3.18). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.17...jax-v0.3.18). * Changes * Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}`#7733`) is stable and public. See [the @@ -1446,7 +1446,7 @@ Changes: would have been provided. ## jax 0.3.17 (Aug 31, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.16...jax-v0.3.17). * Bugs * Fix corner case issue in gradient of `lax.pow` with an exponent of zero ({jax-issue}`12041`) @@ -1462,7 +1462,7 @@ Changes: * `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead. ## jax 0.3.16 -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main). * Breaking changes * Support for NumPy 1.19 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). @@ -1486,7 +1486,7 @@ Changes: deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). ## jax 0.3.15 (July 22, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...jax-v0.3.15). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15). * Changes * `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These classes have been deprecated since v0.3.1 ({jax-issue}`#11248`). @@ -1507,10 +1507,10 @@ Changes: following a similar deprecation in {func}`scipy.linalg.solve`. ## jaxlib 0.3.15 (July 22, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15). ## jax 0.3.14 (June 27, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...jax-v0.3.14). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.13...jax-v0.3.14). * Breaking changes * {func}`jax.experimental.compilation_cache.initialize_cache` does not support `max_cache_size_ bytes` anymore and will not get that as an input. @@ -1563,22 +1563,22 @@ Changes: coefficients have leading zeros ({jax-issue}`#11215`). ## jaxlib 0.3.14 (June 27, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14). * x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 was released in 2018, so this should not be a very onerous requirement. * The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks. * The Python flatbuffers package is no longer a dependency of jaxlib. ## jax 0.3.13 (May 16, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.12...jax-v0.3.13). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.12...jax-v0.3.13). ## jax 0.3.12 (May 15, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.11...jax-v0.3.12). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.11...jax-v0.3.12). * Changes - * Fixes [#10717](https://github.com/google/jax/issues/10717). + * Fixes [#10717](https://github.com/jax-ml/jax/issues/10717). ## jax 0.3.11 (May 15, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...jax-v0.3.11). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.10...jax-v0.3.11). * Changes * {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument that allows users to opt out of eigenvalue sorting on TPU. @@ -1592,22 +1592,22 @@ Changes: scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead. ## jax 0.3.10 (May 3, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.9...jax-v0.3.10). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.9...jax-v0.3.10). ## jaxlib 0.3.10 (May 3, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10). * Changes * [TF commit](https://github.com/tensorflow/tensorflow/commit/207d50d253e11c3a3430a700af478a1d524a779a) fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs. ## jax 0.3.9 (May 2, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.8...jax-v0.3.9). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.8...jax-v0.3.9). * Changes * Added support for fully asynchronous checkpointing for GlobalDeviceArray. ## jax 0.3.8 (April 29 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.7...jax-v0.3.8). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.7...jax-v0.3.8). * Changes * {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver. * {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input. @@ -1666,7 +1666,7 @@ Changes: ## jax 0.3.7 (April 15, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.6...jax-v0.3.7). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.6...jax-v0.3.7). * Changes: * Fixed a performance problem if the indices passed to {func}`jax.numpy.take_along_axis` were broadcasted ({jax-issue}`#10281`). @@ -1684,17 +1684,17 @@ Changes: ## jax 0.3.6 (April 12, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.5...jax-v0.3.6). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.5...jax-v0.3.6). * Changes: * Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU - pod. Fixes [#10218](https://github.com/google/jax/issues/10218). + pod. Fixes [#10218](https://github.com/jax-ml/jax/issues/10218). * Deprecations: * {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278` for an alternative API. ## jax 0.3.5 (April 7, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.4...jax-v0.3.5). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.4...jax-v0.3.5). * Changes: * added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta` and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`). @@ -1717,17 +1717,17 @@ Changes: ## jax 0.3.4 (March 18, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.3...jax-v0.3.4). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.3...jax-v0.3.4). ## jax 0.3.3 (March 17, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.2...jax-v0.3.3). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.2...jax-v0.3.3). ## jax 0.3.2 (March 16, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.1...jax-v0.3.2). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.1...jax-v0.3.2). * Changes: * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were deprecated in 0.2.22, have been removed. Please use @@ -1751,7 +1751,7 @@ Changes: ## jax 0.3.1 (Feb 18, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.0...jax-v0.3.1). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.0...jax-v0.3.1). * Changes: * `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated. @@ -1774,7 +1774,7 @@ Changes: ## jax 0.3.0 (Feb 10, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.28...jax-v0.3.0). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0). * Changes * jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) @@ -1788,7 +1788,7 @@ Changes: ## jax 0.2.28 (Feb 1, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.27...jax-v0.2.28). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.27...jax-v0.2.28). * `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no `dialect=` is passed. * The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR @@ -1813,7 +1813,7 @@ Changes: * The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311). ## jax 0.2.27 (Jan 18 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.26...jax-v0.2.27). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.26...jax-v0.2.27). * Breaking changes: * Support for NumPy 1.18 has been dropped, per the @@ -1858,7 +1858,7 @@ Changes: ## jax 0.2.26 (Dec 8, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.25...jax-v0.2.26). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.25...jax-v0.2.26). * Bug fixes: * Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with @@ -1875,7 +1875,7 @@ Changes: ## jax 0.2.25 (Nov 10, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.24...jax-v0.2.25). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.24...jax-v0.2.25). * New features: * (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend. @@ -1889,7 +1889,7 @@ Changes: ## jax 0.2.24 (Oct 19, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.22...jax-v0.2.24). * New features: * `jax.random.choice` and `jax.random.permutation` now support @@ -1923,7 +1923,7 @@ Changes: ## jax 0.2.22 (Oct 12, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.21...jax-v0.2.22). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.21...jax-v0.2.22). * Breaking Changes * Static arguments to `jax.pmap` must now be hashable. @@ -1958,13 +1958,13 @@ Changes: * Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+. * Bug fixes: - * Fixes https://github.com/google/jax/issues/7461, which caused wrong + * Fixes https://github.com/jax-ml/jax/issues/7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler. ## jax 0.2.21 (Sept 23, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.20...jax-v0.2.21). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.20...jax-v0.2.21). * Breaking Changes * `jax.api` has been removed. Functions that were available as `jax.api.*` were aliases for functions in `jax.*`; please use the functions in @@ -1992,7 +1992,7 @@ Changes: ## jax 0.2.20 (Sept 2, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.19...jax-v0.2.20). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.19...jax-v0.2.20). * Breaking Changes * `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`) * `jnp.unique` and other set-like operations now require array-like inputs @@ -2005,7 +2005,7 @@ Changes: ## jax 0.2.19 (Aug 12, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.18...jax-v0.2.19). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19). * Breaking changes: * Support for NumPy 1.17 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). @@ -2042,7 +2042,7 @@ Changes: called in sequence. ## jax 0.2.18 (July 21 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...jax-v0.2.18). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.17...jax-v0.2.18). * Breaking changes: * Support for Python 3.6 has been dropped, per the @@ -2065,7 +2065,7 @@ Changes: * Fix bugs in TFRT CPU backend that results in incorrect results. ## jax 0.2.17 (July 9 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...jax-v0.2.17). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.16...jax-v0.2.17). * Bug fixes: * Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68 to work around #7229, which caused wrong outputs on CPU due to a concurrency @@ -2082,12 +2082,12 @@ Changes: ## jax 0.2.16 (June 23 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.15...jax-v0.2.16). ## jax 0.2.15 (June 23 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...jax-v0.2.15). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.14...jax-v0.2.15). * New features: - * [#7042](https://github.com/google/jax/pull/7042) Turned on TFRT CPU backend + * [#7042](https://github.com/jax-ml/jax/pull/7042) Turned on TFRT CPU backend with significant dispatch performance improvements on CPU. * The {func}`jax2tf.convert` supports inequalities and min/max for booleans ({jax-issue}`#6956`). @@ -2107,7 +2107,7 @@ Changes: CPU. ## jax 0.2.14 (June 10 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.13...jax-v0.2.14). * New features: * The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`. * A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters @@ -2165,7 +2165,7 @@ Changes: {func}`jit` transformed functions. ## jax 0.2.13 (May 3 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...jax-v0.2.13). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.12...jax-v0.2.13). * New features: * When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static keyword arguments. A new `static_argnames` option has been added to specify @@ -2209,7 +2209,7 @@ Changes: ## jaxlib 0.1.65 (April 7 2021) ## jax 0.2.12 (April 1 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.11...v0.2.12). * New features * New profiling APIs: {func}`jax.profiler.start_trace`, {func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace` @@ -2222,7 +2222,7 @@ Changes: * `TraceContext` --> {func}`~jax.profiler.TraceAnnotation` * `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation` * `trace_function` --> {func}`~jax.profiler.annotate_function` - * Omnistaging can no longer be disabled. See [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md) + * Omnistaging can no longer be disabled. See [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md) for more information. * Python integers larger than the maximum `int64` value will now lead to an overflow in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`). @@ -2236,23 +2236,23 @@ Changes: ## jax 0.2.11 (March 23 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.10...jax-v0.2.11). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.10...jax-v0.2.11). * New features: - * [#6112](https://github.com/google/jax/pull/6112) added context managers: + * [#6112](https://github.com/jax-ml/jax/pull/6112) added context managers: `jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`, `jax.debug_infs`, `jax.log_compiles`. - * [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete` + * [#6085](https://github.com/jax-ml/jax/pull/6085) added `jnp.delete` * Bug fixes: - * [#6136](https://github.com/google/jax/pull/6136) generalized + * [#6136](https://github.com/jax-ml/jax/pull/6136) generalized `jax.flatten_util.ravel_pytree` to handle integer dtypes. - * [#6129](https://github.com/google/jax/issues/6129) fixed a bug with handling + * [#6129](https://github.com/jax-ml/jax/issues/6129) fixed a bug with handling some constants like `enum.IntEnums` - * [#6145](https://github.com/google/jax/pull/6145) fixed batching issues with + * [#6145](https://github.com/jax-ml/jax/pull/6145) fixed batching issues with incomplete beta functions - * [#6014](https://github.com/google/jax/pull/6014) fixed H2D transfers during + * [#6014](https://github.com/jax-ml/jax/pull/6014) fixed H2D transfers during tracing - * [#6165](https://github.com/google/jax/pull/6165) avoids OverflowErrors when + * [#6165](https://github.com/jax-ml/jax/pull/6165) avoids OverflowErrors when converting some large Python integers to floats * Breaking changes: * The minimum jaxlib version is now 0.1.62. @@ -2264,13 +2264,13 @@ Changes: ## jax 0.2.10 (March 5 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.9...jax-v0.2.10). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.9...jax-v0.2.10). * New features: * {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods. * {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods. * Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions from JAX ({jax-issue}`#5627`) - and [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)). + and [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)). * Extended the batching rule for `lax.pad` to support batching of the padding values. * Bug fixes: * {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`#5768`) @@ -2314,7 +2314,7 @@ Changes: ## jax 0.2.9 (January 26 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.8...jax-v0.2.9). * New features: * Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved error checking and error messages. @@ -2330,7 +2330,7 @@ Changes: ## jax 0.2.8 (January 12 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.7...jax-v0.2.8). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.7...jax-v0.2.8). * New features: * Add {func}`jax.closure_convert` for use with higher-order custom derivative functions. ({jax-issue}`#5244`) @@ -2362,7 +2362,7 @@ Changes: ## jax 0.2.7 (Dec 4 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.6...jax-v0.2.7). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.6...jax-v0.2.7). * New features: * Add `jax.device_put_replicated` * Add multi-host support to `jax.experimental.sharded_jit` @@ -2382,14 +2382,14 @@ Changes: ## jax 0.2.6 (Nov 18 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.5...jax-v0.2.6). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.5...jax-v0.2.6). * New Features: * Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. - See [README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + See [README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md). * Breaking change cleanup * Raise an error on non-hashable static arguments for jax.jit and - xla_computation. See [cb48f42](https://github.com/google/jax/commit/cb48f42). + xla_computation. See [cb48f42](https://github.com/jax-ml/jax/commit/cb48f42). * Improve consistency of type promotion behavior ({jax-issue}`#4744`): * Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously @@ -2441,15 +2441,15 @@ Changes: ## jax 0.2.5 (October 27 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.4...jax-v0.2.5). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.4...jax-v0.2.5). * Improvements: * Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`#4650`. * Expanded the set of JAX primitives converted by jax2tf. - See [primitives_with_limited_support.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md). + See [primitives_with_limited_support.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md). ## jax 0.2.4 (October 19 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.3...jax-v0.2.4). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.3...jax-v0.2.4). * Improvements: * Add support for `remat` to jax.experimental.host_callback. See {jax-issue}`#4608`. * Deprecations @@ -2461,17 +2461,17 @@ Changes: ## jax 0.2.3 (October 14 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.2...jax-v0.2.3). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.2...jax-v0.2.3). * The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation ## jax 0.2.2 (October 13 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.1...jax-v0.2.2). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.1...jax-v0.2.2). ## jax 0.2.1 (October 6 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.0...jax-v0.2.1). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.0...jax-v0.2.1). * Improvements: * As a benefit of omnistaging, the host_callback functions are executed (in program order) even if the result of the {py:func}`jax.experimental.host_callback.id_print`/ @@ -2479,10 +2479,10 @@ Changes: ## jax (0.2.0) (September 23 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.77...jax-v0.2.0). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.77...jax-v0.2.0). * Improvements: * Omnistaging on by default. See {jax-issue}`#3370` and - [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md) + [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md) ## jax (0.1.77) (September 15 2020) @@ -2496,11 +2496,11 @@ Changes: ## jax 0.1.76 (September 8, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.75...jax-v0.1.76). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.75...jax-v0.1.76). ## jax 0.1.75 (July 30, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.74...jax-v0.1.75). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.74...jax-v0.1.75). * Bug Fixes: * make jnp.abs() work for unsigned inputs (#3914) * Improvements: @@ -2508,7 +2508,7 @@ Changes: ## jax 0.1.74 (July 29, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.73...jax-v0.1.74). * New Features: * BFGS (#3101) * TPU support for half-precision arithmetic (#3878) @@ -2525,7 +2525,7 @@ Changes: ## jax 0.1.73 (July 22, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.72...jax-v0.1.73). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.72...jax-v0.1.73). * The minimum jaxlib version is now 0.1.51. * New Features: * jax.image.resize. (#3703) @@ -2563,14 +2563,14 @@ Changes: ## jax 0.1.72 (June 28, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.71...jax-v0.1.72). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.71...jax-v0.1.72). * Bug fixes: * Fix an odeint bug introduced in the previous release, see {jax-issue}`#3587`. ## jax 0.1.71 (June 25, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.70...jax-v0.1.71). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.70...jax-v0.1.71). * The minimum jaxlib version is now 0.1.48. * Bug fixes: * Allow `jax.experimental.ode.odeint` dynamics functions to close over @@ -2606,7 +2606,7 @@ Changes: ## jax 0.1.70 (June 8, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.69...jax-v0.1.70). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.69...jax-v0.1.70). * New features: * `lax.switch` introduces indexed conditionals with multiple branches, together with a generalization of the `cond` @@ -2615,11 +2615,11 @@ Changes: ## jax 0.1.69 (June 3, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.68...jax-v0.1.69). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.68...jax-v0.1.69). ## jax 0.1.68 (May 21, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.67...jax-v0.1.68). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.67...jax-v0.1.68). * New features: * {func}`lax.cond` supports a single-operand form, taken as the argument to both branches @@ -2630,7 +2630,7 @@ Changes: ## jax 0.1.67 (May 12, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.66...jax-v0.1.67). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.66...jax-v0.1.67). * New features: * Support for reduction over subsets of a pmapped axis using `axis_index_groups` {jax-issue}`#2382`. @@ -2648,7 +2648,7 @@ Changes: ## jax 0.1.66 (May 5, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.65...jax-v0.1.66). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.65...jax-v0.1.66). * New features: * Support for `in_axes=None` on {func}`pmap` {jax-issue}`#2896`. @@ -2661,7 +2661,7 @@ Changes: ## jax 0.1.65 (April 30, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.64...jax-v0.1.65). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.64...jax-v0.1.65). * New features: * Differentiation of determinants of singular matrices {jax-issue}`#2809`. @@ -2679,7 +2679,7 @@ Changes: ## jax 0.1.64 (April 21, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.63...jax-v0.1.64). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.63...jax-v0.1.64). * New features: * Add syntactic sugar for functional indexed updates {jax-issue}`#2684`. @@ -2706,7 +2706,7 @@ Changes: ## jax 0.1.63 (April 12, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.62...jax-v0.1.63). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63). * Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). * Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`. * Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`. @@ -2727,7 +2727,7 @@ Changes: ## jax 0.1.62 (March 21, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.61...jax-v0.1.62). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.61...jax-v0.1.62). * JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer. * Removed the internal function `lax._safe_mul`, which implemented the convention `0. * nan == 0.`. This change means some programs when @@ -2745,13 +2745,13 @@ Changes: ## jax 0.1.61 (March 17, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.60...jax-v0.1.61). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.60...jax-v0.1.61). * Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5. ## jax 0.1.60 (March 17, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.59...jax-v0.1.60). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.59...jax-v0.1.60). * New features: * {py:func}`jax.pmap` has `static_broadcast_argnums` argument which allows the user to specify arguments that should be treated as compile-time @@ -2777,7 +2777,7 @@ Changes: ## jax 0.1.59 (February 11, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.58...jax-v0.1.59). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.58...jax-v0.1.59). * Breaking changes * The minimum jaxlib version is now 0.1.38. @@ -2809,7 +2809,7 @@ Changes: ## jax 0.1.58 (January 28, 2020) -* [GitHub commits](https://github.com/google/jax/compare/46014da21...jax-v0.1.58). +* [GitHub commits](https://github.com/jax-ml/jax/compare/46014da21...jax-v0.1.58). * Breaking changes * JAX has dropped Python 2 support, because Python 2 reached its end of life on diff --git a/CITATION.bib b/CITATION.bib index 88049a1469d1..777058b5aaa9 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -1,7 +1,7 @@ @software{jax2018github, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, - url = {http://github.com/google/jax}, + url = {http://github.com/jax-ml/jax}, version = {0.3.13}, year = {2018}, } diff --git a/README.md b/README.md index 35307bee33f6..d67bdac82414 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@
-logo +logo
# Transformable numerical computing at scale -![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg) +![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg) ![PyPI version](https://img.shields.io/pypi/v/jax) [**Quickstart**](#quickstart-colab-in-the-cloud) @@ -50,7 +50,7 @@ parallel programming of multiple accelerators, with more to come. This is a research project, not an official Google product. Expect bugs and [sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). Please help by trying it out, [reporting -bugs](https://github.com/google/jax/issues), and letting us know what you +bugs](https://github.com/jax-ml/jax/issues), and letting us know what you think! ```python @@ -84,16 +84,16 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: - [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) -- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) +- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) **JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU -Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs). +Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs). For a deeper dive into JAX: - [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) - [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) - See the [full list of -notebooks](https://github.com/google/jax/tree/main/docs/notebooks). +notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks). ## Transformations @@ -300,7 +300,7 @@ print(normalize(jnp.arange(4.))) # prints [0. 0.16666667 0.33333334 0.5 ] ``` -You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more +You can even [nest `pmap` functions](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more sophisticated communication patterns. It all composes, so you're free to differentiate through parallel computations: @@ -333,9 +333,9 @@ When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the backward pass of the computation is parallelized just like the forward pass. See the [SPMD -Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) +Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) and the [SPMD MNIST classifier from scratch -example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) +example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) for more. ## Current gotchas @@ -349,7 +349,7 @@ Some standouts: 1. [In-place mutating updates of arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md). + different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. @@ -437,7 +437,7 @@ To cite this repository: @software{jax2018github, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, - url = {http://github.com/google/jax}, + url = {http://github.com/jax-ml/jax}, version = {0.3.13}, year = {2018}, } diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index cb5a42ced8c4..edaa71b93e85 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -451,7 +451,7 @@ "id": "jC-KIMQ1q-lK" }, "source": [ - "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." + "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." ] }, { diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index 4952cdbe9365..d7ba5ed334f4 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -837,7 +837,7 @@ "id": "f-FBsWeo1AXE" }, "source": [ - "" + "" ] }, { @@ -847,7 +847,7 @@ "id": "jC-KIMQ1q-lK" }, "source": [ - "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." + "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." ] }, { diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index 981f0a9e80a7..ea126ac4f1e7 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -15,13 +15,13 @@ "id": "sk-3cPGIBTq8" }, "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n", "\n", "This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n", "\n", "**Note:** To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Colab TPU, a Google Cloud TPU or a Kaggle TPU VM.\n", "\n", - "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)." + "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)." ] }, { diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md index 4a795f718c84..db3dc5f30814 100644 --- a/cloud_tpu_colabs/README.md +++ b/cloud_tpu_colabs/README.md @@ -13,25 +13,25 @@ VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). The following notebooks showcase how to use and what you can do with Cloud TPUs on Colab: -### [Pmap Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) +### [Pmap Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) A guide to getting started with `pmap`, a transform for easily distributing SPMD computations across devices. -### [Lorentz ODE Solver](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb) +### [Lorentz ODE Solver](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb) Contributed by Alex Alemi (alexalemi@) Solve and plot parallel ODE solutions with `pmap`. - + -### [Wave Equation](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb) +### [Wave Equation](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb) Contributed by Stephan Hoyer (shoyer@) Solve the wave equation with `pmap`, and make cool movies! The spatial domain is partitioned across the 8 cores of a Cloud TPU. -![](https://raw.githubusercontent.com/google/jax/main/cloud_tpu_colabs/images/wave_movie.gif) +![](https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/wave_movie.gif) -### [JAX Demo](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb) +### [JAX Demo](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb) An overview of JAX presented at the [Program Transformations for ML workshop at NeurIPS 2019](https://program-transformations.github.io/) and the [Compilers for ML workshop at CGO 2020](https://www.c4ml.org/). Covers basic numpy usage, `grad`, `jit`, `vmap`, and `pmap`. ## Performance notes @@ -53,7 +53,7 @@ By default\*, matrix multiplication in JAX on TPUs [uses bfloat16](https://cloud JAX also adds the `bfloat16` dtype, which you can use to explicitly cast arrays to bfloat16, e.g., `jax.numpy.array(x, dtype=jax.numpy.bfloat16)`. -\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/google/jax/issues/2161) if it affects you! +\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/jax-ml/jax/issues/2161) if it affects you! ## Running JAX on a Cloud TPU VM @@ -65,8 +65,8 @@ documentation](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). If you run into Cloud TPU-specific issues (e.g. trouble creating a Cloud TPU VM), please email , or if you are a [TRC](https://sites.research.google/trc/) member. You can also [file a -JAX issue](https://github.com/google/jax/issues) or [ask a discussion -question](https://github.com/google/jax/discussions) for any issues with these +JAX issue](https://github.com/jax-ml/jax/issues) or [ask a discussion +question](https://github.com/jax-ml/jax/discussions) for any issues with these notebooks or using JAX in general. If you have any other questions or comments regarding JAX on Cloud TPUs, please diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index 180f65f5d492..287487ad40a0 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -571,7 +571,7 @@ print("Naive full Hessian materialization") ### Jacobian-Matrix and Matrix-Jacobian products -Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: +Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: ```{code-cell} # Isolate the function from the weight matrix to the predictions diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index ed242ecc5710..9a956670ceea 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -27,7 +27,7 @@ "metadata": {}, "source": [ "[![Open in\n", - "Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)" + "Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)" ] }, { @@ -1781,7 +1781,7 @@ "metadata": {}, "source": [ "This is precisely the issue that\n", - "[omnistaging](https://github.com/google/jax/pull/3370) fixed.\n", + "[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.\n", "We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always\n", "applied, regardless of whether any inputs to `bind` are boxed in corresponding\n", "`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 471dd7c63f6d..937e1012a230 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -33,7 +33,7 @@ limitations under the License. ``` [![Open in -Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb) +Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb) +++ @@ -1399,7 +1399,7 @@ print(jaxpr) ``` This is precisely the issue that -[omnistaging](https://github.com/google/jax/pull/3370) fixed. +[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed. We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always applied, regardless of whether any inputs to `bind` are boxed in corresponding `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` diff --git a/docs/autodidax.py b/docs/autodidax.py index 6d295fc50301..c10e6365e62d 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -27,7 +27,7 @@ # --- # [![Open in -# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb) +# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb) # # Autodidax: JAX core from scratch # @@ -1396,7 +1396,7 @@ def pp_params(params: dict[str, Any]) -> PPrint: # This is precisely the issue that -# [omnistaging](https://github.com/google/jax/pull/3370) fixed. +# [omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed. # We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always # applied, regardless of whether any inputs to `bind` are boxed in corresponding # `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` diff --git a/docs/beginner_guide.rst b/docs/beginner_guide.rst index 204659ec2cb9..783d3b49ae52 100644 --- a/docs/beginner_guide.rst +++ b/docs/beginner_guide.rst @@ -52,4 +52,4 @@ questions answered are: .. _Flax: https://flax.readthedocs.io/ .. _Haiku: https://dm-haiku.readthedocs.io/ .. _JAX on StackOverflow: https://stackoverflow.com/questions/tagged/jax -.. _JAX GitHub discussions: https://github.com/google/jax/discussions \ No newline at end of file +.. _JAX GitHub discussions: https://github.com/jax-ml/jax/discussions \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index ed6fcfd0dc8b..e77916e265ff 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -168,7 +168,7 @@ def _do_not_evaluate_in_jax( # documentation. html_theme_options = { 'show_toc_level': 2, - 'repository_url': 'https://github.com/google/jax', + 'repository_url': 'https://github.com/jax-ml/jax', 'use_repository_button': True, # add a "link to repository" button 'navigation_with_keys': False, } @@ -345,7 +345,7 @@ def linkcode_resolve(domain, info): return None filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__)) lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else "" - return f"https://github.com/google/jax/blob/main/jax/{filename}{lines}" + return f"https://github.com/jax-ml/jax/blob/main/jax/{filename}{lines}" # Generate redirects from deleted files to new sources rediraffe_redirects = { diff --git a/docs/contributing.md b/docs/contributing.md index d7fa6e9da8a3..99d78453c436 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -5,22 +5,22 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are several ways to contribute, including: -- Answering questions on JAX's [discussions page](https://github.com/google/jax/discussions) +- Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions) - Improving or expanding JAX's [documentation](http://jax.readthedocs.io/) -- Contributing to JAX's [code-base](http://github.com/google/jax/) -- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/google/jax#neural-network-libraries) +- Contributing to JAX's [code-base](http://github.com/jax-ml/jax/) +- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries) The JAX project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). ## Ways to contribute We welcome pull requests, in particular for those issues marked with -[contributions welcome](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or -[good first issue](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). +[contributions welcome](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or +[good first issue](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). For other proposals, we ask that you first open a GitHub -[Issue](https://github.com/google/jax/issues/new/choose) or -[Discussion](https://github.com/google/jax/discussions) +[Issue](https://github.com/jax-ml/jax/issues/new/choose) or +[Discussion](https://github.com/jax-ml/jax/discussions) to seek feedback on your planned contribution. ## Contributing code using pull requests @@ -33,7 +33,7 @@ Follow these steps to contribute code: For more information, see the Pull Request Checklist below. 2. Fork the JAX repository by clicking the **Fork** button on the - [repository page](http://www.github.com/google/jax). This creates + [repository page](http://www.github.com/jax-ml/jax). This creates a copy of the JAX repository in your own account. 3. Install Python >= 3.10 locally in order to run tests. @@ -52,7 +52,7 @@ Follow these steps to contribute code: changes. ```bash - git remote add upstream https://www.github.com/google/jax + git remote add upstream https://www.github.com/jax-ml/jax ``` 6. Create a branch where you will develop from: diff --git a/docs/developer.md b/docs/developer.md index 40ad51e873ca..4f33614138ef 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -6,7 +6,7 @@ First, obtain the JAX source code: ``` -git clone https://github.com/google/jax +git clone https://github.com/jax-ml/jax cd jax ``` @@ -26,7 +26,7 @@ If you're only modifying Python portions of JAX, we recommend installing pip install jaxlib ``` -See the [JAX readme](https://github.com/google/jax#installation) for full +See the [JAX readme](https://github.com/jax-ml/jax#installation) for full guidance on pip installation (e.g., for GPU and TPU support). ### Building `jaxlib` from source @@ -621,7 +621,7 @@ pytest --doctest-modules jax/_src/numpy/lax_numpy.py Keep in mind that there are several files that are marked to be skipped when the doctest command is run on the full package; you can see the details in -[`ci-build.yaml`](https://github.com/google/jax/blob/main/.github/workflows/ci-build.yaml) +[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml) ## Type checking @@ -712,7 +712,7 @@ jupytext --sync docs/notebooks/thinking_in_jax.ipynb ``` The jupytext version should match that specified in -[.pre-commit-config.yaml](https://github.com/google/jax/blob/main/.pre-commit-config.yaml). +[.pre-commit-config.yaml](https://github.com/jax-ml/jax/blob/main/.pre-commit-config.yaml). To check that the markdown and ipynb files are properly synced, you may use the [pre-commit](https://pre-commit.com/) framework to perform the same check used @@ -740,12 +740,12 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked Some of the notebooks are built automatically as part of the pre-submit checks and as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, -or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)). +or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else re-saves the notebook. We exclude some notebooks from the build, e.g., because they contain long computations. -See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs/conf.py). +See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs/conf.py). ### Documentation building on `readthedocs.io` @@ -772,7 +772,7 @@ I saw in the Readthedocs logs: mkvirtualenv jax-docs # A new virtualenv mkdir jax-docs # A new directory cd jax-docs -git clone --no-single-branch --depth 50 https://github.com/google/jax +git clone --no-single-branch --depth 50 https://github.com/jax-ml/jax cd jax git checkout --force origin/test-docs git clean -d -f -f diff --git a/docs/export/export.md b/docs/export/export.md index 0ca1a64800e0..4e4d50556d8e 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -153,7 +153,7 @@ JAX runtime system that are: an inference system that is already deployed when the exporting is done. (The particular compatibility window lengths are the same that JAX -[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model), +[promised for jax2tf](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model), and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow). The terminology “backward compatibility” is from the perspective of the consumer, e.g., the inference system.) @@ -626,7 +626,7 @@ We list here a history of the calling convention version numbers: June 13th, 2023 (JAX 0.4.13). * Version 7 adds support for `stablehlo.shape_assertion` operations and for `shape_assertions` specified in `disabled_checks`. - See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule + See [Errors in presence of shape polymorphism](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule since July 12th, 2023 (cl/547482522), available in JAX serialization since July 20th, 2023 (JAX 0.4.14), and the default since August 12th, 2023 (JAX 0.4.15). @@ -721,7 +721,7 @@ that live in jaxlib): 2. Day “D”, we add the new custom call target `T_NEW`. We should create a new custom call target, and clean up the old target roughly after 6 months, rather than updating `T` in place: - * See the example [PR #20997](https://github.com/google/jax/pull/20997) + * See the example [PR #20997](https://github.com/jax-ml/jax/pull/20997) implementing the steps below. * We add the custom call target `T_NEW`. * We change the JAX lowering rules that were previous using `T`, diff --git a/docs/export/jax2tf.md b/docs/export/jax2tf.md index 498a0418f232..9c0ee90a0d93 100644 --- a/docs/export/jax2tf.md +++ b/docs/export/jax2tf.md @@ -2,4 +2,4 @@ ## Interoperation with TensorFlow -See the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). +See the [JAX2TF documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md). diff --git a/docs/faq.rst b/docs/faq.rst index 3ac7d89fb36e..af14f382b1d7 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -372,7 +372,7 @@ device. Jitted functions behave like any other primitive operations—they will follow the data and will show errors if invoked on data committed on more than one device. -(Before `PR #6002 `_ in March 2021 +(Before `PR #6002 `_ in March 2021 there was some laziness in creation of array constants, so that ``jax.device_put(jnp.zeros(...), jax.devices()[1])`` or similar would actually create the array of zeros on ``jax.devices()[1]``, instead of creating the @@ -385,7 +385,7 @@ and its use is not recommended.) For a worked-out example, we recommend reading through ``test_computation_follows_data`` in -`multi_device_test.py `_. +`multi_device_test.py `_. .. _faq-benchmark: @@ -691,7 +691,7 @@ The inner ``jnp.where`` may be needed in addition to the original one, e.g.:: Additional reading: - * `Issue: gradients through jnp.where when one of branches is nan `_. + * `Issue: gradients through jnp.where when one of branches is nan `_. * `How to avoid NaN gradients when using where `_. diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index 7f7bcc07ce85..04ae80cbf5b1 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -406,7 +406,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)." + "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)." ] }, { @@ -492,7 +492,7 @@ "source": [ "At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n", "One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n", - "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.\n", + "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.\n", "\n", "One other JAX feature that this example doesn't support is higher-order AD.\n", "It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n", diff --git a/docs/ffi.md b/docs/ffi.md index d96d9ff8c4fc..03acf876be08 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -333,7 +333,7 @@ def rms_norm_not_vectorized(x, eps=1e-5): jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x) ``` -If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues). +If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues). +++ @@ -406,7 +406,7 @@ np.testing.assert_allclose( At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`. One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode. -JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice. +JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice. One other JAX feature that this example doesn't support is higher-order AD. It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here. diff --git a/docs/installation.md b/docs/installation.md index 93df4a240a55..acb802ea939c 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -176,7 +176,7 @@ installation. JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package. Make sure that it is present in your CUDA installation. -Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues) +Please let the JAX team know on [the GitHub issue tracker](https://github.com/jax-ml/jax/issues) if you run into any errors or problems with the pre-built wheels. (docker-containers-nvidia-gpu)= @@ -216,7 +216,7 @@ refer to **Note:** There are several caveats with the Metal plugin: * The Metal plugin is new and experimental and has a number of - [known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22). + [known issues](https://github.com/jax-ml/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22). Please report any issues on the JAX issue tracker. * The Metal plugin currently requires very specific versions of `jax` and `jaxlib`. This restriction will be relaxed over time as the plugin API diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md index 389cc0b5a9e8..61d219d1bae1 100644 --- a/docs/investigating_a_regression.md +++ b/docs/investigating_a_regression.md @@ -9,7 +9,7 @@ Let's first make a JAX issue. But if you can pinpoint the commit that triggered the regression, it will really help us. This document explains how we identified the commit that caused a -[15% performance regression](https://github.com/google/jax/issues/17686). +[15% performance regression](https://github.com/jax-ml/jax/issues/17686). ## Steps @@ -34,7 +34,7 @@ containers](https://github.com/NVIDIA/JAX-Toolbox). - test_runner.sh: will start the containers and the test. - test.sh: will install missing dependencies and run the test -Here are real example scripts used for the issue: https://github.com/google/jax/issues/17686 +Here are real example scripts used for the issue: https://github.com/jax-ml/jax/issues/17686 - test_runner.sh: ``` for m in 7 8 9; do diff --git a/docs/jep/11830-new-remat-checkpoint.md b/docs/jep/11830-new-remat-checkpoint.md index da0adaf18060..019188349257 100644 --- a/docs/jep/11830-new-remat-checkpoint.md +++ b/docs/jep/11830-new-remat-checkpoint.md @@ -14,7 +14,7 @@ ## What’s going on? -As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade) +As of [#11830](https://github.com/jax-ml/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade) ## How can I disable the change, and go back to the old behavior for now? @@ -29,7 +29,7 @@ If you need to revert to the old implementation, **please reach out** on a GitHu As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer available. If you have an issue, please reach out on [the issue -tracker](https://github.com/google/jax/issues) so we can help fix it! +tracker](https://github.com/jax-ml/jax/issues) so we can help fix it! ## Why are we doing this? @@ -82,7 +82,7 @@ The old `jax.checkpoint` implementation was forced to save the value of `a`, whi ### Significantly less Python overhead in some cases -The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/google/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice! +The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/jax-ml/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice! ### Enabling new JAX features by simplifying internals diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 9137e3e71232..7a20958c5cab 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -12,7 +12,7 @@ The current state of type annotations in JAX is a bit patchwork, and efforts to This doc attempts to summarize those issues and generate a roadmap for the goals and non-goals of type annotations in JAX. Why do we need such a roadmap? Better/more comprehensive type annotations are a frequent request from users, both internally and externally. -In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/google/jax/pull/9917), [PR #10322](https://github.com/google/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python. +In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/jax-ml/jax/pull/9917), [PR #10322](https://github.com/jax-ml/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python. This document details JAX's goals and recommendations for type annotations within the package. ## Why type annotations? @@ -21,7 +21,7 @@ There are a number of reasons that a Python project might wish to annotate their ### Level 1: Annotations as documentation -When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]: +When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]: ```python Array = Any @@ -44,14 +44,14 @@ Many modern IDEs take advantage of type annotations as inputs to [intelligent co This use of type checking requires going further than the simple aliases used above; for example, knowing that the `slice` function returns an alias of `Any` named `Array` does not add any useful information to the code completion engine. However, were we to annotate the function with a `DeviceArray` return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development. -JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]: +JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]: ```python def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray: ... ``` -In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result: +In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result: ![VSCode Intellisense Screenshot](../_static/vscode-completion.png) @@ -232,7 +232,7 @@ assert jit(f)(x) # x will be a tracer ``` Again, there are a couple mechanisms that could be used for this: -- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/google/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)). +- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/jax-ml/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)). - define `ArrayInstance` as an abstract base class and dynamically register it to `Array` and `Tracer` - restructure `Array` and `Tracer` so that `ArrayInstance` is a true base class of both `Array` and `Tracer` diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md index bec06000194e..a5625abf8930 100644 --- a/docs/jep/15856-jex.md +++ b/docs/jep/15856-jex.md @@ -170,7 +170,7 @@ print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32) This module could expose our mechanism for defining new RNG implementations, and functions for working with PRNG key internals -(see issue [#9263](https://github.com/google/jax/issues/9263)), +(see issue [#9263](https://github.com/jax-ml/jax/issues/9263)), such as the current `jax._src.prng.random_wrap` and `random_unwrap`. diff --git a/docs/jep/18137-numpy-scipy-scope.md b/docs/jep/18137-numpy-scipy-scope.md index 2371e11ee07e..eaebe8fb8997 100644 --- a/docs/jep/18137-numpy-scipy-scope.md +++ b/docs/jep/18137-numpy-scipy-scope.md @@ -78,8 +78,8 @@ to JAX which have relatively complex implementations which are difficult to vali and introduce outsized maintenance burdens; an example is {func}`jax.scipy.special.bessel_jn`: as of the writing of this JEP, its current implementation is a non-straightforward iterative approximation that has -[convergence issues in some domains](https://github.com/google/jax/issues/12402#issuecomment-1384828637), -and [proposed fixes](https://github.com/google/jax/pull/17038/files) introduce further +[convergence issues in some domains](https://github.com/jax-ml/jax/issues/12402#issuecomment-1384828637), +and [proposed fixes](https://github.com/jax-ml/jax/pull/17038/files) introduce further complexity. Had we more carefully weighed the complexity and robustness of the implementation when accepting the contribution, we may have chosen not to accept this contribution to the package. diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md index aa568adc0d9a..ce149fa6fb35 100644 --- a/docs/jep/2026-custom-derivatives.md +++ b/docs/jep/2026-custom-derivatives.md @@ -35,9 +35,9 @@ behavior of their code. This customization Python control flow and workflows for NaN debugging. As **JAX developers** we want to write library functions, like -[`logit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83) +[`logit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83) and -[`expit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91), +[`expit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91), that are defined in terms of other primitives, but for the purposes of differentiation have primitive-like behavior in the sense that we want to define custom differentiation rules for them, which may be more numerically stable or @@ -50,9 +50,9 @@ looking to add custom differentiation rules for higher-order functions like want to be confident we’re not going to preclude good solutions to that problem. That is, our primary goals are -1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/google/jax/issues/1249)), and +1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/jax-ml/jax/issues/1249)), and 2. allow Python in custom VJPs, e.g. to debug NaNs - ([#1275](https://github.com/google/jax/issues/1275)). + ([#1275](https://github.com/jax-ml/jax/issues/1275)). Secondary goals are 3. clean up and simplify user experience (symbolic zeros, kwargs, etc) @@ -60,18 +60,18 @@ Secondary goals are `odeint`, `root`, etc. Overall, we want to close -[#116](https://github.com/google/jax/issues/116), -[#1097](https://github.com/google/jax/issues/1097), -[#1249](https://github.com/google/jax/issues/1249), -[#1275](https://github.com/google/jax/issues/1275), -[#1366](https://github.com/google/jax/issues/1366), -[#1723](https://github.com/google/jax/issues/1723), -[#1670](https://github.com/google/jax/issues/1670), -[#1875](https://github.com/google/jax/issues/1875), -[#1938](https://github.com/google/jax/issues/1938), +[#116](https://github.com/jax-ml/jax/issues/116), +[#1097](https://github.com/jax-ml/jax/issues/1097), +[#1249](https://github.com/jax-ml/jax/issues/1249), +[#1275](https://github.com/jax-ml/jax/issues/1275), +[#1366](https://github.com/jax-ml/jax/issues/1366), +[#1723](https://github.com/jax-ml/jax/issues/1723), +[#1670](https://github.com/jax-ml/jax/issues/1670), +[#1875](https://github.com/jax-ml/jax/issues/1875), +[#1938](https://github.com/jax-ml/jax/issues/1938), and replace the custom_transforms machinery (from -[#636](https://github.com/google/jax/issues/636), -[#818](https://github.com/google/jax/issues/818), +[#636](https://github.com/jax-ml/jax/issues/636), +[#818](https://github.com/jax-ml/jax/issues/818), and others). ## Non-goals @@ -400,7 +400,7 @@ There are some other bells and whistles to the API: resolved to positions using the `inspect` module. This is a bit of an experiment with Python 3’s improved ability to programmatically inspect argument signatures. I believe it is sound but not complete, which is a fine place to be. - (See also [#2069](https://github.com/google/jax/issues/2069).) + (See also [#2069](https://github.com/jax-ml/jax/issues/2069).) * Arguments can be marked non-differentiable using `nondiff_argnums`, and as with `jit`’s `static_argnums` these arguments don’t have to be JAX types. We need to set a convention for how these arguments are passed to the rules. For a primal @@ -433,5 +433,5 @@ There are some other bells and whistles to the API: `custom_lin` to the tangent values; `custom_lin` carries with it the user’s custom backward-pass function, and as a primitive it only has a transpose rule. - * This mechanism is described more in [#636](https://github.com/google/jax/issues/636). + * This mechanism is described more in [#636](https://github.com/jax-ml/jax/issues/636). * To prevent diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md index 65235dc64337..1e2270e052a6 100644 --- a/docs/jep/4008-custom-vjp-update.md +++ b/docs/jep/4008-custom-vjp-update.md @@ -9,7 +9,7 @@ notebook. ## What to update -After JAX [PR #4008](https://github.com/google/jax/pull/4008), the arguments +After JAX [PR #4008](https://github.com/jax-ml/jax/pull/4008), the arguments passed into a `custom_vjp` function's `nondiff_argnums` can't be `Tracer`s (or containers of `Tracer`s), which basically means to allow for arbitrarily-transformable code `nondiff_argnums` shouldn't be used for @@ -95,7 +95,7 @@ acted very much like lexical closure. But lexical closure over `Tracer`s wasn't at the time intended to work with `custom_jvp`/`custom_vjp`. Implementing `nondiff_argnums` that way was a mistake! -**[PR #4008](https://github.com/google/jax/pull/4008) fixes all lexical closure +**[PR #4008](https://github.com/jax-ml/jax/pull/4008) fixes all lexical closure issues with `custom_jvp` and `custom_vjp`.** Woohoo! That is, now `custom_jvp` and `custom_vjp` functions and rules can close over `Tracer`s to our hearts' content. For all non-autodiff transformations, things will Just Work. For @@ -120,9 +120,9 @@ manageable, until you think through how we have to handle arbitrary pytrees! Moreover, that complexity isn't necessary: if user code treats array-like non-differentiable arguments just like regular arguments and residuals, everything already works. (Before -[#4039](https://github.com/google/jax/pull/4039) JAX might've complained about +[#4039](https://github.com/jax-ml/jax/pull/4039) JAX might've complained about involving integer-valued inputs and outputs in autodiff, but after -[#4039](https://github.com/google/jax/pull/4039) those will just work!) +[#4039](https://github.com/jax-ml/jax/pull/4039) those will just work!) Unlike `custom_vjp`, it was easy to make `custom_jvp` work with `nondiff_argnums` arguments that were `Tracer`s. So these updates only need to diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md index eb68ee5f0e0a..f95c15f404b6 100644 --- a/docs/jep/4410-omnistaging.md +++ b/docs/jep/4410-omnistaging.md @@ -20,7 +20,7 @@ This is more of an upgrade guide than a design doc. ### What's going on? A change to JAX's tracing infrastructure called “omnistaging” -([google/jax#3370](https://github.com/google/jax/pull/3370)) was switched on in +([jax-ml/jax#3370](https://github.com/jax-ml/jax/pull/3370)) was switched on in jax==0.2.0. This change improves memory performance, trace execution time, and simplifies jax internals, but may cause some existing code to break. Breakage is usually a result of buggy code, so long-term it’s best to fix the bugs, but @@ -191,7 +191,7 @@ and potentially even fragmenting memory. (The `broadcast` that corresponds to the construction of the zeros array for `jnp.zeros_like(x)` is staged out because JAX is lazy about very simple -expressions from [google/jax#1668](https://github.com/google/jax/pull/1668). After +expressions from [jax-ml/jax#1668](https://github.com/jax-ml/jax/pull/1668). After omnistaging, we can remove that lazy sublanguage and simplify JAX internals.) The reason the creation of `mask` is not staged out is that, before omnistaging, diff --git a/docs/jep/9263-typed-keys.md b/docs/jep/9263-typed-keys.md index 828b95e8ce00..d520f6f63df9 100644 --- a/docs/jep/9263-typed-keys.md +++ b/docs/jep/9263-typed-keys.md @@ -321,7 +321,7 @@ Why introduce extended dtypes in generality, beyond PRNGs? We reuse this same extended dtype mechanism elsewhere internally. For example, the `jax._src.core.bint` object, a bounded integer type used for experimental work on dynamic shapes, is another extended dtype. In recent JAX versions it satisfies -the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/google/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)). +the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/jax-ml/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)). ### PRNG dtypes PRNG dtypes are defined as a particular case of extended dtypes. Specifically, diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index 3e99daabed93..a1ede3177a3a 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -8,7 +8,7 @@ "source": [ "# Design of Type Promotion Semantics for JAX\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n", "\n", "*Jake VanderPlas, December 2021*\n", "\n", diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index cdb1f7805b7e..ff67a8c21399 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -16,7 +16,7 @@ kernelspec: # Design of Type Promotion Semantics for JAX -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) *Jake VanderPlas, December 2021* diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index 759a9be86713..b964aa2af45d 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -58,11 +58,11 @@ These constraints imply the following rules for releases: * If a new `jaxlib` is released, a `jax` release must be made at the same time. These -[version constraints](https://github.com/google/jax/blob/main/jax/version.py) +[version constraints](https://github.com/jax-ml/jax/blob/main/jax/version.py) are currently checked by `jax` at import time, instead of being expressed as Python package version constraints. `jax` checks the `jaxlib` version at runtime rather than using a `pip` package version constraint because we -[provide separate `jaxlib` wheels](https://github.com/google/jax#installation) +[provide separate `jaxlib` wheels](https://github.com/jax-ml/jax#installation) for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we do not know which is the right choice for any given user, we do not want `pip` to install a `jaxlib` package for us automatically. @@ -119,7 +119,7 @@ no released `jax` version uses that API. ## How is the source to `jaxlib` laid out? `jaxlib` is split across two main repositories, namely the -[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib) +[`jaxlib/` subdirectory in the main JAX repository](https://github.com/jax-ml/jax/tree/main/jaxlib) and in the [XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla). The JAX-specific pieces inside XLA are primarily in the @@ -146,7 +146,7 @@ level. `jaxlib` is built using Bazel out of the `jax` repository. The pieces of `jaxlib` from the XLA repository are incorporated into the build -[as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE). +[as a Bazel submodule](https://github.com/jax-ml/jax/blob/main/WORKSPACE). To update the version of XLA used during the build, one must update the pinned version in the Bazel `WORKSPACE`. This is done manually on an as-needed basis, but can be overridden on a build-by-build basis. diff --git a/docs/jep/index.rst b/docs/jep/index.rst index 194eb0cb9d69..f9dda2657ced 100644 --- a/docs/jep/index.rst +++ b/docs/jep/index.rst @@ -32,7 +32,7 @@ should be linked to this issue. Then create a pull request that adds a file named `%d-{short-title}.md` - with the number being the issue number. -.. _JEP label: https://github.com/google/jax/issues?q=label%3AJEP +.. _JEP label: https://github.com/jax-ml/jax/issues?q=label%3AJEP .. toctree:: :maxdepth: 1 diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 0cffc22f1e8d..71bd4527644a 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)" + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)" ] }, { @@ -661,7 +661,7 @@ "source": [ "Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.\n", "\n", - "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)." + "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)." ] }, { @@ -1003,7 +1003,7 @@ "id": "COjzGBpO4tzL" }, "source": [ - "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", + "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", "\n", "The random state is described by a special array element that we call a __key__:" ] @@ -1349,7 +1349,7 @@ "\n", "For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n", "\n", - "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n", + "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n", "\n", "By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n", "\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 543d9ecb1558..741fa3af063c 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) +++ {"id": "4k5PVzEo2uJO"} @@ -312,7 +312,7 @@ jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan) Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error. -Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior). +Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior). +++ {"id": "LwB07Kx5sgHu"} @@ -460,7 +460,7 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha +++ {"id": "COjzGBpO4tzL"} -JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. +JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. The random state is described by a special array element that we call a __key__: @@ -623,7 +623,7 @@ When we `jit`-compile a function, we usually want to compile a version of the fu For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time. -To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels. +To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels. By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time. diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index dd7a36e57079..5c09a0a4f732 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", "There are two ways to define differentiation rules in JAX:\n", "\n", diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 930887af1e1b..8a9b931552d9 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -17,7 +17,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) There are two ways to define differentiation rules in JAX: diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 8bc0e0a52ce6..32d332d9ac7e 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -17,7 +17,7 @@ "id": "pFtQjv4SzHRj" }, "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n", "\n", "This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer." ] diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index 97b07172b707..2142db9866ae 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -19,7 +19,7 @@ kernelspec: +++ {"id": "pFtQjv4SzHRj"} -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer. diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb index 0c20fc47dc47..e9924e18d023 100644 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ b/docs/notebooks/How_JAX_primitives_work.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", "\n", "*necula@google.com*, October 2019.\n", "\n", diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md index b926c22ea32f..7c24ac11a6ce 100644 --- a/docs/notebooks/How_JAX_primitives_work.md +++ b/docs/notebooks/How_JAX_primitives_work.md @@ -17,7 +17,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) *necula@google.com*, October 2019. diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index a4a4d7d1652b..a7ef2a017048 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n", "\n", "**Copyright 2018 The JAX Authors.**\n", "\n", @@ -32,9 +32,9 @@ "id": "B_XlLLpcWjkA" }, "source": [ - "![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n", + "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n", + "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index 03b8415fc91c..cd98022e7421 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) **Copyright 2018 The JAX Authors.** @@ -35,9 +35,9 @@ limitations under the License. +++ {"id": "B_XlLLpcWjkA"} -![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png) +![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library). +Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 2c231bf99c46..00ba9186eeec 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)" + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)" ] }, { @@ -79,7 +79,7 @@ "id": "gA8V51wZdsjh" }, "source": [ - "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/google/jax#how-it-works) section in the README." + "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/jax-ml/jax#how-it-works) section in the README." ] }, { @@ -320,7 +320,7 @@ "source": [ "Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n", "\n", - "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover." + "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover." ] }, { @@ -333,7 +333,7 @@ "\n", "An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n", "\n", - "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234)." + "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234)." ] }, { diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 41d7a7e51dfc..10c4e7cb6e3b 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) +++ {"id": "r-3vMiKRYXPJ"} @@ -57,7 +57,7 @@ fast_f = jit(f) +++ {"id": "gA8V51wZdsjh"} -When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/google/jax#how-it-works) section in the README. +When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/jax-ml/jax#how-it-works) section in the README. +++ {"id": "2Th1vYLVaFBz"} @@ -223,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5)) Notice that `eval_jaxpr` will always return a flat list even if the original function does not. -Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover. +Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover. +++ {"id": "0vb2ZoGrCMM4"} @@ -231,7 +231,7 @@ Furthermore, this interpreter does not handle higher-order primitives (like `jit An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry. -It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234). +It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234). ```{code-cell} ipython3 :id: gSMIT2z1vUpO diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 3f2f0fd5650d..5538b70dac93 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", "\n", "JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics." ] @@ -255,7 +255,7 @@ "id": "cJ2NxiN58bfI" }, "source": [ - "You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)." + "You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)." ] }, { @@ -1015,7 +1015,7 @@ "source": [ "### Jacobian-Matrix and Matrix-Jacobian products\n", "\n", - "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products." + "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products." ] }, { diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index 496d676f794a..db6fde8051d1 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics. @@ -151,7 +151,7 @@ print(grad(loss2)({'W': W, 'b': b})) +++ {"id": "cJ2NxiN58bfI"} -You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.). +You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.). +++ {"id": "PaCHzAtGruBz"} @@ -592,7 +592,7 @@ print("Naive full Hessian materialization") ### Jacobian-Matrix and Matrix-Jacobian products -Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products. +Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products. ```{code-cell} ipython3 :id: asAWvxVaCmsx diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index f628625bd041..9d91804b6021 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)\n", "\n", "JAX provides a number of interfaces to compute convolutions across data, including:\n", "\n", diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 83ab2d9fd56d..b98099aa9571 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) JAX provides a number of interfaces to compute convolutions across data, including: diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 91f2ee571b4b..c31a99746866 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -40,11 +40,11 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n", "\n", "_Forked from_ `neural_network_and_data_loading.ipynb`\n", "\n", - "![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n", + "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 480e7477b8ac..53b7d47358c2 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -38,11 +38,11 @@ limitations under the License. -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) _Forked from_ `neural_network_and_data_loading.ipynb` -![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png) +![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index e4f9d888e6fc..b5f8074c0f3e 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", "\n", "JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively." ] diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index dd0c73ec7699..b3672b90e653 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -17,7 +17,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively. diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index 9aef1a8eb599..dccc83168ac0 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n", "\n", "This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n", "\n", diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index 5989a87bc141..3f836e680e88 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs. diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index 93d7e55473f2..4a2d4daa637f 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -21,7 +21,7 @@ software emulation, and can slow down the computation. If you see unexpected outputs, please compare them against a kernel run with ``interpret=True`` passed in to ``pallas_call``. If the results diverge, - please file a `bug report `_. + please file a `bug report `_. What is a TPU? -------------- diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 1d6a5d9b701a..47a7587b620f 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -29,7 +29,7 @@ f(x) ### Setting cache directory The compilation cache is enabled when the -[cache location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206) +[cache location](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206) is set. This should be done prior to the first compilation. Set the location as follows: @@ -54,7 +54,7 @@ os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache" jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") ``` -(3) Using [`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18) +(3) Using [`set_cache_dir()`](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18) ```python from jax.experimental.compilation_cache import compilation_cache as cc diff --git a/docs/sphinxext/jax_extensions.py b/docs/sphinxext/jax_extensions.py index 3a78557632a7..7cce8b88254d 100644 --- a/docs/sphinxext/jax_extensions.py +++ b/docs/sphinxext/jax_extensions.py @@ -26,14 +26,14 @@ def jax_issue_role(name, rawtext, text, lineno, inliner, options=None, :jax-issue:`1234` This will output a hyperlink of the form - `#1234 `_. These links work even + `#1234 `_. These links work even for PR numbers. """ text = text.lstrip('#') if not text.isdigit(): raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.") options = {} if options is None else options - url = f"https://github.com/google/jax/issues/{text}" + url = f"https://github.com/jax-ml/jax/issues/{text}" node = nodes.reference(rawtext, '#' + text, refuri=url, **options) return [node], [] diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 4eb6e7a66cdd..2ff82e0431e2 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -234,4 +234,4 @@ Handling parameters manually seems fine if you're dealing with two parameters, b 2) Are we supposed to pipe all these things around manually? -The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/google/jax#neural-network-libraries) for some examples. +The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/jax-ml/jax#neural-network-libraries) for some examples. diff --git a/jax/__init__.py b/jax/__init__.py index e2e302adb855..c6e073699b0c 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -29,7 +29,7 @@ # Defensively swallow any exceptions to avoid making jax unimportable from warnings import warn as _warn _warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report " - f"an issue at https://github.com/google/jax/issues") + f"an issue at https://github.com/jax-ml/jax/issues") del _warn del _cloud_tpu_init @@ -38,7 +38,7 @@ del _core # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.basearray import Array as Array from jax import tree as tree diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 8c7fe2f489d5..39df07359c18 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -546,7 +546,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): # To avoid precision mismatches in fwd and bwd passes due to XLA excess # precision, insert explicit x = reduce_precision(x, **finfo(x.dtype)) calls - # on producers of any residuals. See https://github.com/google/jax/pull/22244. + # on producers of any residuals. See https://github.com/jax-ml/jax/pull/22244. jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res) # compute known outputs and residuals (hoisted out of remat primitive) diff --git a/jax/_src/api.py b/jax/_src/api.py index aae99a28bbea..bd8a951954ac 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -956,7 +956,7 @@ def vmap(fun: F, # list: if in_axes is not a leaf, it must be a tuple of trees. However, # in cases like these users expect tuples and lists to be treated # essentially interchangeably, so we canonicalize lists to tuples here - # rather than raising an error. https://github.com/google/jax/issues/2367 + # rather than raising an error. https://github.com/jax-ml/jax/issues/2367 in_axes = tuple(in_axes) if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}): @@ -2505,7 +2505,7 @@ def __eq__(self, other): def __hash__(self): # TODO(frostig): avoid the conversion from dict by addressing - # https://github.com/google/jax/issues/8182 + # https://github.com/jax-ml/jax/issues/8182 return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type)) def _sds_aval_mapping(x): diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 453a4eba47bf..3a18dcdfa2ac 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -196,7 +196,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): device_assignment = axis_context.device_assignment if device_assignment is None: raise AssertionError( - "Please file a bug at https://github.com/google/jax/issues") + "Please file a bug at https://github.com/jax-ml/jax/issues") try: device_index = device_assignment.index(device) except IndexError as e: diff --git a/jax/_src/config.py b/jax/_src/config.py index fe56ec68f6cb..b21d2f35f9a4 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1170,7 +1170,7 @@ def _update_jax_memories_thread_local(val): upgrade=True, help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should ' 'improve memory usage and stability. Set True to use new ' - 'behavior. See https://github.com/google/jax/pull/15677'), + 'behavior. See https://github.com/jax-ml/jax/pull/15677'), update_global_hook=lambda val: _update_global_jit_state( softmax_custom_jvp=val), update_thread_local_hook=lambda val: update_thread_local_jit_state( diff --git a/jax/_src/core.py b/jax/_src/core.py index 057a79925e2e..bff59625b702 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -935,7 +935,7 @@ def unsafe_buffer_pointer(self): class EvalTrace(Trace): - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 def pure(self, x): return x lift = sublift = pure @@ -998,7 +998,7 @@ def with_cur_sublevel(self): return self.trace_type(self, cur_sublevel(), **self.payload) class TraceStack: - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 stack: list[MainTrace] dynamic: MainTrace @@ -1167,7 +1167,7 @@ def _why_alive(ignore_ids: set[int], x: Any) -> str: # parent->child jump. We do that by setting `parent` here to be a # grandparent (or great-grandparent) of `child`, and then handling that case # in _why_alive_container_info. See example: - # https://github.com/google/jax/pull/13022#discussion_r1008456599 + # https://github.com/jax-ml/jax/pull/13022#discussion_r1008456599 # To prevent this collapsing behavior, just comment out this code block. if (isinstance(parent, dict) and getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]): @@ -1213,7 +1213,7 @@ def _why_alive_container_info(container, obj_id) -> str: @contextmanager def new_main(trace_type: type[Trace], dynamic: bool = False, **payload) -> Generator[MainTrace, None, None]: - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 stack = thread_local_state.trace_state.trace_stack level = stack.next_level() main = MainTrace(level, trace_type, **payload) @@ -1254,7 +1254,7 @@ def dynamic_level() -> int: @contextmanager def new_base_main(trace_type: type[Trace], **payload) -> Generator[MainTrace, None, None]: - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 stack = thread_local_state.trace_state.trace_stack main = MainTrace(0, trace_type, **payload) prev_dynamic, stack.dynamic = stack.dynamic, main @@ -1319,7 +1319,7 @@ def f(x): else: return jnp.cos(x) - Here's a real-world example from https://github.com/google/jax/issues/3974:: + Here's a real-world example from https://github.com/jax-ml/jax/issues/3974:: import jax import jax.numpy as jnp @@ -1680,7 +1680,7 @@ def strip_weak_type(self): @property def shape(self): msg = ("UnshapedArray has no shape. Please open an issue at " - "https://github.com/google/jax/issues because it's unexpected for " + "https://github.com/jax-ml/jax/issues because it's unexpected for " "UnshapedArray instances to ever be produced.") raise TypeError(msg) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 07df2321be95..35e7d33430bd 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -191,7 +191,7 @@ def jvp_of_rule_rule(axis_size, in_batched, primals, tangents): # TODO(frostig): assert these also equal: # treedef_tuple((in_tree, in_tree)) - # once https://github.com/google/jax/issues/9066 is fixed + # once https://github.com/jax-ml/jax/issues/9066 is fixed assert tree_ps_ts == tree_ps_ts2 del tree_ps_ts2 diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 88be655a0ddd..f5ecdfcda286 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1144,7 +1144,7 @@ def rev(objective_fn, res, g): def _maybe_perturbed(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value # with a nontrivial tangent attached), up to heuristics, and True otherwise. - # See https://github.com/google/jax/issues/6415 for motivation. + # See https://github.com/jax-ml/jax/issues/6415 for motivation. x = core.full_lower(x) if not isinstance(x, core.Tracer): # If x is not a Tracer, it can't be perturbed. diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 8f48746dda37..984d55fe2f6b 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -492,7 +492,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, devices = axis_context.device_assignment if devices is None: raise AssertionError( - 'Please file a bug at https://github.com/google/jax/issues') + 'Please file a bug at https://github.com/jax-ml/jax/issues') if axis_context.mesh_shape is not None: ma, ms = list(zip(*axis_context.mesh_shape)) mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 3373496940e2..465dc90e21da 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -389,7 +389,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, devices = axis_context.device_assignment if devices is None: raise AssertionError( - 'Please file a bug at https://github.com/google/jax/issues') + 'Please file a bug at https://github.com/jax-ml/jax/issues') elif isinstance(axis_context, sharding_impls.SPMDAxisContext): devices = axis_context.mesh._flat_devices_tuple else: diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index d76b80ad3a89..352a3e550112 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -793,7 +793,7 @@ def check_user_dtype_supported(dtype, fun_name=None): "and will be truncated to dtype {}. To enable more dtypes, set the " "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " "environment variable. " - "See https://github.com/google/jax#current-gotchas for more.") + "See https://github.com/jax-ml/jax#current-gotchas for more.") fun_name = f"requested in {fun_name}" if fun_name else "" truncated_dtype = canonicalize_dtype(np_dtype).name warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3) diff --git a/jax/_src/flatten_util.py b/jax/_src/flatten_util.py index e18ad1f6e793..11a9dda66e74 100644 --- a/jax/_src/flatten_util.py +++ b/jax/_src/flatten_util.py @@ -61,7 +61,7 @@ def _ravel_list(lst): if all(dt == to_dtype for dt in from_dtypes): # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`. - # See https://github.com/google/jax/issues/7809. + # See https://github.com/jax-ml/jax/issues/7809. del from_dtypes, to_dtype raveled = jnp.concatenate([jnp.ravel(e) for e in lst]) return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 4bf6d1ceb145..2c94907568d9 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -30,7 +30,7 @@ particular test, we write them as `Limitation` objects that can be reused in multiple tests and can also be used to generate documentation, e.g., the report of [unsupported and partially-implemented JAX -primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) +primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) The limitations are used to filter out from tests the harnesses that are known to fail. A Limitation is specific to a harness. @@ -515,7 +515,7 @@ def _make_convert_element_type_harness(name, for old_dtype in jtu.dtypes.all: # TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating # point numbers and new_dtype is an unsigned integer. See issue - # https://github.com/google/jax/issues/5082 for details. + # https://github.com/jax-ml/jax/issues/5082 for details. for new_dtype in (jtu.dtypes.all if not (dtypes.issubdtype(old_dtype, np.floating) or dtypes.issubdtype(old_dtype, np.complexfloating)) @@ -2336,7 +2336,7 @@ def _make_select_and_scatter_add_harness(name, # Validate padding for padding in [ # TODO(bchetioui): commented out the test based on - # https://github.com/google/jax/issues/4690 + # https://github.com/jax-ml/jax/issues/4690 # ((1, 2), (2, 3), (3, 4)) # non-zero padding ((1, 1), (1, 1), (1, 1)) # non-zero padding ]: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 6bc6539b9262..6bc3cceb7ab7 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1262,7 +1262,7 @@ def partial_eval_jaxpr_custom_rule_not_implemented( name: str, saveable: Callable[..., RematCases_], unks_in: Sequence[bool], inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult: msg = (f'custom-policy remat rule not implemented for {name}, ' - 'open a feature request at https://github.com/google/jax/issues!') + 'open a feature request at https://github.com/jax-ml/jax/issues!') raise NotImplementedError(msg) @@ -2688,7 +2688,7 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): # TODO(mattjj): the following are deprecated; update callers to _nounits version -# See https://github.com/google/jax/pull/9498 +# See https://github.com/jax-ml/jax/pull/9498 @lu.transformation def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool], pvals: Sequence[PartialVal]): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 944e20fa7faa..de668090eaa1 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -500,7 +500,7 @@ def process_map(self, map_primitive, fun, tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): if symbolic_zeros: msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. " - "Please open an issue at https://github.com/google/jax/issues !") + "Please open an issue at https://github.com/jax-ml/jax/issues !") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros # always base main, can drop jvp in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) @@ -513,7 +513,7 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): if symbolic_zeros: msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. " - "Please open an issue at https://github.com/google/jax/issues !") + "Please open an issue at https://github.com/jax-ml/jax/issues !") raise NotImplementedError(msg) del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) @@ -1869,7 +1869,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " - "See https://github.com/google/jax/issues/2926.") + "See https://github.com/jax-ml/jax/issues/2926.") if nreps > xb.device_count(backend): raise ValueError( diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 4cb38d28c36f..d3065d0f96d7 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -389,7 +389,7 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, branch_outs = [] for i, jaxpr in enumerate(branches_batched): # Perform a select on the inputs for safety of reverse-mode autodiff; see - # https://github.com/google/jax/issues/1052 + # https://github.com/jax-ml/jax/issues/1052 predicate = lax.eq(index, lax._const(index, i)) ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops] branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_)) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 41d809f8d688..7a9596bf2c0d 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -715,7 +715,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: raise NotImplementedError if not all(init_lin): - pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963 + pass # TODO(mattjj): error check https://github.com/jax-ml/jax/issues/1963 consts, _, xs = split_list(args, [num_consts, num_carry]) ires, _ = split_list(consts, [num_ires]) @@ -1169,7 +1169,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts, if discharged_consts: raise NotImplementedError("Discharged jaxpr has consts. If you see this, " "please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") def wrapped(*wrapped_args): val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args, [n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs]) @@ -1838,7 +1838,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr, if body_jaxpr_consts: raise NotImplementedError("Body jaxpr has consts. If you see this error, " "please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") # body_jaxpr has the signature (*body_consts, *carry) -> carry. # Some of these body_consts are actually `Ref`s so when we discharge # them, they also turn into outputs, effectively turning those consts into diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 0cbee6d2bfbc..36553e512cd7 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -157,7 +157,7 @@ def _irfft_transpose(t, fft_lengths): out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) # Use JAX's convention for complex gradients - # https://github.com/google/jax/issues/6223#issuecomment-807740707 + # https://github.com/jax-ml/jax/issues/6223#issuecomment-807740707 return lax.conj(out) def _fft_transpose_rule(t, operand, fft_type, fft_lengths): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 48af9c64ffc9..394a54c357b8 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1077,7 +1077,7 @@ def comp(x, y): if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " - "at https://github.com/google/jax.") + "at https://github.com/jax-ml/jax.") return jaxpr, tuple(consts) @cache() @@ -1090,7 +1090,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree): if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " - "at https://github.com/google/jax.") + "at https://github.com/jax-ml/jax.") return core.ClosedJaxpr(jaxpr, consts), out_tree() def _get_monoid_reducer(monoid_op: Callable, @@ -4911,7 +4911,7 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs): return tree_util.tree_unflatten(p.out_tree(), out_flat) -# TODO(https://github.com/google/jax/issues/13552): Look into making this a +# TODO(https://github.com/jax-ml/jax/issues/13552): Look into making this a # method on jax.Array so that we can bypass the XLA compilation here. def _copy_impl(prim, *args, **kwargs): a, = args diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index ec0a075dae1b..453e79a5c7f8 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -781,7 +781,7 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' 'eigenvalues. See ' - 'https://github.com/google/jax/issues/2748 for discussion.') + 'https://github.com/jax-ml/jax/issues/2748 for discussion.') # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in # https://arxiv.org/abs/1701.00392 a, = primals diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index b2bcc53a53f8..e8fcb433438a 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -27,7 +27,7 @@ except ModuleNotFoundError as err: raise ModuleNotFoundError( 'jax requires jaxlib to be installed. See ' - 'https://github.com/google/jax#installation for installation instructions.' + 'https://github.com/jax-ml/jax#installation for installation instructions.' ) from err import jax.version @@ -92,7 +92,7 @@ def _parse_version(v: str) -> tuple[int, ...]: jax_jit = xla_client._xla.jax_jit pmap_lib = xla_client._xla.pmap_lib -# XLA garbage collection: see https://github.com/google/jax/issues/14882 +# XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 def _xla_gc_callback(*args): xla_client._xla.collect_garbage() gc.callbacks.append(_xla_gc_callback) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 20234b678172..08c8bfcb3a29 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -333,7 +333,7 @@ class AbstractMesh: should use this as an input to the sharding passed to with_sharding_constraint and mesh passed to shard_map to avoid tracing and lowering cache misses when your mesh shape and names stay the same but the devices change. - See the description of https://github.com/google/jax/pull/23022 for more + See the description of https://github.com/jax-ml/jax/pull/23022 for more details. """ diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a1601e9201fe..5b936268581a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3953,7 +3953,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], arrays_out = [asarray(arr, dtype=dtype) for arr in arrays] # lax.concatenate can be slow to compile for wide concatenations, so form a # tree of concatenations as a workaround especially for op-by-op mode. - # (https://github.com/google/jax/issues/653). + # (https://github.com/jax-ml/jax/issues/653). k = 16 while len(arrays_out) > 1: arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) @@ -4645,7 +4645,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in - # https://github.com/google/jax/pull/6047. More correct would be to call + # https://github.com/jax-ml/jax/pull/6047. More correct would be to call # coerce_to_array on each leaf, but this may have performance implications. out = np.asarray(object, dtype=dtype) elif isinstance(object, Array): @@ -10150,11 +10150,11 @@ def _eliminate_deprecated_list_indexing(idx): if any(_should_unpack_list_index(i) for i in idx): msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " "use `arr[tuple(seq)]` instead of `arr[seq]`. " - "See https://github.com/google/jax/issues/4564 for more information.") + "See https://github.com/jax-ml/jax/issues/4564 for more information.") else: msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " "use `arr[array(seq)]` instead of `arr[seq]`. " - "See https://github.com/google/jax/issues/4564 for more information.") + "See https://github.com/jax-ml/jax/issues/4564 for more information.") raise TypeError(msg) else: idx = (idx,) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index b6aea9e195a9..043c976ef6f5 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -968,7 +968,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy msg = ("jax.numpy.var does not yet support real dtype parameters when " "computing the variance of an array of complex values. The " "semantics of numpy.var seem unclear in this case. Please comment " - "on https://github.com/google/jax/issues/2283 if this behavior is " + "on https://github.com/jax-ml/jax/issues/2283 if this behavior is " "important to you.") raise ValueError(msg) computation_dtype = dtype diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 7e8acb090279..6491a7617d8d 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -134,7 +134,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. - The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. + The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. In order to ensure statically-known output shapes, you can pass a static ``size`` argument: @@ -217,7 +217,7 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. - The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. + The error occurred while tracing the function union1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. In order to ensure statically-known output shapes, you can pass a static ``size`` argument: diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index b45b3370fe53..00b5311b8415 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2176,7 +2176,7 @@ def sinc(x: ArrayLike, /) -> Array: def _sinc_maclaurin(k, x): # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we # compute the monomial term in the jvp rule) - # TODO(mattjj): see https://github.com/google/jax/issues/10750 + # TODO(mattjj): see https://github.com/jax-ml/jax/issues/10750 if k % 2: return x * 0 else: diff --git a/jax/_src/pallas/mosaic/error_handling.py b/jax/_src/pallas/mosaic/error_handling.py index 5340eb3fa654..f8231f5b24b6 100644 --- a/jax/_src/pallas/mosaic/error_handling.py +++ b/jax/_src/pallas/mosaic/error_handling.py @@ -35,7 +35,7 @@ ) MLIR_ERR_PREFIX = ( 'Pallas encountered an internal verification error.' - 'Please file a bug at https://github.com/google/jax/issues. ' + 'Please file a bug at https://github.com/jax-ml/jax/issues. ' 'Error details: ' ) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f120bbabf8a4..d4dc534d034a 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -833,7 +833,7 @@ def write_env(var: jax_core.Var, val): raise NotImplementedError( "Unimplemented primitive in Pallas TPU lowering: " f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/google/jax/issues.") + "Please file an issue on https://github.com/jax-ml/jax/issues.") if eqn.primitive.multiple_results: map(write_env, eqn.outvars, ans) else: diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5eaf6e5233cb..4b8199b36105 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -549,7 +549,7 @@ def write_env(var: jax_core.Var, val): raise NotImplementedError( "Unimplemented primitive in Pallas Mosaic GPU lowering: " f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/google/jax/issues." + "Please file an issue on https://github.com/jax-ml/jax/issues." ) rule = mosaic_lowering_rules[eqn.primitive] rule_ctx = LoweringRuleContext( diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index ac28bd21a3dc..6a16156271d7 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -381,7 +381,7 @@ def write_env(var: jax_core.Var, val): raise NotImplementedError( "Unimplemented primitive in Pallas GPU lowering: " f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/google/jax/issues.") + "Please file an issue on https://github.com/jax-ml/jax/issues.") rule = triton_lowering_rules[eqn.primitive] avals_in = [v.aval for v in eqn.invars] avals_out = [v.aval for v in eqn.outvars] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0abaa3fd0139..ac1318ed7810 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -459,7 +459,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, # list: if in_axes is not a leaf, it must be a tuple of trees. However, # in cases like these users expect tuples and lists to be treated # essentially interchangeably, so we canonicalize lists to tuples here - # rather than raising an error. https://github.com/google/jax/issues/2367 + # rather than raising an error. https://github.com/jax-ml/jax/issues/2367 in_shardings = tuple(in_shardings) in_layouts, in_shardings = _split_layout_and_sharding(in_shardings) @@ -1276,7 +1276,7 @@ def unpack(key): return done() # we think this is unreachable... - p("explanation unavailable! please open an issue at https://github.com/google/jax") + p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax") return done() @partial(lu.cache, explain=explain_tracing_cache_miss) @@ -1701,7 +1701,7 @@ def _pjit_call_impl_python( "`jit` decorator, at the cost of losing optimizations. " "\n\n" "If you see this error, consider opening a bug report at " - "https://github.com/google/jax.") + "https://github.com/jax-ml/jax.") raise FloatingPointError(msg) diff --git a/jax/_src/random.py b/jax/_src/random.py index d889713f6c3d..203f72d406e5 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -508,7 +508,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: span = lax.convert_element_type(maxval - minval, unsigned_dtype) # Ensure that span=1 when maxval <= minval, so minval is always returned; - # https://github.com/google/jax/issues/222 + # https://github.com/jax-ml/jax/issues/222 span = lax.select(maxval <= minval, lax.full_like(span, 1), span) # When maxval is out of range, the span has to be one larger. @@ -2540,7 +2540,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array: _btrs(key, count_btrs, q_btrs, shape, dtype, max_iters), ) # ensure nan q always leads to nan output and nan or neg count leads to nan - # as discussed in https://github.com/google/jax/pull/16134#pullrequestreview-1446642709 + # as discussed in https://github.com/jax-ml/jax/pull/16134#pullrequestreview-1446642709 invalid = (q_l_0 | q_is_nan | count_nan_or_neg) samples = lax.select( invalid, diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index d81008308b94..ee144eaf990a 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -176,7 +176,7 @@ def map_coordinates( Note: Interpolation near boundaries differs from the scipy function, because JAX - fixed an outstanding bug; see https://github.com/google/jax/issues/11097. + fixed an outstanding bug; see https://github.com/jax-ml/jax/issues/11097. This function interprets the ``mode`` argument as documented by SciPy, but not as implemented by SciPy. """ diff --git a/jax/_src/shard_alike.py b/jax/_src/shard_alike.py index 2361eaf6426d..574d725c4999 100644 --- a/jax/_src/shard_alike.py +++ b/jax/_src/shard_alike.py @@ -44,7 +44,7 @@ def shard_alike(x, y): raise ValueError( 'The leaves shapes of `x` and `y` should match. Got `x` leaf shape:' f' {x_aval.shape} and `y` leaf shape: {y_aval.shape}. File an issue at' - ' https://github.com/google/jax/issues if you want this feature.') + ' https://github.com/jax-ml/jax/issues if you want this feature.') outs = [shard_alike_p.bind(x_, y_) for x_, y_ in safe_zip(x_flat, y_flat)] x_out_flat, y_out_flat = zip(*outs) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 5afcd5e3a718..81737f27540b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1208,7 +1208,7 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', y = np.asarray(y) if (not allow_object_dtype) and (x.dtype == object or y.dtype == object): - # See https://github.com/google/jax/issues/17867 + # See https://github.com/jax-ml/jax/issues/17867 raise TypeError( "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. " "If comparing PRNG keys, consider random_test.KeyArrayTest.assertKeysEqual. " diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 0caa6e7c643b..010841b45dd2 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -21,7 +21,7 @@ and may change without notice. To see the proposal that led to the development of these tools, see -https://github.com/google/jax/pull/11859/. +https://github.com/jax-ml/jax/pull/11859/. """ from __future__ import annotations diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 1d3c50403b47..796093b6225f 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -1232,7 +1232,7 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs): if library_path is None: raise RuntimeError( "JAX TPU support not installed; cannot generate TPU topology. See" - " https://github.com/google/jax#installation") + " https://github.com/jax-ml/jax#installation") c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path) xla_client.profiler.register_plugin_profiler(c_api) assert xla_client.pjrt_plugin_loaded("tpu") diff --git a/jax/core.py b/jax/core.py index 9857fcf88c02..cdf8d76558d9 100644 --- a/jax/core.py +++ b/jax/core.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.core import ( AbstractToken as AbstractToken, diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 8e517f5d4610..ea1ef4f0274e 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.custom_derivatives import ( _initial_style_jaxpr, diff --git a/jax/dtypes.py b/jax/dtypes.py index f2071fd4fe56..a6f1b764510b 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.dtypes import ( bfloat16 as bfloat16, diff --git a/jax/errors.py b/jax/errors.py index 15a6654fa32d..2a811661d1ae 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.errors import ( JAXTypeError as JAXTypeError, diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index caf27ec7a8ca..1b22c2c2ada9 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.experimental.x64_context import ( enable_x64 as enable_x64, diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 69b25d0b6ad9..3ac1d4246f6a 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 import sys as _sys import warnings as _warnings diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index 0b6b51f71a3c..8e11d4173afe 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.checkify import ( Error as Error, diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 3c7bfac40061..6da3ad7c5d4b 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.custom_partitioning import ( custom_partitioning as custom_partitioning, diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 63c3299c5904..49162809a325 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -17,7 +17,7 @@ The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. This module introduces the host callback functions :func:`call`, :func:`id_tap`, and :func:`id_print`, that send their arguments from the device @@ -363,11 +363,11 @@ def power3_with_cotangents(x): This is relatively easy to do, once one understands both the JAX custom VJP and the TensorFlow autodiff mechanisms. The code for how this can be done is shown in the ``call_tf_full_ad`` -function in `host_callback_to_tf_test.py `_. +function in `host_callback_to_tf_test.py `_. This example supports arbitrary higher-order differentiation as well. Note that if you just want to call TensorFlow functions from JAX, you can also -use the `jax2tf.call_tf function `_. +use the `jax2tf.call_tf function `_. Using :func:`call` to call a JAX function on another device, with reverse-mode autodiff support ------------------------------------------------------------------------------------------------ @@ -378,7 +378,7 @@ def power3_with_cotangents(x): computation will run, and then the results are sent back to the original accelerator. The code for how this can be done is shown in the ``call_jax_other_device function`` -in `host_callback_test.py `_. +in `host_callback_test.py `_. Low-level details and debugging ------------------------------- @@ -572,7 +572,7 @@ def power3_with_cotangents(x): help=( 'Use old implementation of host_callback, documented in the module docstring.' 'If False, use the jax.experimental.io_callback implementation. ' - 'See https://github.com/google/jax/issues/20385.' + 'See https://github.com/jax-ml/jax/issues/20385.' ) ) @@ -592,7 +592,7 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): "See https://jax.readthedocs.io/en/latest/debugging/index.html and " "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html" " for alternatives. Please file a feature request at " - "https://github.com/google/jax/issues if none of the alternatives are " + "https://github.com/jax-ml/jax/issues if none of the alternatives are " "sufficient.") @@ -608,7 +608,7 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): class CallbackFlavor(enum.Enum): """Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. """ IO_CALLBACK = 1 # uses jax.experimental.io_callback PURE = 2 # uses jax.pure_callback @@ -629,7 +629,7 @@ def _deprecated_id_tap(tap_func, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime @@ -655,7 +655,7 @@ def _deprecated_id_tap(tap_func, i.e., does not work on CPU unless --jax_host_callback_outfeed=True. callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies the flavor of callback to use. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. Returns: ``arg``, or ``result`` if given. @@ -712,7 +712,7 @@ def _deprecated_id_print(arg, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. On each invocation of the printing tap, the ``kwargs`` if present will be printed first (sorted by keys). Then arg will be printed, @@ -730,7 +730,7 @@ def _deprecated_id_print(arg, * ``threshold`` is passed to ``numpy.array2string``. * ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies the flavor of callback to use. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. For more details see the :mod:`jax.experimental.host_callback` module documentation. """ @@ -757,7 +757,7 @@ def _deprecated_call(callback_func: Callable, arg, *, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. Args: callback_func: The Python function to invoke on the host as @@ -787,7 +787,7 @@ def _deprecated_call(callback_func: Callable, arg, *, i.e., does not work on CPU unless --jax_host_callback_outfeed=True. callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies the flavor of callback to use. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. Returns: the result of the ``callback_func`` invocation. @@ -800,7 +800,7 @@ def _deprecated_call(callback_func: Callable, arg, *, raise NotImplementedError( "When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` " "flavor of callback only when the `result_shape` is None. " - "See https://github.com/google/jax/issues/20385." + "See https://github.com/jax-ml/jax/issues/20385." ) return _call(callback_func, arg, result_shape=result_shape, call_with_device=call_with_device, identity=False, @@ -819,7 +819,7 @@ def __init__(self, callback_func, identity, call_with_device): raise NotImplementedError( "When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs" " do not support `tap_with_device` and `call_with_device`. " - "See https://github.com/google/jax/issues/20385.") + "See https://github.com/jax-ml/jax/issues/20385.") def __hash__(self): return hash((self.callback_func, self.identity, self.call_with_device)) @@ -2121,7 +2121,7 @@ def _deprecated_stop_outfeed_receiver(): _deprecation_msg = ( "The host_callback APIs are deprecated as of March 20, 2024. The functionality " "is subsumed by the new JAX external callbacks. " - "See https://github.com/google/jax/issues/20385.") + "See https://github.com/jax-ml/jax/issues/20385.") _deprecations = { # Added March 20, 2024 diff --git a/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb b/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb index 4f23d88e036e..3613dba0ef06 100644 --- a/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb +++ b/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb @@ -26,7 +26,7 @@ "Link: go/jax2tf-colab\n", "\n", "The JAX2TF colab has been deprecated, and the example code has\n", - "been moved to [jax2tf/examples](https://github.com/google/jax/tree/main/jax/experimental/jax2tf/examples). \n" + "been moved to [jax2tf/examples](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf/examples). \n" ] } ] diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index dbdc4f563368..b77474c03728 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -103,10 +103,10 @@ For more involved examples, please see examples involving: * SavedModel for archival ([examples below](#usage-saved-model)), including saving [batch-polymorphic functions](#shape-polymorphic-conversion), - * TensorFlow Lite ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)), - * TensorFlow.js ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md)), + * TensorFlow Lite ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)), + * TensorFlow.js ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md)), * TFX ([examples](https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/README.md#instructions-for-using-flax)), - * TensorFlow Hub and Keras ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md)). + * TensorFlow Hub and Keras ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md)). [TOC] @@ -249,7 +249,7 @@ graph (they will be saved in a `variables` area of the model, which is not subject to the 2GB limitation). For examples of how to save a Flax model as a SavedModel see the -[examples directory](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md). +[examples directory](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md). ### Saved model and differentiation @@ -619,7 +619,7 @@ Cannot solve for values of dimension variables {'a', 'b'}. " We can only solve linear uni-variate constraints. " Using the following polymorphic shapes specifications: args[0].shape = (a + b,). Unprocessed specifications: 'a + b' for dimension size args[0].shape[0]. " -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details. +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details. ``` ### Shape assertion errors @@ -645,7 +645,7 @@ Input shapes do not match the polymorphic shapes specification. Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details. +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details. ``` When using native serialization these are checked by the `tf.XlaCallModule` @@ -869,7 +869,7 @@ leads to errors for the following expressions `b == a or b == b` or `b in [a, b] even though the error is avoided if we change the order of the comparisons. We attempted to retain soundness and hashability by creating both hashable and unhashable -kinds of symbolic dimensions [PR #14200](https://github.com/google/jax/pull/14200), +kinds of symbolic dimensions [PR #14200](https://github.com/jax-ml/jax/pull/14200), but it turned out to be very hard to diagnose hashing failures in user programs because often hashing is implicit when using sets or memo tables. @@ -989,7 +989,7 @@ We list here a history of the serialization version numbers: June 13th, 2023 (JAX 0.4.13). * Version 7 adds support for `stablehlo.shape_assertion` operations and for `shape_assertions` specified in `disabled_checks`. - See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule + See [Errors in presence of shape polymorphism](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule since July 12th, 2023 (cl/547482522), available in JAX serialization since July 20th, 2023 (JAX 0.4.14), and the default since August 12th, 2023 (JAX 0.4.15). @@ -1164,7 +1164,7 @@ self.assertAllClose(grad_jax.b, grad_tf[1]) Applies to both native and non-native serialization. When JAX differentiates functions with integer or boolean arguments, the gradients will -be zero-vectors with a special `float0` type (see PR 4039](https://github.com/google/jax/pull/4039)). +be zero-vectors with a special `float0` type (see PR 4039](https://github.com/jax-ml/jax/pull/4039)). This type is translated to `int32` when lowering to TF. For example, @@ -1441,7 +1441,7 @@ Operations like ``jax.numpy.cumsum`` are lowered by JAX differently based on the platform. For TPU, the lowering uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow) operation, which has an efficient implementation for the cases when the reduction function is associative. For CPU and GPU, JAX uses an alternative -lowering using [associative scans](https://github.com/google/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801). +lowering using [associative scans](https://github.com/jax-ml/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801). jax2tf uses the TPU lowering (because it does not support backend-specific lowering) and hence it can be slow in some cases on CPU and GPU. @@ -1502,7 +1502,7 @@ before conversion. (This is a hypothesis, we have not yet verified it extensivel There is one know case when the performance of the lowered code will be different. JAX programs use a [stateless -deterministic PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md) +deterministic PRNG](https://github.com/jax-ml/jax/blob/main/docs/design_notes/prng.md) and it has an internal JAX primitive for it. This primitive is at the moment lowered to a soup of tf.bitwise operations, which has a clear performance penalty. We plan to look into using the @@ -1589,7 +1589,7 @@ Applies to non-native serialization only. There are a number of cases when the TensorFlow ops that are used by the `jax2tf` are not supported by TensorFlow for the same data types as in JAX. There is an -[up-to-date list of unimplemented cases](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). +[up-to-date list of unimplemented cases](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). If you try to lower and run in TensorFlow a program with partially supported primitives, you may see TensorFlow errors that @@ -1626,7 +1626,7 @@ the function to a SavedModel, knowing that upon restore the jax2tf-lowered code will be compiled. For a more elaborate example, see the test `test_tf_mix_jax_with_uncompilable` -in [savedmodel_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py). +in [savedmodel_test.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py). # Calling TensorFlow functions from JAX @@ -1704,7 +1704,7 @@ For a more elaborate example, including round-tripping from JAX to TensorFlow and back through a SavedModel, with support for custom gradients, see the test `test_round_trip_custom_grad_saved_model` -in [call_tf_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py). +in [call_tf_test.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py). All the metadata inserted by TF during tracing and compilation, e.g., source location information and op names, is carried through to the @@ -1901,7 +1901,7 @@ As of today, the tests are run using `tf_nightly==2.14.0.dev20230720`. To run jax2tf on GPU, both jaxlib and TensorFlow must be installed with support for CUDA. One must be mindful to install a version of CUDA that is compatible -with both [jaxlib](https://github.com/google/jax/blob/main/README.md#pip-installation) and +with both [jaxlib](https://github.com/jax-ml/jax/blob/main/README.md#pip-installation) and [TensorFlow](https://www.tensorflow.org/install/source#tested_build_configurations). ## Updating the limitations documentation @@ -1913,9 +1913,9 @@ JAX primitive, data type, device type, and TensorFlow execution mode (`eager`, `graph`, or `compiled`). These limitations are also used to generate tables of limitations, e.g., - * [List of primitives not supported in JAX](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md), + * [List of primitives not supported in JAX](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md), e.g., due to unimplemented cases in the XLA compiler, and - * [List of primitives not supported in jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md), + * [List of primitives not supported in jax2tf](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md), e.g., due to unimplemented cases in TensorFlow. This list is incremental on top of the unsupported JAX primitives. diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 037f8bbc2a02..baae52403053 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -19,7 +19,7 @@ TensorFlow functions. For examples and details, see -https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax. +https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax. """ @@ -93,7 +93,7 @@ def call_tf( For an example and more details see the `README - `_. + `_. Args: callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow @@ -460,7 +460,7 @@ def is_fully_known_shape(s): msg = ("call_tf cannot call functions whose output has dynamic shape. " f"Found output shapes: {concrete_function_flat_tf.output_shapes}. " "Consider using the `output_shape_dtype` argument to call_tf. " - "\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + "\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" " for a discussion.") raise ValueError(msg) @@ -499,7 +499,7 @@ def _call_tf_lowering( msg = ( "call_tf works best with a TensorFlow function that does not capture " "variables or tensors from the context. " - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. " + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. " f"The following captures were found {concrete_function_flat_tf.captured_inputs}") logging.warning(msg) for inp in concrete_function_flat_tf.captured_inputs: @@ -544,7 +544,7 @@ def convert_to_spec(x): "\ncall_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compilable functions with static output shapes.\n" + - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." + + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." + "\n\nCaught TensorFlow exception: " + str(e)) raise ValueError(msg) from e @@ -557,7 +557,7 @@ def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: f"{res_shape}. call_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compilable functions with static output shapes. " + - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") raise ValueError(msg) res_dtype = res_shape.numpy_dtype() diff --git a/jax/experimental/jax2tf/examples/README.md b/jax/experimental/jax2tf/examples/README.md index b049798e7e15..8869a226b675 100644 --- a/jax/experimental/jax2tf/examples/README.md +++ b/jax/experimental/jax2tf/examples/README.md @@ -4,7 +4,7 @@ jax2tf Examples Link: go/jax2tf-examples. This directory contains a number of examples of using the -[jax2tf converter](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) to: +[jax2tf converter](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) to: * save SavedModel from trained MNIST models, using both Flax and pure JAX. * reuse the feature-extractor part of the trained MNIST model @@ -19,12 +19,12 @@ You can also find usage examples in other projects: The functions generated by `jax2tf.convert` are standard TensorFlow functions and you can save them in a SavedModel using standard TensorFlow code, as shown -in the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model). +in the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model). This decoupling of jax2tf from SavedModel is important, because it **allows the user to have full control over what metadata is saved in the SavedModel**. As an example, we provide the function `convert_and_save_model` -(see [saved_model_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py).) +(see [saved_model_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py).) For serious uses, you will probably want to copy and expand this function as needed. @@ -65,7 +65,7 @@ If you are using Flax, then the recipe to obtain this pair is as follows: ``` You can see in -[mnist_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) +[mnist_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) how this can be done for two implementations of MNIST, one using pure JAX (`PureJaxMNIST`) and a CNN one using Flax (`FlaxMNIST`). Other Flax models can be arranged similarly, @@ -91,7 +91,7 @@ embed all parameters in the graph: ``` (The MNIST Flax examples from -[mnist_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) +[mnist_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) normally has a GraphDef of 150k and a variables section of 3Mb. If we embed the parameters as constants in the GraphDef as shown above, the variables section becomes empty and the GraphDef becomes 13Mb. This embedding may allow @@ -112,7 +112,7 @@ If you are using Haiku, then the recipe is along these lines: Once you have the model in this form, you can use the `saved_model_lib.save_model` function from -[saved_model_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py) +[saved_model_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py) to generate the SavedModel. There is very little in that function that is specific to jax2tf. The goal of jax2tf is to convert JAX functions into functions that behave as if they had been written with TensorFlow. @@ -120,7 +120,7 @@ Therefore, if you are familiar with how to generate SavedModel, you can most likely just use your own code for this. The file -[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) +[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) is an executable that shows how to perform the following sequence of steps: @@ -147,9 +147,9 @@ batch sizes: 1, 16, 128. You can see this in the dumped SavedModel. The SavedModel produced by the example in `saved_model_main.py` already implements the [reusable saved models interface](https://www.tensorflow.org/hub/reusable_saved_models). The executable -[keras_reuse_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/keras_reuse_main.py) +[keras_reuse_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/keras_reuse_main.py) extends -[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) +[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) with code to include a jax2tf SavedModel into a larger TensorFlow Keras model. @@ -174,7 +174,7 @@ In particular, you can select the Flax MNIST model: `--model=mnist_flax`. It is also possible to use jax2tf-generated SavedModel with TensorFlow serving. At the moment, the open-source TensorFlow model server is missing XLA support, but the Google version can be used, as shown in the -[serving examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/serving/README.md). +[serving examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/serving/README.md). # Using jax2tf with TensorFlow Lite and TensorFlow JavaScript @@ -186,6 +186,6 @@ can pass the `enable_xla=False` parameter to `jax2tf.convert` to direct `jax2tf` to avoid problematic ops. This will increase the coverage, and in fact most, but not all, Flax examples can be converted this way. -Check out the [MNIST TensorFlow Lite](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md) +Check out the [MNIST TensorFlow Lite](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md) and the -[Quickdraw TensorFlow.js example](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md). +[Quickdraw TensorFlow.js example](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md). diff --git a/jax/experimental/jax2tf/examples/serving/README.md b/jax/experimental/jax2tf/examples/serving/README.md index 0d8f49e45d99..299923109226 100644 --- a/jax/experimental/jax2tf/examples/serving/README.md +++ b/jax/experimental/jax2tf/examples/serving/README.md @@ -2,7 +2,7 @@ Using jax2tf with TensorFlow serving ==================================== This is a supplement to the -[examples/README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md) +[examples/README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md) with example code and instructions for using `jax2tf` with the open source TensorFlow model server. Specific instructions for Google-internal versions of model server are in the `internal` subdirectory. @@ -15,16 +15,16 @@ SavedModel**. The only difference in the SavedModel produced with jax2tf is that the function graphs may contain -[XLA TF ops](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#caveats) +[XLA TF ops](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#caveats) that require enabling CPU/GPU XLA for execution in the model server. This is achieved using a command-line flag. There are no other differences compared to using SavedModel produced by TensorFlow. This serving example uses -[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) +[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) for saving the SavedModel and adds code specific to interacting with the model server: -[model_server_request.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/serving/model_server_request.py). +[model_server_request.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/serving/model_server_request.py). 0. *Set up JAX and TensorFlow serving*. @@ -36,7 +36,7 @@ We also need to install TensorFlow for the `jax2tf` feature and the rest of this We use the `tf_nightly` package to get an up-to-date version. ```shell - git clone https://github.com/google/jax + git clone https://github.com/jax-ml/jax JAX2TF_EXAMPLES=$(pwd)/jax/jax/experimental/jax2tf/examples pip install -e jax pip install flax jaxlib tensorflow_datasets tensorflow_serving_api tf_nightly diff --git a/jax/experimental/jax2tf/examples/tflite/mnist/README.md b/jax/experimental/jax2tf/examples/tflite/mnist/README.md index 9c889e647067..f39bd9c7ea9f 100644 --- a/jax/experimental/jax2tf/examples/tflite/mnist/README.md +++ b/jax/experimental/jax2tf/examples/tflite/mnist/README.md @@ -65,7 +65,7 @@ TensorFlow ops that are only available with the XLA compiler, and which are not understood (yet) by the TFLite converter to be used below. -Check out [more details about this limitation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md), +Check out [more details about this limitation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md), including to which JAX primitives it applies. ### Convert the trained model to the TF Lite format diff --git a/jax/experimental/jax2tf/g3doc/convert_models_results.md b/jax/experimental/jax2tf/g3doc/convert_models_results.md index 545f1faee266..24e2539a3626 100644 --- a/jax/experimental/jax2tf/g3doc/convert_models_results.md +++ b/jax/experimental/jax2tf/g3doc/convert_models_results.md @@ -48,13 +48,13 @@ details on the different converters. ## `flax/actor_critic_[(_, 4*b, 4*b, _)]` ### Example: `flax/actor_critic_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/actor_critic_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -62,13 +62,13 @@ InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -78,13 +78,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -94,13 +94,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -122,13 +122,13 @@ RuntimeError('third_party/tensorflow/lite/kernels/concatenation.cc:159 t->dims-> ## `flax/bilstm_[(b, _), (_,)]` ### Example: `flax/bilstm_[(b, _), (_,)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/bilstm_[(b, _), (_,)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -141,7 +141,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -156,7 +156,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -171,7 +171,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -180,13 +180,13 @@ for more details. ## `flax/bilstm_[(_, _), (b,)]` ### Example: `flax/bilstm_[(_, _), (b,)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/bilstm_[(_, _), (b,)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -199,7 +199,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -214,7 +214,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -229,7 +229,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -238,13 +238,13 @@ for more details. ## `flax/cnn_[(_, b, b, _)]` ### Example: `flax/cnn_[(_, b, b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/cnn_[(_, b, b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -253,13 +253,13 @@ InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_ Conversion error InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'. Details: Cannot divide 'b + -2' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. . @@ -267,7 +267,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -278,13 +278,13 @@ for more details. Conversion error InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'. Details: Cannot divide 'b + -2' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. . @@ -292,7 +292,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -303,13 +303,13 @@ for more details. Conversion error InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'. Details: Cannot divide 'b + -2' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. . @@ -317,7 +317,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -395,13 +395,13 @@ ValueError('Cannot set tensor: Dimension mismatch. Got 8 but expected 1 for dime ## `flax/resnet50_[(_, 4*b, 4*b, _)]` ### Example: `flax/resnet50_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/resnet50_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -409,13 +409,13 @@ InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -425,13 +425,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -441,13 +441,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -613,13 +613,13 @@ IndexError('Cannot use NumPy slice indexing on an array dimension whose size is ## `flax/lm1b_[(b, _)]` ### Example: `flax/lm1b_[(b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/lm1b_[(b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -632,7 +632,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -647,7 +647,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -662,7 +662,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -684,13 +684,13 @@ ValueError('Cannot set tensor: Dimension mismatch. Got 2 but expected 1 for dime ## `flax/wmt_[(b, _), (b, _)]` ### Example: `flax/wmt_[(b, _), (b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/wmt_[(b, _), (b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -703,7 +703,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -718,7 +718,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -733,7 +733,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -798,14 +798,14 @@ This converter simply converts a the forward function of a JAX model to a Tensorflow function with XLA support linked in. This is considered the baseline converter and has the largest coverage, because we expect nearly all ops to be convertible. However, please see -[jax2tf Known Issue](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues) +[jax2tf Known Issue](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues) for a list of known problems. ### `jax2tf_noxla` This converter converts a JAX model to a Tensorflow function without XLA support. This means the Tensorflow XLA ops aren't used. See -[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) +[here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) for more details. ### `jax2tfjs` diff --git a/jax/experimental/jax2tf/g3doc/convert_models_results.md.template b/jax/experimental/jax2tf/g3doc/convert_models_results.md.template index b54c5750334a..54e1d21356a7 100644 --- a/jax/experimental/jax2tf/g3doc/convert_models_results.md.template +++ b/jax/experimental/jax2tf/g3doc/convert_models_results.md.template @@ -29,14 +29,14 @@ This converter simply converts a the forward function of a JAX model to a Tensorflow function with XLA support linked in. This is considered the baseline converter and has the largest coverage, because we expect nearly all ops to be convertible. However, please see -[jax2tf Known Issue](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues) +[jax2tf Known Issue](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues) for a list of known problems. ### `jax2tf_noxla` This converter converts a JAX model to a Tensorflow function without XLA support. This means the Tensorflow XLA ops aren't used. See -[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) +[here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) for more details. ### `jax2tfjs` diff --git a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md index 457dc998abca..24a1d62ee67e 100644 --- a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md +++ b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md @@ -1,6 +1,6 @@ # jax2tf Limitations for `enable_xla=False` -*Note: the list below is only for running jax2tf with `enable_xla=False`. For general jax2tf known issues please see [here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues)* +*Note: the list below is only for running jax2tf with `enable_xla=False`. For general jax2tf known issues please see [here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues)* For most JAX primitives there is a natural TF op that fits the needed semantics (e.g., `jax.lax.abs` is equivalent to `tf.abs`). However, there are a number of diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md index dabbcca4d430..b36b004a9d31 100644 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md +++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md @@ -40,7 +40,7 @@ The converter has a mode in which it attempts to avoid special XLA TF ops (`enable_xla=False`). In this mode, some primitives have additional limitations. This table only shows errors for cases that are working in JAX (see [separate -list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) +list of unsupported or partially-supported primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) We do not yet have support for `pmap` (with its collective primitives), nor for `sharded_jit` (SPMD partitioning). @@ -56,7 +56,7 @@ We use the following abbreviations for sets of dtypes: * `all` = `integer`, `inexact`, `bool` More detailed information can be found in the -[source code for the limitation specification](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). +[source code for the limitation specification](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). | Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template index bf5dc41d8b8b..219802f5363a 100644 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template +++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template @@ -40,7 +40,7 @@ The converter has a mode in which it attempts to avoid special XLA TF ops (`enable_xla=False`). In this mode, some primitives have additional limitations. This table only shows errors for cases that are working in JAX (see [separate -list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) +list of unsupported or partially-supported primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) We do not yet have support for `pmap` (with its collective primitives), nor for `sharded_jit` (SPMD partitioning). @@ -56,7 +56,7 @@ We use the following abbreviations for sets of dtypes: * `all` = `integer`, `inexact`, `bool` More detailed information can be found in the -[source code for the limitation specification](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). +[source code for the limitation specification](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). {{tf_error_table}} diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 5ecde602cdaa..310cbaab6d59 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -591,7 +591,7 @@ def _padding_reduce_window(operand, operand_shape, computation_name, padding_type = pads_to_padtype(operand_shape, window_dimensions, window_strides, padding) - # https://github.com/google/jax/issues/11874. + # https://github.com/jax-ml/jax/issues/11874. needs_manual_padding = ( padding_type == "SAME" and computation_name == "add" and window_dimensions != [1] * len(operand_shape)) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 24dee390f398..8a90c491e526 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -198,7 +198,7 @@ def __init__(self): # A cache for the tf.convert_to_tensor for constants. We try to preserve # sharing for constants, to enable tf.Graph to take advantage of it. - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. self.constant_cache = None # None means that we don't use a cache. We # may be outside a conversion scope. @@ -249,7 +249,7 @@ def convert(fun_jax: Callable, """Allows calling a JAX function from a TensorFlow program. See - [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) + [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) for more details about usage and common problems. Args: @@ -291,12 +291,12 @@ def convert(fun_jax: Callable, polymorphic_shapes are only supported for positional arguments; shape polymorphism is not supported for keyword arguments. - See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) + See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) for more details. polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of the form `e1 >= e2` or `e1 <= e2`. - See more details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. with_gradient: if set (default), add a tf.custom_gradient to the lowered function, by converting the ``jax.vjp(fun)``. This means that reverse-mode TensorFlow AD is supported for the output TensorFlow function, and the @@ -3536,7 +3536,7 @@ def _shard_value(val: TfVal, if tf_context.executing_eagerly(): raise ValueError( "A jit function with sharded arguments or results must be used under a `tf.function` context. " - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion") + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion") return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor( val, use_sharding_op=True) diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 2760efea8061..e10c3fbfdff7 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -304,7 +304,7 @@ def fun_tf(x): self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False) def test_with_var_different_shape(self): - # See https://github.com/google/jax/issues/6050 + # See https://github.com/jax-ml/jax/issues/6050 v = tf.Variable((4., 2.), dtype=tf.float32) def tf_func(x): @@ -428,7 +428,7 @@ def loss(functional, x_dict): self.assertAllClose(g_jax, g_tf) def test_grad_int_argument(self): - # Similar to https://github.com/google/jax/issues/6975 + # Similar to https://github.com/jax-ml/jax/issues/6975 # state is a pytree that contains an integer and a boolean. # The function returns an integer and a boolean. def f(param, state, x): diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 896d0436e3c2..c3b9e96dc320 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -1066,9 +1066,9 @@ def nextafter(cls, harness: test_harnesses.Harness): @classmethod def qr(cls, harness: test_harnesses.Harness): - # See https://github.com/google/jax/pull/3775#issuecomment-659407824; + # See https://github.com/jax-ml/jax/pull/3775#issuecomment-659407824; # # jit_compile=True breaks for complex types. - # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. + # TODO: see https://github.com/jax-ml/jax/pull/3775#issuecomment-659407824. # - for now, the performance of the HLO QR implementation called when # compiling with TF is expected to have worse performance than the # custom calls made in JAX. diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 64d461fe9996..ef7a5ee2c138 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -595,7 +595,7 @@ def fn(x0, x1, x2, x3): @jtu.sample_product(with_function=[False, True]) def test_gradients_int_argument(self, with_function=False): - # https://github.com/google/jax/issues/6975 + # https://github.com/jax-ml/jax/issues/6975 # Also issue #6975. # An expanded version of test_gradients_unused_argument state = dict( @@ -969,7 +969,7 @@ def caller_jax(x): self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def) def test_bfloat16_constant(self): - # Re: https://github.com/google/jax/issues/3942 + # Re: https://github.com/jax-ml/jax/issues/3942 def jax_fn_scalar(x): x = x.astype(jnp.bfloat16) x *= 2. @@ -990,7 +990,7 @@ def jax_fn_array(x): def test_shared_constants(self): # Check that the constants are shared properly in converted functions - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. if config.jax2tf_default_native_serialization.value: raise unittest.SkipTest("shared constants tests not interesting for native serialization") const = np.random.uniform(size=256).astype(np.float32) # A shared constant @@ -1002,7 +1002,7 @@ def f(x): def test_shared_constants_under_cond(self): # Check that the constants are shared properly in converted functions - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. if config.jax2tf_default_native_serialization.value: raise unittest.SkipTest("shared constants tests not interesting for native serialization") const_size = 512 @@ -1018,7 +1018,7 @@ def f2(x): self.assertLen(f2_consts, len(f1_consts)) def test_shared_constants_under_scan(self): - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. if config.jax2tf_default_native_serialization.value: raise unittest.SkipTest("shared constants tests not interesting for native serialization") const_size = 512 @@ -1092,7 +1092,7 @@ def test_weak_types(self): @jtu.sample_product(with_function=[False, True]) def test_kwargs(self, with_function=False): - # Re: https://github.com/google/jax/issues/6791 + # Re: https://github.com/jax-ml/jax/issues/6791 def f_jax(*, x): return jnp.sum(x) f_tf = jax2tf.convert(f_jax) @@ -1104,7 +1104,7 @@ def f_jax(*, x): @jtu.sample_product(with_function=[False, True]) def test_grad_kwargs(self, with_function=False): - # Re: https://github.com/google/jax/issues/6791 + # Re: https://github.com/jax-ml/jax/issues/6791 x = (np.zeros(3, dtype=np.float32), np.zeros(4, dtype=np.float32)) def f_jax(*, x=(1., 2.)): diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 485fa6e5831f..78c24b7ea411 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -30,7 +30,7 @@ are captured as jax2tf_limitations.Jax2TfLimitation objects. From the limitations objects, we generate a -[report](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). +[report](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). The report has instructions for how to re-generate it. If a harness run fails with error, and a limitation that matches the device diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py index 8b71de7db30c..bc19915d1644 100644 --- a/jax/experimental/jax2tf/tests/savedmodel_test.py +++ b/jax/experimental/jax2tf/tests/savedmodel_test.py @@ -175,7 +175,7 @@ def model_jax(params, inputs): def test_save_grad_integers(self): - # https://github.com/google/jax/issues/7123 + # https://github.com/jax-ml/jax/issues/7123 # In the end this is a test that does not involve JAX at all batch_size = 5 state = np.array([1], dtype=np.int32) # Works if float32 diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 2475a062f5ec..a9ee1776222c 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -933,7 +933,7 @@ def f(x): kwargs=[dict(with_function=v) for v in [True, False]] ) def test_grad_int(self, with_function=False): - # https://github.com/google/jax/issues/7093 + # https://github.com/jax-ml/jax/issues/7093 # Also issue #6975. x_shape = (2, 3, 4) xi = np.arange(math.prod(x_shape), dtype=np.int16).reshape(x_shape) @@ -2172,7 +2172,7 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] (2, x.shape[0]), (1, 1), "VALID"), arg_descriptors=[RandArg((3, 8), _f32)], polymorphic_shapes=["b, ..."]), - # https://github.com/google/jax/issues/11804 + # https://github.com/jax-ml/jax/issues/11804 # Use the reshape trick to simulate a polymorphic dimension of 16*b. # (See test "conv_general_dilated.1d_1" above for more details.) PolyHarness("reduce_window", "add_monoid_strides_window_size=static", diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index d6349b4870d2..9009c1586f15 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -437,7 +437,7 @@ def f_grad_tf(x_v, res_ct): def test_grad_sharding_different_mesh(self): # Convert with two similar meshes, the only difference being # the order of the devices. grad should not fail. - # https://github.com/google/jax/issues/21314 + # https://github.com/jax-ml/jax/issues/21314 devices = jax.local_devices()[:2] if len(devices) < 2: raise unittest.SkipTest("Test requires 2 local devices") diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 1ed6183b1229..ffe362974dcb 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -45,11 +45,11 @@ and can thus be used for high-order automatic differentiation of :math:`f`. Details are explained in - `these notes `__. + `these notes `__. Note: Help improve :func:`jet` by contributing - `outstanding primitive rules `__. + `outstanding primitive rules `__. """ from collections.abc import Callable diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index fabd45ca069a..f19401525cc0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -639,7 +639,7 @@ def _rule_missing(prim: core.Primitive, *_, **__): raise NotImplementedError( f"No replication rule for {prim}. As a workaround, pass the " "`check_rep=False` argument to `shard_map`. To get this fixed, open an " - "issue at https://github.com/google/jax/issues") + "issue at https://github.com/jax-ml/jax/issues") # Lowering @@ -845,20 +845,20 @@ def process_call(self, call_primitive, fun, tracers, params): f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " "yet supported. Put a `jax.jit` around the `shard_map`-decorated " "function, and open a feature request at " - "https://github.com/google/jax/issues !") + "https://github.com/jax-ml/jax/issues !") def process_map(self, map_primitive, fun, tracers, params): raise NotImplementedError( "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/google/jax/issues !") + "a feature request at https://github.com/jax-ml/jax/issues !") def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. if symbolic_zeros: msg = ("custom_jvp symbolic_zeros support with shard_map is not " "implemented; please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) @@ -876,7 +876,7 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, if symbolic_zeros: msg = ("custom_vjp symbolic_zeros support with shard_map is not " "implemented; please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) @@ -1042,7 +1042,7 @@ def _standard_check(prim, mesh, *in_rep, **__): if in_rep_ and not in_rep_[:-1] == in_rep_[1:]: raise Exception(f"Primitive {prim} requires argument replication types " f"to match, but got {in_rep}. Please open an issue at " - "https://github.com/google/jax/issues and as a temporary " + "https://github.com/jax-ml/jax/issues and as a temporary " "workaround pass the check_rep=False argument to shard_map") return in_rep_[0] if in_rep_ else None @@ -1057,7 +1057,7 @@ def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): raise Exception(f"Collective {prim} must be applied to a device-varying " f"replication type, but got {x_rep} for collective acting " f"over axis name {axis_name}. Please open an issue at " - "https://github.com/google/jax/issues and as a temporary " + "https://github.com/jax-ml/jax/issues and as a temporary " "workaround pass the check_rep=False argument to shard_map") return x_rep @@ -1114,7 +1114,7 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups): raise Exception("Collective psum must be applied to a device-varying " f"replication type, but got {in_rep} for collective acting " f"over axis name {axes}. Please open an issue at " - "https://github.com/google/jax/issues, and as a temporary " + "https://github.com/jax-ml/jax/issues, and as a temporary " "workaround pass the check_rep=False argument to shard_map") in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) return [r | set(axes) for r in in_rep] @@ -1129,7 +1129,7 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): "non-device-varying " f"replication type, but got {in_rep} for collective acting " f"over axis name {axes}. Please open an issue at " - "https://github.com/google/jax/issues, and as a temporary " + "https://github.com/jax-ml/jax/issues, and as a temporary " "workaround pass the check_rep=False argument to shard_map") in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) return [r - set(axes) for r in in_rep] @@ -1216,7 +1216,7 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): if not carry_rep_in == carry_rep_out: raise Exception("Scan carry input and output got mismatched replication " f"types {carry_rep_in} and {carry_rep_out}. Please open an " - "issue at https://github.com/google/jax/issues, and as a " + "issue at https://github.com/jax-ml/jax/issues, and as a " "temporary workaround pass the check_rep=False argument to " "shard_map") return out_rep @@ -1267,7 +1267,7 @@ def _custom_vjp_call_jaxpr_rewrite( mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees, symbolic_zeros): if symbolic_zeros: - msg = ("Please open an issue at https://github.com/google/jax/issues and as" + msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and as" " a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) @@ -1303,7 +1303,7 @@ def _linear_solve_check(mesh, *in_rep, const_lengths, jaxprs): assert in_rep if not in_rep_[:-1] == in_rep_[1:]: msg = ("shard_map check_rep rewrite failed. Please open an issue at " - "https://github.com/google/jax/issues and as a workaround pass the " + "https://github.com/jax-ml/jax/issues and as a workaround pass the " "check_rep=False argument to shard_map") raise Exception(msg) return [in_rep_[0]] * len(jaxprs.solve.out_avals) @@ -1878,7 +1878,7 @@ def post_process_call(self, call_primitive, out_tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): if symbolic_zeros: - msg = ("Please open an issue at https://github.com/google/jax/issues and " + msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " "as a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) @@ -1899,7 +1899,7 @@ def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): if symbolic_zeros: - msg = ("Please open an issue at https://github.com/google/jax/issues and " + msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " "as a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) diff --git a/jax/experimental/sparse/__init__.py b/jax/experimental/sparse/__init__.py index 8ab8cd88721d..f388cd527cf9 100644 --- a/jax/experimental/sparse/__init__.py +++ b/jax/experimental/sparse/__init__.py @@ -189,7 +189,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.experimental.sparse.ad import ( jacfwd as jacfwd, diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index d200577c2416..9f2f0f69be63 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -105,7 +105,7 @@ def _bcoo_set_nse(mat: BCOO, nse: int) -> BCOO: unique_indices=mat.unique_indices) # TODO(jakevdp) this can be problematic when used with autodiff; see -# https://github.com/google/jax/issues/10163. Should this be a primitive? +# https://github.com/jax-ml/jax/issues/10163. Should this be a primitive? # Alternatively, maybe roll this into bcoo_sum_duplicates as an optional argument. def bcoo_eliminate_zeros(mat: BCOO, nse: int | None = None) -> BCOO: data, indices, shape = mat.data, mat.indices, mat.shape @@ -1140,7 +1140,7 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :]) out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1]) # Note: we do not eliminate zeros here, because it can cause issues with autodiff. - # See https://github.com/google/jax/issues/10163. + # See https://github.com/jax-ml/jax/issues/10163. return _bcoo_sum_duplicates(out_data, out_indices, spinfo=SparseInfo(shape=out_shape), nse=out_nse) @bcoo_spdot_general_p.def_impl @@ -1537,7 +1537,7 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse): nse, *data.shape[props.n_batch + 1:]), dtype=data.dtype) data_dot_out = data_out # This check is because scatter-add on zero-sized arrays has poorly defined - # semantics; see https://github.com/google/jax/issues/13656. + # semantics; see https://github.com/jax-ml/jax/issues/13656. if data_out.size: permute = lambda x, i, y: x.at[i].add(y, mode='drop') else: diff --git a/jax/extend/backend.py b/jax/extend/backend.py index 66fd149d7c8e..b1e471133482 100644 --- a/jax/extend/backend.py +++ b/jax/extend/backend.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.api import ( clear_backends as clear_backends, diff --git a/jax/extend/core/__init__.py b/jax/extend/core/__init__.py index 2732b1984c1d..9f1632fb37a9 100644 --- a/jax/extend/core/__init__.py +++ b/jax/extend/core/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.abstract_arrays import ( array_types as array_types diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index e37287180eee..feb70b5171be 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.ad_util import stop_gradient_p as stop_gradient_p diff --git a/jax/extend/ffi.py b/jax/extend/ffi.py index 3a26030c1687..b2d480adc7eb 100644 --- a/jax/extend/ffi.py +++ b/jax/extend/ffi.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.extend.ffi import ( ffi_call as ffi_call, diff --git a/jax/extend/ifrt_programs.py b/jax/extend/ifrt_programs.py index d5fb9245af91..715dfd43592c 100644 --- a/jax/extend/ifrt_programs.py +++ b/jax/extend/ifrt_programs.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.lib import xla_extension as _xe diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 1706f8c8c30b..74c52dddbae8 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.linear_util import ( StoreException as StoreException, diff --git a/jax/extend/random.py b/jax/extend/random.py index a055c75751bd..d6e0cfaab0e4 100644 --- a/jax/extend/random.py +++ b/jax/extend/random.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.extend.random import ( define_prng_impl as define_prng_impl, diff --git a/jax/extend/source_info_util.py b/jax/extend/source_info_util.py index f74df2cab5e1..f031dabef48d 100644 --- a/jax/extend/source_info_util.py +++ b/jax/extend/source_info_util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.source_info_util import ( NameStack as NameStack, diff --git a/jax/image/__init__.py b/jax/image/__init__.py index c7ee8ffa9c64..993395f503fd 100644 --- a/jax/image/__init__.py +++ b/jax/image/__init__.py @@ -21,7 +21,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.image.scale import ( resize as resize, diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 6bfc3473ff50..28816afb01e3 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from __future__ import annotations diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 98fad903cc4f..607fc6fa596d 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.interpreters.batching import ( Array as Array, diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index e2bcd5de9408..293bd02446fe 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.lax.lax import ( DotDimensionNumbers as DotDimensionNumbers, diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 230aacb7654a..496d03261384 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -15,7 +15,7 @@ """Common functions for neural network libraries.""" # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.numpy import tanh as tanh from jax.nn import initializers as initializers diff --git a/jax/nn/initializers.py b/jax/nn/initializers.py index 6c73356ce1a1..019f3e179215 100644 --- a/jax/nn/initializers.py +++ b/jax/nn/initializers.py @@ -18,7 +18,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.nn.initializers import ( constant as constant, diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index da79f7859bcd..20c37c55902c 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.numpy import fft as fft from jax.numpy import linalg as linalg diff --git a/jax/numpy/fft.py b/jax/numpy/fft.py index 24a271487d5e..c268c2d65597 100644 --- a/jax/numpy/fft.py +++ b/jax/numpy/fft.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.numpy.fft import ( ifft as ifft, diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index c342fde0ae6e..05b5ff6db289 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.numpy.linalg import ( cholesky as cholesky, diff --git a/jax/ops/__init__.py b/jax/ops/__init__.py index c61a44fd1357..5e1f3d682589 100644 --- a/jax/ops/__init__.py +++ b/jax/ops/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.ops.scatter import ( segment_sum as segment_sum, diff --git a/jax/profiler.py b/jax/profiler.py index 01ea6e2222cc..77157dc02a13 100644 --- a/jax/profiler.py +++ b/jax/profiler.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.profiler import ( StepTraceAnnotation as StepTraceAnnotation, diff --git a/jax/random.py b/jax/random.py index 5c2eaf81f2bc..29a625389811 100644 --- a/jax/random.py +++ b/jax/random.py @@ -103,7 +103,7 @@ **TLDR**: JAX PRNG = `Threefry counter PRNG `_ + a functional array-oriented `splitting model `_ -See `docs/jep/263-prng.md `_ +See `docs/jep/263-prng.md `_ for more details. To summarize, among other requirements, the JAX PRNG aims to: @@ -201,7 +201,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.random import ( PRNGKey as PRNGKey, diff --git a/jax/scipy/__init__.py b/jax/scipy/__init__.py index c0746910dd3f..cf44b6e179c0 100644 --- a/jax/scipy/__init__.py +++ b/jax/scipy/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from typing import TYPE_CHECKING diff --git a/jax/scipy/cluster/__init__.py b/jax/scipy/cluster/__init__.py index 5a01ea0ee493..ea35467f6353 100644 --- a/jax/scipy/cluster/__init__.py +++ b/jax/scipy/cluster/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.scipy.cluster import vq as vq diff --git a/jax/scipy/cluster/vq.py b/jax/scipy/cluster/vq.py index 3a46ce52f468..eeeabb7224bc 100644 --- a/jax/scipy/cluster/vq.py +++ b/jax/scipy/cluster/vq.py @@ -13,6 +13,6 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.cluster.vq import vq as vq diff --git a/jax/scipy/fft.py b/jax/scipy/fft.py index b8005b72f349..d3c2de09935a 100644 --- a/jax/scipy/fft.py +++ b/jax/scipy/fft.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.fft import ( dct as dct, diff --git a/jax/scipy/integrate.py b/jax/scipy/integrate.py index b19aa054ca00..3335f12fd381 100644 --- a/jax/scipy/integrate.py +++ b/jax/scipy/integrate.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.integrate import ( trapezoid as trapezoid diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 059f927ec46c..64bc0544000b 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.linalg import ( block_diag as block_diag, diff --git a/jax/scipy/ndimage.py b/jax/scipy/ndimage.py index 2f63e236654c..81d7e3ef27d8 100644 --- a/jax/scipy/ndimage.py +++ b/jax/scipy/ndimage.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.ndimage import ( map_coordinates as map_coordinates, diff --git a/jax/scipy/optimize/__init__.py b/jax/scipy/optimize/__init__.py index 8a2248733145..f1c7167c33f4 100644 --- a/jax/scipy/optimize/__init__.py +++ b/jax/scipy/optimize/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.optimize.minimize import ( minimize as minimize, diff --git a/jax/scipy/signal.py b/jax/scipy/signal.py index 7e39da3f95b1..c46b2fce3572 100644 --- a/jax/scipy/signal.py +++ b/jax/scipy/signal.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.signal import ( fftconvolve as fftconvolve, diff --git a/jax/scipy/sparse/__init__.py b/jax/scipy/sparse/__init__.py index f2e305e829c8..2968a26b4415 100644 --- a/jax/scipy/sparse/__init__.py +++ b/jax/scipy/sparse/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.scipy.sparse import linalg as linalg diff --git a/jax/scipy/sparse/linalg.py b/jax/scipy/sparse/linalg.py index d475ddff81f7..d22e5ec43977 100644 --- a/jax/scipy/sparse/linalg.py +++ b/jax/scipy/sparse/linalg.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.sparse.linalg import ( cg as cg, diff --git a/jax/scipy/spatial/transform.py b/jax/scipy/spatial/transform.py index 4b532d5f3d50..63e8dd3736b2 100644 --- a/jax/scipy/spatial/transform.py +++ b/jax/scipy/spatial/transform.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.spatial.transform import ( Rotation as Rotation, diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 5d72339eaec8..431617d362ea 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.special import ( bernoulli as bernoulli, diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 7aa73f7b5218..7719945f23df 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.scipy.stats import bernoulli as bernoulli from jax.scipy.stats import beta as beta diff --git a/jax/scipy/stats/bernoulli.py b/jax/scipy/stats/bernoulli.py index 46c1e4825d11..1623f71130c1 100644 --- a/jax/scipy/stats/bernoulli.py +++ b/jax/scipy/stats/bernoulli.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.bernoulli import ( logpmf as logpmf, diff --git a/jax/scipy/stats/beta.py b/jax/scipy/stats/beta.py index 5c57dda6bb56..2a4e7f12f7a5 100644 --- a/jax/scipy/stats/beta.py +++ b/jax/scipy/stats/beta.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.beta import ( cdf as cdf, diff --git a/jax/scipy/stats/betabinom.py b/jax/scipy/stats/betabinom.py index 48f955d9eaf3..f8adf68f4b2e 100644 --- a/jax/scipy/stats/betabinom.py +++ b/jax/scipy/stats/betabinom.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.betabinom import ( logpmf as logpmf, diff --git a/jax/scipy/stats/cauchy.py b/jax/scipy/stats/cauchy.py index 4ff79f5f9888..34c9972d09bd 100644 --- a/jax/scipy/stats/cauchy.py +++ b/jax/scipy/stats/cauchy.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.cauchy import ( cdf as cdf, diff --git a/jax/scipy/stats/chi2.py b/jax/scipy/stats/chi2.py index e17a2e331958..47fcb76db28d 100644 --- a/jax/scipy/stats/chi2.py +++ b/jax/scipy/stats/chi2.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.chi2 import ( cdf as cdf, diff --git a/jax/scipy/stats/dirichlet.py b/jax/scipy/stats/dirichlet.py index 9368defc8f58..22e9b3cc11cc 100644 --- a/jax/scipy/stats/dirichlet.py +++ b/jax/scipy/stats/dirichlet.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.dirichlet import ( logpdf as logpdf, diff --git a/jax/scipy/stats/expon.py b/jax/scipy/stats/expon.py index 1ec50ac3f604..8f5c0a0680ce 100644 --- a/jax/scipy/stats/expon.py +++ b/jax/scipy/stats/expon.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.expon import ( logpdf as logpdf, diff --git a/jax/scipy/stats/gamma.py b/jax/scipy/stats/gamma.py index 8efecafed3bd..531a1e300ca9 100644 --- a/jax/scipy/stats/gamma.py +++ b/jax/scipy/stats/gamma.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.gamma import ( cdf as cdf, diff --git a/jax/scipy/stats/gennorm.py b/jax/scipy/stats/gennorm.py index c903ff606c25..c760575fa7a6 100644 --- a/jax/scipy/stats/gennorm.py +++ b/jax/scipy/stats/gennorm.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.gennorm import ( cdf as cdf, diff --git a/jax/scipy/stats/geom.py b/jax/scipy/stats/geom.py index 75f917fc27c7..eb12dbb5a183 100644 --- a/jax/scipy/stats/geom.py +++ b/jax/scipy/stats/geom.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.geom import ( logpmf as logpmf, diff --git a/jax/scipy/stats/laplace.py b/jax/scipy/stats/laplace.py index 3abe62020398..8f182804daf0 100644 --- a/jax/scipy/stats/laplace.py +++ b/jax/scipy/stats/laplace.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.laplace import ( cdf as cdf, diff --git a/jax/scipy/stats/logistic.py b/jax/scipy/stats/logistic.py index c25a06856ff7..7cdb26fb1d20 100644 --- a/jax/scipy/stats/logistic.py +++ b/jax/scipy/stats/logistic.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.logistic import ( cdf as cdf, diff --git a/jax/scipy/stats/multinomial.py b/jax/scipy/stats/multinomial.py index 723d1a645726..392ca405581e 100644 --- a/jax/scipy/stats/multinomial.py +++ b/jax/scipy/stats/multinomial.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.multinomial import ( logpmf as logpmf, diff --git a/jax/scipy/stats/multivariate_normal.py b/jax/scipy/stats/multivariate_normal.py index 95ad355c75f1..94c4cc50a18c 100644 --- a/jax/scipy/stats/multivariate_normal.py +++ b/jax/scipy/stats/multivariate_normal.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.multivariate_normal import ( logpdf as logpdf, diff --git a/jax/scipy/stats/norm.py b/jax/scipy/stats/norm.py index f47765adfc68..563e40ce06cd 100644 --- a/jax/scipy/stats/norm.py +++ b/jax/scipy/stats/norm.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.norm import ( cdf as cdf, diff --git a/jax/scipy/stats/pareto.py b/jax/scipy/stats/pareto.py index bf27ea205948..5e46fd5d0bc7 100644 --- a/jax/scipy/stats/pareto.py +++ b/jax/scipy/stats/pareto.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.pareto import ( logpdf as logpdf, diff --git a/jax/scipy/stats/poisson.py b/jax/scipy/stats/poisson.py index 2e857bc15a3b..5fcde905f89b 100644 --- a/jax/scipy/stats/poisson.py +++ b/jax/scipy/stats/poisson.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.poisson import ( logpmf as logpmf, diff --git a/jax/scipy/stats/t.py b/jax/scipy/stats/t.py index d92fcab97bf7..694bcb0b0dfc 100644 --- a/jax/scipy/stats/t.py +++ b/jax/scipy/stats/t.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.t import ( logpdf as logpdf, diff --git a/jax/scipy/stats/truncnorm.py b/jax/scipy/stats/truncnorm.py index 28d5533b02da..cb8e8958d735 100644 --- a/jax/scipy/stats/truncnorm.py +++ b/jax/scipy/stats/truncnorm.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.truncnorm import ( cdf as cdf, diff --git a/jax/scipy/stats/uniform.py b/jax/scipy/stats/uniform.py index d0a06c673b3c..fa754125f556 100644 --- a/jax/scipy/stats/uniform.py +++ b/jax/scipy/stats/uniform.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.uniform import ( logpdf as logpdf, diff --git a/jax/scipy/stats/vonmises.py b/jax/scipy/stats/vonmises.py index 8de7fba47096..6572e43f63c6 100644 --- a/jax/scipy/stats/vonmises.py +++ b/jax/scipy/stats/vonmises.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.vonmises import ( logpdf as logpdf, diff --git a/jax/scipy/stats/wrapcauchy.py b/jax/scipy/stats/wrapcauchy.py index 6e2420c5ae7b..eb1768f0c959 100644 --- a/jax/scipy/stats/wrapcauchy.py +++ b/jax/scipy/stats/wrapcauchy.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.wrapcauchy import ( logpdf as logpdf, diff --git a/jax/sharding.py b/jax/sharding.py index ea92e9d17e42..26c542292e87 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.sharding import Sharding as Sharding from jax._src.sharding_impls import ( diff --git a/jax/stages.py b/jax/stages.py index 6ffc3144c3bc..3e7e461c385b 100644 --- a/jax/stages.py +++ b/jax/stages.py @@ -22,7 +22,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.stages import ( Compiled as Compiled, diff --git a/jax/test_util.py b/jax/test_util.py index 5d4f5ed0aa77..176f4521b281 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.public_test_util import ( check_grads as check_grads, diff --git a/jax/tree_util.py b/jax/tree_util.py index b4854c7dfbf1..956d79b9b4ef 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -36,7 +36,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.tree_util import ( DictKey as DictKey, diff --git a/jax/util.py b/jax/util.py index c1259e9c5f56..8071f77dffe2 100644 --- a/jax/util.py +++ b/jax/util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.util import ( HashableFunction as HashableFunction, diff --git a/jax/version.py b/jax/version.py index c6e4b3ad11ec..6c64d75b9733 100644 --- a/jax/version.py +++ b/jax/version.py @@ -115,7 +115,7 @@ def run(self): # missing or outdated. Because _write_version(...) modifies the copy of # this file in the build tree, re-building from the same JAX directory # would not automatically re-copy a clean version, and _write_version - # would fail without this deletion. See google/jax#18252. + # would fail without this deletion. See jax-ml/jax#18252. if os.path.isfile(this_file_in_build_dir): os.unlink(this_file_in_build_dir) super().run() diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py index 9ccf3bf44339..a84a6b34ea48 100644 --- a/jax_plugins/rocm/plugin_setup.py +++ b/jax_plugins/rocm/plugin_setup.py @@ -51,7 +51,7 @@ def has_ext_modules(self): packages=[package_name], python_requires=">=3.9", install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"], - url="https://github.com/google/jax", + url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", diff --git a/jaxlib/README.md b/jaxlib/README.md index 74e1e5b36ae3..cee5f246d96b 100644 --- a/jaxlib/README.md +++ b/jaxlib/README.md @@ -4,4 +4,4 @@ jaxlib is the support library for JAX. While JAX itself is a pure Python package jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main -JAX README: https://github.com/google/jax/. +JAX README: https://github.com/jax-ml/jax/. diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 215313f9bb3a..dea9503c7c00 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -66,7 +66,7 @@ def has_ext_modules(self): 'numpy>=1.24', 'ml_dtypes>=0.2.0', ], - url='https://github.com/google/jax', + url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ "Programming Language :: Python :: 3.10", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 6305b0c24aa8..52a17c451aea 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack(): We don't entirely know why this happens, but in some build environments we seem to target the wrong Mac OS version. - https://github.com/google/jax/issues/3867 + https://github.com/jax-ml/jax/issues/3867 This check makes sure we don't release wheels that have this dependency. """ diff --git a/setup.py b/setup.py index 81eef74e0049..e807ff3b0052 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,7 @@ def load_version_module(pkg_path): f"jax-cuda12-plugin=={_current_jaxlib_version}", ], }, - url='https://github.com/google/jax', + url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ "Programming Language :: Python :: 3.10", diff --git a/tests/api_test.py b/tests/api_test.py index 1deaa4c08dc8..adce61d650d6 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -467,7 +467,7 @@ def test_jit_donate_weak_type(self, argnum_type, argnum_val): ("argnames", "donate_argnames", ('array',)), ) def test_jnp_array_copy(self, argnum_type, argnum_val): - # https://github.com/google/jax/issues/3412 + # https://github.com/jax-ml/jax/issues/3412 @partial(jit, **{argnum_type: argnum_val}) def _test(array): @@ -905,7 +905,7 @@ def f(x): @jax.legacy_prng_key('allow') def test_omnistaging(self): - # See https://github.com/google/jax/issues/5206 + # See https://github.com/jax-ml/jax/issues/5206 # TODO(frostig): remove `wrap` once we always enable_custom_prng def wrap(arr): @@ -1409,7 +1409,7 @@ def f(d) -> float: f({E.A: 1.0, E.B: 2.0}) def test_jit_static_argnums_requires_type_equality(self): - # See: https://github.com/google/jax/pull/9311 + # See: https://github.com/jax-ml/jax/pull/9311 @partial(jit, static_argnums=(0,)) def f(k): assert python_should_be_executing @@ -1424,7 +1424,7 @@ def f(k): self.assertEqual(x, f(x)) def test_caches_depend_on_axis_env(self): - # https://github.com/google/jax/issues/9187 + # https://github.com/jax-ml/jax/issues/9187 f = lambda: lax.psum(1, "i") g = jax.jit(f) expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)() @@ -1437,7 +1437,7 @@ def test_caches_depend_on_axis_env(self): self.assertEqual(ans, expected) def test_caches_dont_depend_on_unnamed_axis_env(self): - # https://github.com/google/jax/issues/9187 + # https://github.com/jax-ml/jax/issues/9187 f = jax.jit(lambda: jnp.sin(1)) expected = f() with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 @@ -1446,7 +1446,7 @@ def test_caches_dont_depend_on_unnamed_axis_env(self): self.assertArraysAllClose(ans, expected, check_dtypes=True) def test_cache_key_defaults(self): - # https://github.com/google/jax/discussions/11875 + # https://github.com/jax-ml/jax/discussions/11875 f = jit(lambda x: (x ** 2).sum()) self.assertEqual(f._cache_size(), 0) x = jnp.arange(5.0) @@ -1455,7 +1455,7 @@ def test_cache_key_defaults(self): self.assertEqual(f._cache_size(), 1) def test_jit_nan_times_zero(self): - # https://github.com/google/jax/issues/4780 + # https://github.com/jax-ml/jax/issues/4780 def f(x): return 1 + x * 0 self.assertAllClose(f(np.nan), np.nan) @@ -2163,7 +2163,7 @@ def test_grad_and_aux_constant(self): self.assertEqual(aux, [4.**2, 4.]) def test_grad_and_aux_no_tracers(self): - # see https://github.com/google/jax/issues/1950 + # see https://github.com/jax-ml/jax/issues/1950 def f(x): aux = dict(identity=x, p1=x+1) return x ** 2, aux @@ -2322,7 +2322,7 @@ def test_linear_transpose_integer(self): self.assertEqual(actual, expected) def test_linear_transpose_dce(self): - # https://github.com/google/jax/issues/15660 + # https://github.com/jax-ml/jax/issues/15660 f = jit(lambda x: (2 * x, x > 0)) g = lambda x: f(x)[0] api.linear_transpose(g, 1.)(1.) @@ -2389,7 +2389,7 @@ def test_complex_output_jacrev_raises_error(self): self.assertRaises(TypeError, lambda: jacrev(lambda x: jnp.sin(x))(1 + 2j)) def test_nonholomorphic_jacrev(self): - # code based on https://github.com/google/jax/issues/603 + # code based on https://github.com/jax-ml/jax/issues/603 zs = 0.5j * np.arange(5) + np.arange(5) def f(z): @@ -2401,8 +2401,8 @@ def f(z): @jax.numpy_dtype_promotion('standard') # Test explicitly exercises implicit dtype promotion. def test_heterogeneous_jacfwd(self): - # See https://github.com/google/jax/issues/7157 - # See https://github.com/google/jax/issues/7780 + # See https://github.com/jax-ml/jax/issues/7157 + # See https://github.com/jax-ml/jax/issues/7780 x = np.array([2.0], dtype=np.float16) y = np.array([3.0], dtype=np.float32) a = (x, y) @@ -2421,8 +2421,8 @@ def f(tup): @jax.numpy_dtype_promotion('standard') # Test explicitly exercises implicit dtype promotion. def test_heterogeneous_jacrev(self): - # See https://github.com/google/jax/issues/7157 - # See https://github.com/google/jax/issues/7780 + # See https://github.com/jax-ml/jax/issues/7157 + # See https://github.com/jax-ml/jax/issues/7780 x = np.array([2.0], dtype=np.float16) y = np.array([3.0], dtype=np.float32) a = (x, y) @@ -2440,7 +2440,7 @@ def f(tup): jtu.check_eq(actual, desired) def test_heterogeneous_grad(self): - # See https://github.com/google/jax/issues/7157 + # See https://github.com/jax-ml/jax/issues/7157 x = np.array(1.0+1j) y = np.array(2.0) a = (x, y) @@ -2512,7 +2512,7 @@ def test_devicearray_weakref_friendly(self): self.assertIsNone(y()) def test_namedtuple_transparency(self): - # See https://github.com/google/jax/issues/446 + # See https://github.com/jax-ml/jax/issues/446 Point = collections.namedtuple("Point", ["x", "y"]) def f(pt): @@ -2528,7 +2528,7 @@ def f(pt): self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False) def test_namedtuple_subclass_transparency(self): - # See https://github.com/google/jax/issues/806 + # See https://github.com/jax-ml/jax/issues/806 Point = collections.namedtuple("Point", ["x", "y"]) class ZeroPoint(Point): @@ -2705,7 +2705,7 @@ def __init__(self, shape, dtype): self.assertEqual(out_shape.shape, (3, 5)) def test_eval_shape_duck_typing2(self): - # https://github.com/google/jax/issues/5683 + # https://github.com/jax-ml/jax/issues/5683 class EasyDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2980,7 +2980,7 @@ def superfun(a): ])) def test_vmap_in_axes_list(self): - # https://github.com/google/jax/issues/2367 + # https://github.com/jax-ml/jax/issues/2367 dictionary = {'a': 5., 'b': jnp.ones(2)} x = jnp.zeros(3) y = jnp.arange(3.) @@ -2993,7 +2993,7 @@ def f(dct, x, y): self.assertAllClose(out1, out2) def test_vmap_in_axes_non_tuple_error(self): - # https://github.com/google/jax/issues/18548 + # https://github.com/jax-ml/jax/issues/18548 with self.assertRaisesRegex( TypeError, re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding " @@ -3001,7 +3001,7 @@ def test_vmap_in_axes_non_tuple_error(self): jax.vmap(lambda x: x['a'], in_axes={'a': 0}) def test_vmap_in_axes_wrong_length_tuple_error(self): - # https://github.com/google/jax/issues/18548 + # https://github.com/jax-ml/jax/issues/18548 with self.assertRaisesRegex( ValueError, re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding to the " @@ -3009,7 +3009,7 @@ def test_vmap_in_axes_wrong_length_tuple_error(self): jax.vmap(lambda x: x['a'], in_axes=(0, {'a': 0}))({'a': jnp.zeros((3, 3))}) def test_vmap_in_axes_tree_prefix_error(self): - # https://github.com/google/jax/issues/795 + # https://github.com/jax-ml/jax/issues/795 value_tree = jnp.ones(3) self.assertRaisesRegex( ValueError, @@ -3030,14 +3030,14 @@ def test_vmap_out_axes_leaf_types(self): api.vmap(lambda x: x, out_axes=(jnp.array([1., 2.]),))(jnp.array([1., 2.])) def test_vmap_unbatched_object_passthrough_issue_183(self): - # https://github.com/google/jax/issues/183 + # https://github.com/jax-ml/jax/issues/183 fun = lambda f, x: f(x) vfun = api.vmap(fun, (None, 0)) ans = vfun(lambda x: x + 1, jnp.arange(3)) self.assertAllClose(ans, np.arange(1, 4), check_dtypes=False) def test_vmap_mismatched_keyword(self): - # https://github.com/google/jax/issues/10193 + # https://github.com/jax-ml/jax/issues/10193 @jax.vmap def f(x, y): return x + y @@ -3051,7 +3051,7 @@ def f(x, y): f(jnp.array([1], 'int32'), y=jnp.array([1, 2], 'int32')) def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): - # https://github.com/google/jax/issues/705 + # https://github.com/jax-ml/jax/issues/705 def h(a, b): return jnp.sum(a) + jnp.sum(b) @@ -3156,12 +3156,12 @@ def foo(tree_arg): self.assertEqual(vfoo(tree).shape, (6, 2, 5)) def test_vmap_in_axes_bool_error(self): - # https://github.com/google/jax/issues/6372 + # https://github.com/jax-ml/jax/issues/6372 with self.assertRaisesRegex(TypeError, "must be an int"): api.vmap(lambda x: x, in_axes=False)(jnp.zeros(3)) def test_pmap_in_axes_bool_error(self): - # https://github.com/google/jax/issues/6372 + # https://github.com/jax-ml/jax/issues/6372 with self.assertRaisesRegex(TypeError, "must be an int"): api.pmap(lambda x: x, in_axes=False)(jnp.zeros(1)) @@ -3223,7 +3223,7 @@ def test_device_array_hash(self): hash(rep) def test_grad_without_enough_args_error_message(self): - # https://github.com/google/jax/issues/1696 + # https://github.com/jax-ml/jax/issues/1696 def f(x, y): return x + y df = api.grad(f, argnums=0) self.assertRaisesRegex( @@ -3301,7 +3301,7 @@ def f(x): self.assertEqual(count[0], 0) # cache hits on both fwd and bwd def test_grad_does_not_unflatten_tree_with_none(self): - # https://github.com/google/jax/issues/7546 + # https://github.com/jax-ml/jax/issues/7546 class CustomNode(list): pass @@ -3370,7 +3370,7 @@ def test_primitive_compilation_cache(self): self.assertEqual(count[0], 1) def test_arange_jit(self): - # see https://github.com/google/jax/issues/553 + # see https://github.com/jax-ml/jax/issues/553 def fun(x): r = jnp.arange(x.shape[0])[x] return r @@ -3496,7 +3496,7 @@ def test_escaped_tracer_shape_dtype(self): _ = self._saved_tracer+1 def test_pmap_static_kwarg_error_message(self): - # https://github.com/google/jax/issues/3007 + # https://github.com/jax-ml/jax/issues/3007 def f(a, b): return a + b @@ -3650,7 +3650,7 @@ def g(x): g(1) def test_join_concrete_arrays_with_omnistaging(self): - # https://github.com/google/jax/issues/4622 + # https://github.com/jax-ml/jax/issues/4622 x = jnp.array([1., 2., 3.]) y = jnp.array([1., 2., 4.]) @@ -3673,7 +3673,7 @@ def fn(x): self.assertEqual(aux, True) def test_linearize_aval_error(self): - # https://github.com/google/jax/issues/4622 + # https://github.com/jax-ml/jax/issues/4622 f = lambda x: x # these should not error @@ -3691,7 +3691,7 @@ def test_linearize_aval_error(self): f_jvp(np.ones(2, np.int32)) def test_grad_of_token_consuming_primitive(self): - # https://github.com/google/jax/issues/5463 + # https://github.com/jax-ml/jax/issues/5463 tokentest_p = core.Primitive("tokentest") tokentest_p.def_impl(partial(xla.apply_primitive, tokentest_p)) tokentest_p.def_abstract_eval(lambda x, y: x) @@ -3823,7 +3823,7 @@ def g(x): f(3) def test_leak_checker_avoids_false_positive_custom_jvp(self): - # see https://github.com/google/jax/issues/5636 + # see https://github.com/jax-ml/jax/issues/5636 with jax.checking_leaks(): @jax.custom_jvp def t(y): @@ -3906,7 +3906,7 @@ def test_default_device(self): self.assertEqual(jnp.ones(1).devices(), system_default_devices) def test_dunder_jax_array(self): - # https://github.com/google/jax/pull/4725 + # https://github.com/jax-ml/jax/pull/4725 class AlexArray: def __init__(self, jax_val): @@ -3939,7 +3939,7 @@ def __jax_array__(self): self.assertAllClose(np.array(((1, 1), (1, 1))), a2) def test_eval_shape_weak_type(self): - # https://github.com/google/jax/issues/23302 + # https://github.com/jax-ml/jax/issues/23302 arr = jax.numpy.array(1) with jtu.count_jit_tracing_cache_miss() as count: @@ -3980,7 +3980,7 @@ def __jax_array__(self) -> jax.Array: f(a, a) # don't crash def test_constant_handler_mro(self): - # https://github.com/google/jax/issues/6129 + # https://github.com/jax-ml/jax/issues/6129 class Foo(enum.IntEnum): bar = 1 @@ -3997,7 +3997,7 @@ def f(_): {"testcase_name": f"{dtype.__name__}", "dtype": dtype} for dtype in jtu.dtypes.all]) def test_constant_handlers(self, dtype): - # https://github.com/google/jax/issues/9380 + # https://github.com/jax-ml/jax/issues/9380 @jax.jit def f(): return jnp.exp(dtype(0)) @@ -4135,7 +4135,7 @@ def f(x): jaxpr = api.make_jaxpr(f)(3) self.assertNotIn('pjit', str(jaxpr)) - # Repro for https://github.com/google/jax/issues/7229. + # Repro for https://github.com/jax-ml/jax/issues/7229. def test_compute_with_large_transfer(self): def f(x, delta): return x + jnp.asarray(delta, x.dtype) @@ -4193,7 +4193,7 @@ def transpose(f, x): self.assertEqual(actual, expected) def test_leaked_tracer_issue_7613(self): - # from https://github.com/google/jax/issues/7613 + # from https://github.com/jax-ml/jax/issues/7613 import numpy.random as npr def sigmoid(x): @@ -4211,7 +4211,7 @@ def loss(A, x): _ = jax.grad(loss)(A, x) # doesn't crash def test_vmap_caching(self): - # https://github.com/google/jax/issues/7621 + # https://github.com/jax-ml/jax/issues/7621 f = lambda x: jnp.square(x).mean() jf = jax.jit(f) @@ -4299,7 +4299,7 @@ def g(x, y): self.assertEqual(2 * i, g(2, i), msg=i) def test_fastpath_cache_confusion(self): - # https://github.com/google/jax/issues/12542 + # https://github.com/jax-ml/jax/issues/12542 @jax.jit def a(x): return () @@ -4344,7 +4344,7 @@ def h(x): b(8) # don't crash def test_vjp_multiple_arguments_error_message(self): - # https://github.com/google/jax/issues/13099 + # https://github.com/jax-ml/jax/issues/13099 def foo(x): return (x, x) _, f_vjp = jax.vjp(foo, 1.0) @@ -4376,7 +4376,7 @@ def foo(x, y, z): self.assertEqual(jfoo.__module__, "jax") def test_inner_jit_function_retracing(self): - # https://github.com/google/jax/issues/7155 + # https://github.com/jax-ml/jax/issues/7155 inner_count = outer_count = 0 @jax.jit @@ -4403,7 +4403,7 @@ def outer_fn(x): self.assertEqual(outer_count, 1) def test_grad_conj_symbolic_zeros(self): - # https://github.com/google/jax/issues/15400 + # https://github.com/jax-ml/jax/issues/15400 f = lambda x: jax.jit(lambda x, y: (x, y))(x, jax.lax.conj(x))[0] out = jax.grad(f)(3.0) # doesn't crash self.assertAllClose(out, 1., check_dtypes=False) @@ -4555,7 +4555,7 @@ def test_jit_custom_floats(self, dtype): self._CompileAndCheck(f, args_maker) def test_jvp_asarray_returns_array(self): - # https://github.com/google/jax/issues/15676 + # https://github.com/jax-ml/jax/issues/15676 p, t = jax.jvp(jax.numpy.asarray, (1.,), (2.,)) _check_instance(self, p) _check_instance(self, t) @@ -4716,7 +4716,7 @@ def g(): f() def test_inline_return_twice(self): - # https://github.com/google/jax/issues/22944 + # https://github.com/jax-ml/jax/issues/22944 @jax.jit def add_one(x: int) -> int: return x + 1 @@ -5074,7 +5074,7 @@ def f_yesremat(x): ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_no_redundant_flops(self, remat): - # see https://github.com/google/jax/pull/1749#issuecomment-558267584 + # see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584 @api.jit def g(x): @@ -5124,7 +5124,7 @@ def binom_checkpoint(funs): ('_new', new_checkpoint), ]) def test_remat_symbolic_zeros(self, remat): - # code from https://github.com/google/jax/issues/1907 + # code from https://github.com/jax-ml/jax/issues/1907 key = jax.random.key(0) key, split = jax.random.split(key) @@ -5177,7 +5177,7 @@ def g(): ('_new', new_checkpoint), ]) def test_remat_nontrivial_env(self, remat): - # simplified from https://github.com/google/jax/issues/2030 + # simplified from https://github.com/jax-ml/jax/issues/2030 @remat def foo(state, dt=0.5, c=1): @@ -5211,7 +5211,7 @@ def loss(u0, target, steps, dt=1/jnp.sqrt(2), c=1): ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_jit3(self, remat): - # https://github.com/google/jax/issues/2180 + # https://github.com/jax-ml/jax/issues/2180 def f(w, x): a = jnp.dot(x, w) b = jnp.einsum("btd,bTd->btT", a, a) @@ -5244,7 +5244,7 @@ def f(w, x): ('_new', new_checkpoint), ]) def test_remat_scan2(self, remat): - # https://github.com/google/jax/issues/1963 + # https://github.com/jax-ml/jax/issues/1963 def scan_bug(x0): f = lambda x, _: (x + 1, None) @@ -5256,7 +5256,7 @@ def scanned_f(x, _): jax.grad(scan_bug)(1.0) # doesn't crash def test_remat_jit_static_argnum_omnistaging(self): - # https://github.com/google/jax/issues/2833 + # https://github.com/jax-ml/jax/issues/2833 # NOTE(mattjj): after #3370, this test doesn't actually call remat... def named_call(f): def named_f(*args): @@ -5281,7 +5281,7 @@ def f(a_bool, y): ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_eval_counter(self, remat): - # https://github.com/google/jax/issues/2737 + # https://github.com/jax-ml/jax/issues/2737 add_one_p = core.Primitive('add_one') add_one = add_one_p.bind @@ -5665,7 +5665,7 @@ def test_constants_not_hoisted(self): # The old implementation of remat worked by data dependence, and so # (potentially large) constants would not be rematerialized and could be # wastefully instantiated. This test checks that the newer remat - # implementation avoids that. See https://github.com/google/jax/pull/8191. + # implementation avoids that. See https://github.com/jax-ml/jax/pull/8191. # no residuals from constants created inside jnp.einsum @partial(new_checkpoint, policy=lambda *_, **__: False) @@ -5790,7 +5790,7 @@ def f(x): _ = jax.grad(f)(3.) # doesn't crash def test_linearize_caching(self): - # https://github.com/google/jax/issues/9661 + # https://github.com/jax-ml/jax/issues/9661 identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) _, f_lin = jax.linearize(identity, 1.) with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 @@ -5799,7 +5799,7 @@ def test_linearize_caching(self): self.assertEqual(count[0], 1) # cached after first execution def test_vjp_caching(self): - # https://github.com/google/jax/issues/9661 + # https://github.com/jax-ml/jax/issues/9661 identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) _, f_vjp = jax.vjp(identity, 1.) with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841 @@ -6928,7 +6928,7 @@ def f_jvp(primals, tangents): check_dtypes=False) def test_kwargs(self): - # from https://github.com/google/jax/issues/1938 + # from https://github.com/jax-ml/jax/issues/1938 @jax.custom_jvp def my_fun(x, y, c=1.): return c * (x + y) @@ -7209,7 +7209,7 @@ def foo_jvp(primals, tangents): def test_jvp_rule_doesnt_return_pair_error_message(self): - # https://github.com/google/jax/issues/2516 + # https://github.com/jax-ml/jax/issues/2516 @jax.custom_jvp def f(x): @@ -7374,7 +7374,7 @@ def _expit_jvp(primals, tangents): api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3))) def test_jaxpr_zeros(self): - # from https://github.com/google/jax/issues/2657 + # from https://github.com/jax-ml/jax/issues/2657 @jax.custom_jvp def f(A, b): return A @ b @@ -7420,7 +7420,7 @@ def foo(x): self.assertAllClose(ans, expected, check_dtypes=False) def test_custom_jvps_first_rule_is_none(self): - # https://github.com/google/jax/issues/3389 + # https://github.com/jax-ml/jax/issues/3389 @jax.custom_jvp def f(x, y): return x ** 2 * y @@ -7431,7 +7431,7 @@ def f(x, y): self.assertAllClose(ans, expected, check_dtypes=False) def test_concurrent_initial_style(self): - # https://github.com/google/jax/issues/3843 + # https://github.com/jax-ml/jax/issues/3843 def unroll(param, sequence): def scan_f(prev_state, inputs): return prev_state, jax.nn.sigmoid(param * inputs) @@ -7453,7 +7453,7 @@ def run(): self.assertAllClose(ans, expected) def test_nondiff_argnums_vmap_tracer(self): - # https://github.com/google/jax/issues/3964 + # https://github.com/jax-ml/jax/issues/3964 @partial(jax.custom_jvp, nondiff_argnums=(0, 2)) def sample(shape, param, seed): return jax.random.uniform(key=seed, shape=shape, minval=param) @@ -7495,7 +7495,7 @@ def baz(w): api.vmap(fun_with_nested_calls_2)(jnp.arange(3.)) def test_closure_with_vmap(self): - # https://github.com/google/jax/issues/3822 + # https://github.com/jax-ml/jax/issues/3822 alpha = np.float32(2.) def sample(seed): @@ -7515,7 +7515,7 @@ def f_jvp(primal, tangent): api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash def test_closure_with_vmap2(self): - # https://github.com/google/jax/issues/8783 + # https://github.com/jax-ml/jax/issues/8783 def h(z): def f(x): @jax.custom_jvp @@ -7660,7 +7660,7 @@ def foo(x): self.assertAllClose(ans, expected, check_dtypes=False) def test_custom_jvp_vmap_broadcasting_interaction(self): - # https://github.com/google/jax/issues/6452 + # https://github.com/jax-ml/jax/issues/6452 def f2(y, z): v1 = z v2 = jnp.sum(y) + z @@ -7678,7 +7678,7 @@ def f1(y, z): self.assertEqual(g.shape, ()) def test_custom_jvp_vmap_broadcasting_interaction_2(self): - # https://github.com/google/jax/issues/5849 + # https://github.com/jax-ml/jax/issues/5849 @jax.custom_jvp def transform(box, R): if jnp.isscalar(box) or box.size == 1: @@ -7716,7 +7716,7 @@ def energy_fn(box): self.assertEqual(grad(energy_fn)(scalar_box).shape, ()) def test_custom_jvp_implicit_broadcasting(self): - # https://github.com/google/jax/issues/6357 + # https://github.com/jax-ml/jax/issues/6357 if config.enable_x64.value: raise unittest.SkipTest("test only applies when x64 is disabled") @@ -7774,7 +7774,7 @@ def fun(X): self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3) def test_vmap_inside_defjvp(self): - # https://github.com/google/jax/issues/3201 + # https://github.com/jax-ml/jax/issues/3201 seed = 47 key = jax.random.key(seed) mat = jax.random.normal(key, (2, 3)) @@ -7823,7 +7823,7 @@ def operate(mx, val): jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash def test_custom_jvp_unbroadcasting(self): - # https://github.com/google/jax/issues/3056 + # https://github.com/jax-ml/jax/issues/3056 a = jnp.array([1., 1.]) @jax.custom_jvp @@ -7841,8 +7841,8 @@ def f_jvp(primals, tangents): def test_maybe_perturbed_internal_helper_function(self): # This is a unit test for an internal API. We include it so as not to - # regress https://github.com/google/jax/issues/9567. For an explanation of - # this helper function, see https://github.com/google/jax/issues/6415. + # regress https://github.com/jax-ml/jax/issues/9567. For an explanation of + # this helper function, see https://github.com/jax-ml/jax/issues/6415. def f(x): def g(y, _): z = y * x @@ -7854,7 +7854,7 @@ def g(y, _): jax.jvp(f, (1.0,), (1.0,)) # assertions inside f def test_maybe_perturbed_int_regression(self): - # see https://github.com/google/jax/discussions/9951 + # see https://github.com/jax-ml/jax/discussions/9951 @jax.jit def f(): @@ -7864,7 +7864,7 @@ def f(): f() def test_sinc_constant_function_batching(self): - # https://github.com/google/jax/pull/10756 + # https://github.com/jax-ml/jax/pull/10756 batch_data = jnp.arange(15.).reshape(5, 3) @jax.vmap @@ -7981,7 +7981,7 @@ def f_jvp(primals, tangents): _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! def test_symbolic_zeros_under_jit(self): - # https://github.com/google/jax/issues/14833 + # https://github.com/jax-ml/jax/issues/14833 Zero = jax.custom_derivatives.SymbolicZero @jax.custom_jvp @@ -8015,7 +8015,7 @@ def jvp_fn(primals, tangents): self.assertEqual((1.0, 0.1), jax.grad(lambda args: fn(*args))((1.0, 2.0))) def test_run_rules_more_than_once(self): - # https://github.com/google/jax/issues/16614 + # https://github.com/jax-ml/jax/issues/16614 @jax.custom_jvp def f(x, y): @@ -8206,7 +8206,7 @@ def f_rev(cos_x, g): lambda: api.jvp(jit(f), (3.,), (1.,))) def test_kwargs(self): - # from https://github.com/google/jax/issues/1938 + # from https://github.com/jax-ml/jax/issues/1938 @jax.custom_vjp def my_fun(x, y, c=1.): return c * (x + y) @@ -8502,7 +8502,7 @@ def test_issue2511(self): api.jit(foo)(arr) # doesn't crash def test_lowering_out_of_traces(self): - # https://github.com/google/jax/issues/2578 + # https://github.com/jax-ml/jax/issues/2578 class F(collections.namedtuple("F", ["a"])): def __call__(self, x): @@ -8515,7 +8515,7 @@ def g(f, x): jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash def test_clip_gradient(self): - # https://github.com/google/jax/issues/2784 + # https://github.com/jax-ml/jax/issues/2784 @jax.custom_vjp def _clip_gradient(lo, hi, x): return x # identity function when not differentiating @@ -8538,7 +8538,7 @@ def clip_gradient(x): self.assertAllClose(g, jnp.array(0.2)) def test_nestable_vjp(self): - # Verify that https://github.com/google/jax/issues/3667 is resolved. + # Verify that https://github.com/jax-ml/jax/issues/3667 is resolved. def f(x): return x ** 2 @@ -8571,7 +8571,7 @@ def z(x): self.assertAllClose(y, jnp.array(6.0)) def test_initial_style_vmap_2(self): - # https://github.com/google/jax/issues/4173 + # https://github.com/jax-ml/jax/issues/4173 x = jnp.ones((10, 3)) # Create the custom function @@ -8837,7 +8837,7 @@ def f_rev(cos, g): self.assertAllClose(ans, expected, check_dtypes=False) def test_custom_vjp_closure_4521(self): - # https://github.com/google/jax/issues/4521 + # https://github.com/jax-ml/jax/issues/4521 @jax.custom_vjp def g(x, y): return None @@ -8954,7 +8954,7 @@ def closure(x): def test_closure_convert_mixed_consts(self): # Like test_closure_convert, but close over values that # participate in AD as well as values that do not. - # See https://github.com/google/jax/issues/6415 + # See https://github.com/jax-ml/jax/issues/6415 def cos_after(fn, x): converted_fn, aux_args = jax.closure_convert(fn, x) @@ -8993,7 +8993,7 @@ def closure(x): self.assertAllClose(g_x, 17. * x, check_dtypes=False) def test_closure_convert_pytree_mismatch(self): - # See https://github.com/google/jax/issues/23588 + # See https://github.com/jax-ml/jax/issues/23588 def f(x, z): return z * x @@ -9021,7 +9021,7 @@ def f_bwd(_, zbar): jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash def test_custom_vjp_scan_batching_edge_case(self): - # https://github.com/google/jax/issues/5832 + # https://github.com/jax-ml/jax/issues/5832 @jax.custom_vjp def mul(x, coeff): return x * coeff def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff) @@ -9052,7 +9052,7 @@ def f_(x, t): modes=['rev']) def test_closure_with_vmap2(self): - # https://github.com/google/jax/issues/8783 + # https://github.com/jax-ml/jax/issues/8783 def h(z): def f(x): @jax.custom_vjp @@ -9094,7 +9094,7 @@ def f_bwd(_, g): jax.grad(f)(A([1.])) # doesn't crash def test_vmap_vjp_called_twice(self): - # https://github.com/google/jax/pull/14728 + # https://github.com/jax-ml/jax/pull/14728 @jax.custom_vjp def f(x): return x @@ -9390,7 +9390,7 @@ def f_bwd(_, z_bar): _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! def test_run_rules_more_than_once(self): - # https://github.com/google/jax/issues/16614 + # https://github.com/jax-ml/jax/issues/16614 @jax.custom_vjp def f(x, y): @@ -9420,7 +9420,7 @@ def g(x): g(1.) # doesn't crash def test_nones_representing_zeros_in_subtrees_returned_by_bwd(self): - # https://github.com/google/jax/issues/8356 + # https://github.com/jax-ml/jax/issues/8356 @jax.custom_vjp def f(x): return x[0] @@ -9618,7 +9618,7 @@ def f_bwd(res, g): jax.grad(f)(x, y) # Doesn't error def test_optimize_remat_custom_vmap(self): - # See https://github.com/google/jax/pull/23000 + # See https://github.com/jax-ml/jax/pull/23000 @jax.custom_vjp def f(x, y): return jnp.sin(x) * y @@ -10908,7 +10908,7 @@ def test_autodidax_smoketest(self): class GarbageCollectionTest(jtu.JaxTestCase): def test_xla_gc_callback(self): - # https://github.com/google/jax/issues/14882 + # https://github.com/jax-ml/jax/issues/14882 x_np = np.arange(10, dtype='int32') x_jax = jax.device_put(x_np) x_np_weakref = weakref.ref(x_np) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index c2cd4c0f968d..5585f1bcc005 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -185,7 +185,7 @@ def testJaxToTensorFlow(self, shape, dtype): @unittest.skipIf(not tf, "Test requires TensorFlow") def testTensorFlowToJaxInt64(self): - # See https://github.com/google/jax/issues/11895 + # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( tf.experimental.dlpack.to_dlpack(tf.ones((2, 3), tf.int64))) dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32 diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 4378a3c7526d..4eb354a8d50f 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -323,7 +323,7 @@ def f(obj, x): self.assertEqual(count, 1) def test_tracer_lifetime_bug(self): - # regression test for https://github.com/google/jax/issues/20082 + # regression test for https://github.com/jax-ml/jax/issues/20082 class StatefulRNG: key: jax.Array diff --git a/tests/batching_test.py b/tests/batching_test.py index 6cd8c7bc20ac..2b0b0d63a6f5 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -335,7 +335,7 @@ def testConcatenate(self): self.assertAllClose(ans, expected_ans, check_dtypes=False) def testJacobianIssue54(self): - # test modeling the code in https://github.com/google/jax/issues/54 + # test modeling the code in https://github.com/jax-ml/jax/issues/54 def func(xs): return jnp.array(list(xs)) @@ -345,7 +345,7 @@ def func(xs): jacfwd(func)(xs) # don't crash def testAny(self): - # test modeling the code in https://github.com/google/jax/issues/108 + # test modeling the code in https://github.com/jax-ml/jax/issues/108 ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]])) expected = jnp.array([True, False]) @@ -368,7 +368,7 @@ def fun(x, t): def testDynamicSlice(self): # test dynamic_slice via numpy indexing syntax - # see https://github.com/google/jax/issues/1613 for an explanation of why we + # see https://github.com/jax-ml/jax/issues/1613 for an explanation of why we # need to use np rather than np to create x and idx x = jnp.arange(30).reshape((10, 3)) @@ -933,7 +933,7 @@ def f(scale): rtol=jtu.default_gradient_tolerance) def testIssue387(self): - # https://github.com/google/jax/issues/387 + # https://github.com/jax-ml/jax/issues/387 R = self.rng().rand(100, 2) def dist_sq(R): @@ -951,7 +951,7 @@ def f(R): @jax.legacy_prng_key('allow') def testIssue489(self): - # https://github.com/google/jax/issues/489 + # https://github.com/jax-ml/jax/issues/489 def f(key): def body_fn(uk): key = uk[1] @@ -1131,7 +1131,7 @@ def testAxisIndex(self): x - np.arange(x.shape[0], dtype='int32')) def testVmapKwargs(self): - # https://github.com/google/jax/issues/912 + # https://github.com/jax-ml/jax/issues/912 def f(a, b): return (2*a, 3*b) @@ -1242,7 +1242,7 @@ def f(x): self.assertEqual(jax.vmap(f)(jnp.ones((2, 3))).shape, (2, 3)) def testPpermuteBatcherTrivial(self): - # https://github.com/google/jax/issues/8688 + # https://github.com/jax-ml/jax/issues/8688 def ppermute(input): return jax.lax.ppermute(input, axis_name="i", perm=[[0, 1], [1, 0]]) @@ -1255,7 +1255,7 @@ def ppermute(input): self.assertAllClose(ans, jnp.ones(2), check_dtypes=False) def testBatchingPreservesWeakType(self): - # Regression test for https://github.com/google/jax/issues/10025 + # Regression test for https://github.com/jax-ml/jax/issues/10025 x = jnp.ravel(1) self.assertTrue(dtypes.is_weakly_typed(x)) @vmap diff --git a/tests/core_test.py b/tests/core_test.py index 0838702c4be6..94b7010907a9 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -349,7 +349,7 @@ def g_vmap(x): g_vmap(jnp.ones((1, ))) def test_concrete_array_string_representation(self): - # https://github.com/google/jax/issues/5364 + # https://github.com/jax-ml/jax/issues/5364 self.assertEqual( str(core.ConcreteArray(np.dtype(np.int32), np.array([1], dtype=np.int32))), @@ -369,7 +369,7 @@ def body(c, _): self.assertEqual(dropvar.aval, aval) def test_input_residual_forwarding(self): - # https://github.com/google/jax/pull/11151 + # https://github.com/jax-ml/jax/pull/11151 x = jnp.arange(3 * 4.).reshape(3, 4) y = jnp.arange(4 * 3.).reshape(4, 3) diff --git a/tests/custom_linear_solve_test.py b/tests/custom_linear_solve_test.py index 830526826059..857dc34d430e 100644 --- a/tests/custom_linear_solve_test.py +++ b/tests/custom_linear_solve_test.py @@ -291,7 +291,7 @@ def transpose_solve(vecmat, x): jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3) - # regression test for https://github.com/google/jax/issues/1536 + # regression test for https://github.com/jax-ml/jax/issues/1536 jtu.check_grads(jax.jit(linear_solve), (a, b), order=2, rtol={np.float32: 2e-3}) @@ -396,7 +396,7 @@ def custom_unrolled_lower_tri_solve(mat, b): def test_custom_linear_solve_pytree_with_aux(self): # Check that lax.custom_linear_solve handles # pytree inputs + has_aux=True - # https://github.com/google/jax/pull/13093 + # https://github.com/jax-ml/jax/pull/13093 aux_orig = {'a': 1, 'b': 2} b = {'c': jnp.ones(2), 'd': jnp.ones(3)} diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 19e2a5893835..020c9f744833 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -127,7 +127,7 @@ def testPjit(self): ans.block_until_ready() def testDebugNansJitWithDonation(self): - # https://github.com/google/jax/issues/12514 + # https://github.com/jax-ml/jax/issues/12514 a = jnp.array(0.) with self.assertRaises(FloatingPointError): ans = jax.jit(lambda x: 0. / x, donate_argnums=(0,))(a) @@ -214,7 +214,7 @@ def f(x): f(1) def testDebugNansDoesntCorruptCaches(self): - # https://github.com/google/jax/issues/6614 + # https://github.com/jax-ml/jax/issues/6614 @jax.jit def f(x): return jnp.divide(x, x) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index c6b12e2d8a16..e736e06da2d0 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -667,7 +667,7 @@ def testJaxTypeWeak(self, dtype): {"testcase_name": f"_{typ}", "typ": typ} for typ in [bool, int, float, complex]) def testScalarWeakTypes(self, typ): - # Regression test for https://github.com/google/jax/issues/11377 + # Regression test for https://github.com/jax-ml/jax/issues/11377 val = typ(0) result1 = jnp.array(val) @@ -806,7 +806,7 @@ def testBinaryPromotionJitInvariance(self, xtype, ytype, xfun, yfun): for weak_type in [True, False] ) def testUnaryPromotion(self, dtype, weak_type): - # Regression test for https://github.com/google/jax/issues/6051 + # Regression test for https://github.com/jax-ml/jax/issues/6051 if dtype in intn_dtypes: self.skipTest("XLA support for int2 and int4 is incomplete.") x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) @@ -852,7 +852,7 @@ def testBinaryNonPromotion(self, dtype, weak_type, promotion): self.skipTest("XLA support for float8 is incomplete.") if dtype in intn_dtypes: self.skipTest("XLA support for int2 and int4 is incomplete.") - # Regression test for https://github.com/google/jax/issues/6051 + # Regression test for https://github.com/jax-ml/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) with jax.numpy_dtype_promotion(promotion): y = (x + x) diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 101dddccb7c1..cc2419fb3757 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -621,7 +621,7 @@ def test_flattening_basic(self): self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) def test_shape_validation(self): - # Regression test for https://github.com/google/jax/issues/18937 + # Regression test for https://github.com/jax-ml/jax/issues/18937 msg = r"Shapes must be 1D sequences of integer scalars, got .+" with self.assertRaisesRegex(TypeError, msg): jax.make_jaxpr(jnp.ones)(5.0) diff --git a/tests/export_test.py b/tests/export_test.py index d5884b7e6b16..0d946d84d22b 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1334,7 +1334,7 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] def test_grad_sharding_different_mesh(self): # Export and serialize with two similar meshes, the only difference being # the order of the devices. grad and serialization should not fail. - # https://github.com/google/jax/issues/21314 + # https://github.com/jax-ml/jax/issues/21314 def f(x): return jnp.sum(x * 2.) diff --git a/tests/fft_test.py b/tests/fft_test.py index 05fa96a93fae..a87b7b66e150 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -175,7 +175,7 @@ def testFftn(self, inverse, real, shape, dtype, axes, s, norm): self.assertEqual(dtype, expected_dtype) def testIrfftTranspose(self): - # regression test for https://github.com/google/jax/issues/6223 + # regression test for https://github.com/jax-ml/jax/issues/6223 def build_matrix(linear_func, size): return jax.vmap(linear_func)(jnp.eye(size, size)) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 944b47dc8b1d..837d205fbbed 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -1035,7 +1035,7 @@ def func(x, yint): ( 5.00 2 )""", testing_stream.output) def test_tap_grad_float0_result(self): - # https://github.com/google/jax/issues/7340 + # https://github.com/jax-ml/jax/issues/7340 # x is a Tuple[f32[2], s32[3]] x = (np.array([.7, .8], dtype=np.float32), np.array([11, 12, 13], dtype=np.int32)) @@ -1058,7 +1058,7 @@ def f_jax_vjp(x): ( [0.70 0.80] [11 12 13] )""", testing_stream.output) def test_tap_higher_order_grad_float0_result(self): - # https://github.com/google/jax/issues/7340 + # https://github.com/jax-ml/jax/issues/7340 # x is a Tuple[f32[2], s32[3]] x = (np.array([.7, .8], dtype=np.float32), np.array([11, 12, 13], dtype=np.int32)) @@ -1935,7 +1935,7 @@ def func(x, transforms, y): hcb.id_tap(func, 1, y=2) def test_tap_id_tap_random_key(self): - # See https://github.com/google/jax/issues/13949 + # See https://github.com/jax-ml/jax/issues/13949 with jax.enable_custom_prng(): @jax.jit def f(x): diff --git a/tests/image_test.py b/tests/image_test.py index f3cd56ed7622..0f6341086d19 100644 --- a/tests/image_test.py +++ b/tests/image_test.py @@ -180,7 +180,7 @@ def testResizeGradients(self, dtype, image_shape, target_shape, method, antialias=[False, True], ) def testResizeEmpty(self, dtype, image_shape, target_shape, method, antialias): - # Regression test for https://github.com/google/jax/issues/7586 + # Regression test for https://github.com/jax-ml/jax/issues/7586 image = np.ones(image_shape, dtype) out = jax.image.resize(image, shape=target_shape, method=method, antialias=antialias) self.assertArraysEqual(out, jnp.zeros(target_shape, dtype)) diff --git a/tests/jet_test.py b/tests/jet_test.py index b1e2ef3f8380..4e437c044426 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -404,7 +404,7 @@ def g(x): self.assertArraysEqual(g_out_series, f_out_series) def test_add_any(self): - # https://github.com/google/jax/issues/5217 + # https://github.com/jax-ml/jax/issues/5217 f = lambda x, eps: x * eps + eps + x def g(eps): x = jnp.array(1.) @@ -412,7 +412,7 @@ def g(eps): jet(g, (1.,), ([1.],)) # doesn't crash def test_scatter_add(self): - # very basic test from https://github.com/google/jax/issues/5365 + # very basic test from https://github.com/jax-ml/jax/issues/5365 def f(x): x0 = x[0] x1 = x[1] diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index ab3a183177f6..78d90cb8a072 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -424,7 +424,7 @@ def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, assert "Precision.HIGHEST" in s def testDotPreferredElementType(self): - # https://github.com/google/jax/issues/10818 + # https://github.com/jax-ml/jax/issues/10818 x = jax.numpy.ones((), jax.numpy.float16) def f(x): return jax.lax.dot_general(x, x, (((), ()), ((), ())), @@ -513,7 +513,7 @@ def testReverseGrad(self): rtol={np.float32: 3e-3}) def testPowSecondDerivative(self): - # https://github.com/google/jax/issues/12033 + # https://github.com/jax-ml/jax/issues/12033 x, y = 4.0, 0.0 expected = ((0.0, 1/x), (1/x, np.log(x) ** 2)) @@ -528,18 +528,18 @@ def testPowSecondDerivative(self): with self.subTest("zero to the zero"): result = jax.grad(lax.pow)(0.0, 0.0) # TODO(jakevdp) special-case zero in a way that doesn't break other cases - # See https://github.com/google/jax/pull/12041#issuecomment-1222766191 + # See https://github.com/jax-ml/jax/pull/12041#issuecomment-1222766191 # self.assertEqual(result, 0.0) self.assertAllClose(result, np.nan) def testPowIntPowerAtZero(self): - # https://github.com/google/jax/issues/14397 + # https://github.com/jax-ml/jax/issues/14397 ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0) self.assertAllClose(ans, 0., check_dtypes=False) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testPowIntPowerAtZero2(self): - # https://github.com/google/jax/issues/17995 + # https://github.com/jax-ml/jax/issues/17995 a = lambda z: jax.numpy.sum(z**jax.numpy.arange(0, 2, dtype=int)) b = lambda z: jax.numpy.sum(z**jax.numpy.arange(0, 2, dtype=float)) c = lambda z: 1 + z @@ -634,7 +634,7 @@ def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) def testDynamicSliceValueAndGrad(self): - # Regression test for https://github.com/google/jax/issues/10984 + # Regression test for https://github.com/jax-ml/jax/issues/10984 # Issue arose due to an out-of-range negative index. rng = jtu.rand_default(self.rng()) shape = (5, 5) @@ -649,7 +649,7 @@ def f(x): self.assertAllClose(result1, result2) def testDynamicUpdateSliceValueAndGrad(self): - # Regression test for https://github.com/google/jax/issues/10984 + # Regression test for https://github.com/jax-ml/jax/issues/10984 # Issue arose due to an out-of-range negative index. rng = jtu.rand_default(self.rng()) shape = (5, 5) @@ -1004,7 +1004,7 @@ def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) def testScatterGradSymbolicZeroUpdate(self): - # https://github.com/google/jax/issues/1901 + # https://github.com/jax-ml/jax/issues/1901 def f(x): n = x.shape[0] y = np.arange(n, dtype=x.dtype) @@ -1111,7 +1111,7 @@ def gen_y(rng, size): check_grads(lax.rem, (x, y), 2, ["fwd", "rev"]) def testHigherOrderGradientOfReciprocal(self): - # Regression test for https://github.com/google/jax/issues/3136 + # Regression test for https://github.com/jax-ml/jax/issues/3136 def inv(x): # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) return 1 / x @@ -1150,7 +1150,7 @@ def f(x): jax.jacrev(f)(x) def testPowShapeMismatch(self): - # Regression test for https://github.com/google/jax/issues/17294 + # Regression test for https://github.com/jax-ml/jax/issues/17294 x = lax.iota('float32', 4) y = 2 actual = jax.jacrev(jax.jit(jax.lax.pow))(x, y) # no error diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 37ad22063c94..7fb118d47256 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -733,7 +733,7 @@ def false_fun(x): self.assertEqual(fun(4), (8, 16)) def testCondPredIsNone(self): - # see https://github.com/google/jax/issues/11574 + # see https://github.com/jax-ml/jax/issues/11574 def f(pred, x): return lax.cond(pred, lambda x: x + 1, lambda x: x + 2, x) @@ -743,7 +743,7 @@ def f(pred, x): lambda: jax.jit(f)(None, 1.)) def testCondTwoOperands(self): - # see https://github.com/google/jax/issues/8469 + # see https://github.com/jax-ml/jax/issues/8469 add, mul = lax.add, lax.mul def fun(x): @@ -775,7 +775,7 @@ def cfun(x): self.assertEqual(fun(1), cfun(1)) def testCondCallableOperands(self): - # see https://github.com/google/jax/issues/16413 + # see https://github.com/jax-ml/jax/issues/16413 @tree_util.register_pytree_node_class class Foo: @@ -1560,7 +1560,7 @@ def f(x): {"testcase_name": f"_{name}", "cond": cond} for cond, name in COND_IMPLS) def testCondVmapGrad(self, cond): - # https://github.com/google/jax/issues/2264 + # https://github.com/jax-ml/jax/issues/2264 def f_1(x): return x ** 2 def f_2(x): return x ** 3 @@ -1839,7 +1839,7 @@ def loss(params, inputs, targets): def testIssue711(self, scan): # Tests reverse-mode differentiation through a scan for which the scanned # function also involves reverse-mode differentiation. - # See https://github.com/google/jax/issues/711 + # See https://github.com/jax-ml/jax/issues/711 def harmonic_bond(conf, params): return jnp.sum(conf * params) @@ -2078,7 +2078,7 @@ def scan_body(c, x): self.assertAllClose(carry_out[0], jnp.array([2., 2., 2.]), check_dtypes = False) def testIssue757(self): - # code from https://github.com/google/jax/issues/757 + # code from https://github.com/jax-ml/jax/issues/757 def fn(a): return jnp.cos(a) @@ -2107,7 +2107,7 @@ def testMap(self): self.assertAllClose(actual, expected) def testMapEmpty(self): - # https://github.com/google/jax/issues/2412 + # https://github.com/jax-ml/jax/issues/2412 ans = lax.map(lambda x: x * x, jnp.array([])) expected = jnp.array([]) self.assertAllClose(ans, expected) @@ -2164,7 +2164,7 @@ def body(x): lax.while_loop(cond, body, 0) def test_caches_depend_on_axis_env(self): - # https://github.com/google/jax/issues/9187 + # https://github.com/jax-ml/jax/issues/9187 scanned_f = lambda _, __: (lax.psum(1, 'i'), None) f = lambda: lax.scan(scanned_f, 0, None, length=1)[0] ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() @@ -2443,7 +2443,7 @@ def f(h, _): self.assertEqual(h, length) def test_disable_jit_cond_with_vmap(self): - # https://github.com/google/jax/issues/3093 + # https://github.com/jax-ml/jax/issues/3093 def fn(t): return lax.cond(t > 0, 0, lambda x: 0, 0, lambda x: 1) fn = jax.vmap(fn) @@ -2452,14 +2452,14 @@ def fn(t): _ = fn(jnp.array([1])) # doesn't crash def test_disable_jit_while_loop_with_vmap(self): - # https://github.com/google/jax/issues/2823 + # https://github.com/jax-ml/jax/issues/2823 def trivial_while(y): return lax.while_loop(lambda x: x < 10.0, lambda x: x + 1.0, y) with jax.disable_jit(): jax.vmap(trivial_while)(jnp.array([3.0,4.0])) # doesn't crash def test_vmaps_of_while_loop(self): - # https://github.com/google/jax/issues/3164 + # https://github.com/jax-ml/jax/issues/3164 def f(x, n): return lax.fori_loop(0, n, lambda _, x: x + 1, x) x, n = jnp.arange(3), jnp.arange(4) jax.vmap(jax.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash @@ -2567,7 +2567,7 @@ def new_jaxpr(): lambda: core.check_jaxpr(jaxpr)) def test_cond_transformation_rule_with_consts(self): - # https://github.com/google/jax/pull/9731 + # https://github.com/jax-ml/jax/pull/9731 @jax.custom_jvp def f(x): @@ -2584,7 +2584,7 @@ def f_jvp(primals, tangents): jax.jvp(g, (x,), (x,)) # doesn't crash def test_cond_excessive_compilation(self): - # Regression test for https://github.com/google/jax/issues/14058 + # Regression test for https://github.com/jax-ml/jax/issues/14058 def f(x): return x + 1 @@ -2632,7 +2632,7 @@ def body_fun(val): ('new_remat', new_checkpoint), ]) def test_scan_vjp_forwards_extensive_residuals(self, remat): - # https://github.com/google/jax/issues/4510 + # https://github.com/jax-ml/jax/issues/4510 def cumprod(x): s = jnp.ones((2, 32), jnp.float32) return lax.scan(lambda s, x: (x*s, s), s, x) @@ -2671,7 +2671,7 @@ def scan(state, xs): (jnp.array([1.]), jnp.array([[0., 1., 2., 3., 4.]])), check_dtypes=False) def test_xla_cpu_gpu_loop_cond_bug(self): - # https://github.com/google/jax/issues/5900 + # https://github.com/jax-ml/jax/issues/5900 def deriv(f): return lambda x, *args: jax.linearize(lambda x: f(x, *args), x)[1](1.0) @@ -2750,7 +2750,7 @@ def body(c, _): jax.grad(f)(1.) # doesn't crash def test_custom_jvp_tangent_cond_transpose(self): - # https://github.com/google/jax/issues/14026 + # https://github.com/jax-ml/jax/issues/14026 def mask_fun(arr, choice): out = (1 - choice) * arr.sum() + choice * (1 - arr.sum()) return out @@ -2997,7 +2997,7 @@ def test_cond_casting(self): self.assertIsInstance(y, jax.Array) def test_cond_memory_leak(self): - # https://github.com/google/jax/issues/12719 + # https://github.com/jax-ml/jax/issues/12719 def leak(): data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 02fecb7b3f1a..d3dada0d750a 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -914,7 +914,7 @@ def testOperatorRound(self, jit): check_dtypes=False) def testRoundMethod(self): - # https://github.com/google/jax/issues/15190 + # https://github.com/jax-ml/jax/issues/15190 (jnp.arange(3.) / 5.).round() # doesn't crash @jtu.sample_product(shape=[(5,), (5, 2)]) @@ -1425,7 +1425,7 @@ def testIntegerPower(self, ptype): y=[0, 32, 64, 128], ) def testIntegerPowerOverflow(self, x, y): - # Regression test for https://github.com/google/jax/issues/5987 + # Regression test for https://github.com/jax-ml/jax/issues/5987 args_maker = lambda: [x, y] self._CheckAgainstNumpy(np.power, jnp.power, args_maker) self._CompileAndCheck(jnp.power, args_maker) @@ -1536,7 +1536,7 @@ def testConcatenateArray(self, shape, dtype, axis): self._CompileAndCheck(jnp_fun, args_maker) def testConcatenateAxisNone(self): - # https://github.com/google/jax/issues/3419 + # https://github.com/jax-ml/jax/issues/3419 a = jnp.array([[1, 2], [3, 4]]) b = jnp.array([[5]]) jnp.concatenate((a, b), axis=None) @@ -2768,7 +2768,7 @@ def np_fun(x, n=n, axis=axis, prepend=prepend, append=append): self._CompileAndCheck(jnp_fun, args_maker) def testDiffPrepoendScalar(self): - # Regression test for https://github.com/google/jax/issues/19362 + # Regression test for https://github.com/jax-ml/jax/issues/19362 x = jnp.arange(10) result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) @@ -3359,7 +3359,7 @@ def _check(obj, out_dtype, weak_type): _check([jnp.complex128(1)], np.complex128, False) # Mixed inputs use JAX-style promotion. - # (regression test for https://github.com/google/jax/issues/8945) + # (regression test for https://github.com/jax-ml/jax/issues/8945) _check([0, np.int16(1)], np.int16, False) _check([0.0, np.float16(1)], np.float16, False) @@ -3932,17 +3932,17 @@ def testPathologicalFloats(self): # TODO(mattjj): test other ndarray-like method overrides def testNpMean(self): - # from https://github.com/google/jax/issues/125 + # from https://github.com/jax-ml/jax/issues/125 x = jnp.eye(3, dtype=float) + 0. ans = np.mean(x) self.assertAllClose(ans, np.array(1./3), check_dtypes=False) def testArangeOnFloats(self): np_arange = jtu.with_jax_dtype_defaults(np.arange) - # from https://github.com/google/jax/issues/145 + # from https://github.com/jax-ml/jax/issues/145 self.assertAllClose(np_arange(0.0, 1.0, 0.1), jnp.arange(0.0, 1.0, 0.1)) - # from https://github.com/google/jax/issues/3450 + # from https://github.com/jax-ml/jax/issues/3450 self.assertAllClose(np_arange(2.5), jnp.arange(2.5)) self.assertAllClose(np_arange(0., 2.5), @@ -4303,7 +4303,7 @@ def args_maker(): self._CompileAndCheck(jnp_op, args_maker) def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): - # https://github.com/google/jax/issues/5088 + # https://github.com/jax-ml/jax/issues/5088 h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) q0 = jnp.take_along_axis(h, g, axis=-1) @@ -4513,9 +4513,9 @@ def testSymmetrizeDtypePromotion(self): # NOTE(mattjj): I disabled this test when removing lax._safe_mul because # introducing the convention 0 * inf = 0 leads to silently wrong results in # some cases. See this comment for details: - # https://github.com/google/jax/issues/1052#issuecomment-514083352 + # https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352 # def testIssue347(self): - # # https://github.com/google/jax/issues/347 + # # https://github.com/jax-ml/jax/issues/347 # def test_fail(x): # x = jnp.sqrt(jnp.sum(x ** 2, axis=1)) # ones = jnp.ones_like(x) @@ -4526,7 +4526,7 @@ def testSymmetrizeDtypePromotion(self): # assert not np.any(np.isnan(result)) def testIssue453(self): - # https://github.com/google/jax/issues/453 + # https://github.com/jax-ml/jax/issues/453 a = np.arange(6) + 1 ans = jnp.reshape(a, (3, 2), order='F') expected = np.reshape(a, (3, 2), order='F') @@ -4538,7 +4538,7 @@ def testIssue453(self): op=["atleast_1d", "atleast_2d", "atleast_3d"], ) def testAtLeastNdLiterals(self, dtype, op): - # Fixes: https://github.com/google/jax/issues/634 + # Fixes: https://github.com/jax-ml/jax/issues/634 np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) jnp_fun = lambda arg: getattr(jnp, op)(arg) args_maker = lambda: [dtype(2)] @@ -5147,7 +5147,7 @@ def testDisableNumpyRankPromotionBroadcastingDecorator(self): jnp.ones(2) + 3 # don't want to warn for scalars def testStackArrayArgument(self): - # tests https://github.com/google/jax/issues/1271 + # tests https://github.com/jax-ml/jax/issues/1271 @jax.jit def foo(x): return jnp.stack(x) @@ -5316,7 +5316,7 @@ def testGradient(self, shape, varargs, axis, dtype): self._CompileAndCheck(jnp_fun, args_maker) def testZerosShapeErrors(self): - # see https://github.com/google/jax/issues/1822 + # see https://github.com/jax-ml/jax/issues/1822 self.assertRaisesRegex( TypeError, "Shapes must be 1D sequences of concrete values of integer type.*", @@ -5334,7 +5334,7 @@ def testTraceMethod(self): self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) def testIntegerPowersArePrecise(self): - # See https://github.com/google/jax/pull/3036 + # See https://github.com/jax-ml/jax/pull/3036 # Checks if the squares of float32 integers have no numerical errors. # It should be satisfied with all integers less than sqrt(2**24). x = jnp.arange(-2**12, 2**12, dtype=jnp.int32) @@ -5405,7 +5405,7 @@ def testArange64Bit(self, dtype): self._CompileAndCheck(jnp_fun, args_maker) def testIssue2347(self): - # https://github.com/google/jax/issues/2347 + # https://github.com/jax-ml/jax/issues/2347 object_list = list[tuple[jnp.array, float, float, jnp.array, bool]] self.assertRaises(TypeError, jnp.array, object_list) @@ -5617,7 +5617,7 @@ def jax_metal_supported(target_ver): return False - #https://github.com/google/jax/issues/16420 + #https://github.com/jax-ml/jax/issues/16420 def test_broadcast_dim(self): x = jnp.arange(2) f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (0,)) @@ -5640,7 +5640,7 @@ def test_triu(self): res = jnp.triu(x) jtu.check_eq(res, np.triu(x)) - #https://github.com/google/jax/issues/16471 + #https://github.com/jax-ml/jax/issues/16471 def test_matmul_1d(self): x = np.array(np.random.rand(3, 3)) y = np.array(np.random.rand(3)) @@ -5650,7 +5650,7 @@ def test_matmul_1d(self): res = jnp.dot(x, y) self.assertArraysAllClose(res, np.dot(x,y)) - #https://github.com/google/jax/issues/17175 + #https://github.com/jax-ml/jax/issues/17175 def test_indexing(self): x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) @jax.vmap @@ -5661,7 +5661,7 @@ def f(i): res = f(idx) jtu.check_eq(res, np.array([[4., 5., 6.], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.], [1., 2., 3.]])) - #https://github.com/google/jax/issues/17344 + #https://github.com/jax-ml/jax/issues/17344 def test_take_along_axis(self): @jax.jit def f(): @@ -5672,7 +5672,7 @@ def f(): return jnp.take_along_axis(x, idx, axis=1) jtu.check_eq(f(), self.dispatchOn([], f)) - #https://github.com/google/jax/issues/17590 + #https://github.com/jax-ml/jax/issues/17590 def test_in1d(self): a = np.array([123,2,4]) b = np.array([123,1]) @@ -5688,7 +5688,7 @@ def f(x): res = f(x) jtu.check_eq(res, np.array([[1., 2., 3.], [1., 5., 6.,], [1., 8., 9.], [1., 11., 12.]])) - #https://github.com/google/jax/issues/16326 + #https://github.com/jax-ml/jax/issues/16326 def test_indexing_update2(self): @jax.jit def f(x, r): @@ -5722,7 +5722,7 @@ def test_gather_ir(self): print(res) jtu.check_eq(res, res_ref) - #https://github.com/google/jax/issues/16366 + #https://github.com/jax-ml/jax/issues/16366 def test_pad_interior_1(self): if not ReportedIssuesTests.jax_metal_supported('0.0.6'): raise unittest.SkipTest("jax-metal version doesn't support it.") diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index 7397cf3e4ee8..ea7bff1d09fc 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -89,7 +89,7 @@ def test_two_operands_5(self): self._check(s, x, y) def test_two_operands_6(self): - # based on https://github.com/google/jax/issues/37#issuecomment-448572187 + # based on https://github.com/jax-ml/jax/issues/37#issuecomment-448572187 r = self.rng() x = r.randn(2, 1) y = r.randn(2, 3, 4) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index bf2785f62d68..d58a5c2c3866 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -496,7 +496,7 @@ def jnp_op(x, idx): self._CompileAndCheck(jnp_op_idx, args_maker) def testIndexApplyBatchingBug(self): - # https://github.com/google/jax/issues/16655 + # https://github.com/jax-ml/jax/issues/16655 arr = jnp.array([[1, 2, 3, 4, 5, 6]]) ind = jnp.array([3]) func = lambda a, i: a.at[i].apply(lambda x: x - 1) @@ -505,7 +505,7 @@ def testIndexApplyBatchingBug(self): self.assertArraysEqual(out, expected) def testIndexUpdateScalarBug(self): - # https://github.com/google/jax/issues/14923 + # https://github.com/jax-ml/jax/issues/14923 a = jnp.arange(10.) out = a.at[0].apply(jnp.cos) self.assertArraysEqual(out, a.at[0].set(1)) @@ -835,7 +835,7 @@ def testBooleanIndexingArray2D(self): self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean1DIndexingWithEllipsis(self): - # Regression test for https://github.com/google/jax/issues/8412 + # Regression test for https://github.com/jax-ml/jax/issues/8412 x = np.arange(24).reshape(4, 3, 2) idx = (..., np.array([True, False])) ans = jnp.array(x)[idx] @@ -843,7 +843,7 @@ def testBoolean1DIndexingWithEllipsis(self): self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean1DIndexingWithEllipsis2(self): - # Regression test for https://github.com/google/jax/issues/9050 + # Regression test for https://github.com/jax-ml/jax/issues/9050 x = np.arange(3) idx = (..., np.array([True, False, True])) ans = jnp.array(x)[idx] @@ -936,7 +936,7 @@ def testSimpleIndexingUsesSlice(self): self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) def testTrivialGatherIsntGenerated(self): - # https://github.com/google/jax/issues/1621 + # https://github.com/jax-ml/jax/issues/1621 jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertNotIn('gather', str(jaxpr)) @@ -988,14 +988,14 @@ def testBooleanIndexingWithEmptyResult(self): self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingShapeMismatch(self): - # Regression test for https://github.com/google/jax/issues/7329 + # Regression test for https://github.com/jax-ml/jax/issues/7329 x = jnp.arange(4) idx = jnp.array([True, False]) with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"): x[idx] def testBooleanIndexingWithNone(self): - # Regression test for https://github.com/google/jax/issues/18542 + # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) idx = (None, jnp.array([True, False])) ans = x[idx] @@ -1003,7 +1003,7 @@ def testBooleanIndexingWithNone(self): self.assertAllClose(ans, expected) def testBooleanIndexingWithNoneAndEllipsis(self): - # Regression test for https://github.com/google/jax/issues/18542 + # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) mask = jnp.array([True, False, False]) ans = x[None, ..., mask] @@ -1011,7 +1011,7 @@ def testBooleanIndexingWithNoneAndEllipsis(self): self.assertAllClose(ans, expected) def testBooleanIndexingWithEllipsisAndNone(self): - # Regression test for https://github.com/google/jax/issues/18542 + # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) mask = jnp.array([True, False, False]) ans = x[..., None, mask] @@ -1038,7 +1038,7 @@ def testNontrivialBooleanIndexing(self): [(3, 4, 5), (3, 0)], ) def testEmptyBooleanIndexing(self, x_shape, m_shape): - # Regression test for https://github.com/google/jax/issues/22886 + # Regression test for https://github.com/jax-ml/jax/issues/22886 rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(x_shape, np.int32), np.empty(m_shape, dtype=bool)] @@ -1120,7 +1120,7 @@ def testStrIndexingError(self): with self.assertRaisesRegex(TypeError, msg): jnp.zeros((2, 3))[:, 'abc'] - def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 + def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245 x = jnp.arange(5, dtype=jnp.int32) + 1 self.assertAllClose(x, x[:10]) @@ -1613,7 +1613,7 @@ def np_fun(data, segment_ids): self._CompileAndCheck(jnp_fun, args_maker) def testIndexDtypeError(self): - # https://github.com/google/jax/issues/2795 + # https://github.com/jax-ml/jax/issues/2795 jnp.array(1) # get rid of startup warning with self.assertNoWarnings(): jnp.zeros(5).at[::2].set(1) @@ -1647,13 +1647,13 @@ def testIndexSequenceDeprecation(self, idx, idx_type): x.at[normalize(idx)].set(0) def testIndexedUpdateAliasingBug(self): - # https://github.com/google/jax/issues/7461 + # https://github.com/jax-ml/jax/issues/7461 fn = lambda x: x.at[1:].set(1 + x[:-1]) y = jnp.zeros(8) self.assertArraysEqual(fn(y), jax.jit(fn)(y)) def testScatterValuesCastToTargetDType(self): - # https://github.com/google/jax/issues/15505 + # https://github.com/jax-ml/jax/issues/15505 a = jnp.zeros(1, dtype=jnp.uint32) val = 2**32 - 1 # too large for int32 diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index d9d6fa464d98..45a780c9f721 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -697,7 +697,7 @@ def __rmul__(self, other): self.assertIsInstance(jax.jit(operator.mul)(b, a), MyArray) def testI0Grad(self): - # Regression test for https://github.com/google/jax/issues/11479 + # Regression test for https://github.com/jax-ml/jax/issues/11479 dx = jax.grad(jax.numpy.i0)(0.0) self.assertArraysEqual(dx, 0.0) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 402e206ef37b..33830c541fb9 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -425,7 +425,7 @@ def testReducerWhere(self, name, rng_factory, shape, dtype, axis, if (shape in [()] + scalar_shapes and dtype in [jnp.int16, jnp.uint16] and jnp_op in [jnp.min, jnp.max]): - self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.") + self.skipTest("Known XLA failure; see https://github.com/jax-ml/jax/issues/4971.") rng = rng_factory(self.rng()) is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. @@ -582,7 +582,7 @@ def np_fun(x): size=[0, 1, 2] ) def testStdOrVarLargeDdofReturnsNan(self, jnp_fn, size): - # test for https://github.com/google/jax/issues/21330 + # test for https://github.com/jax-ml/jax/issues/21330 x = jnp.arange(size) self.assertTrue(np.isnan(jnp_fn(x, ddof=size))) self.assertTrue(np.isnan(jnp_fn(x, ddof=size + 1))) @@ -622,7 +622,7 @@ def np_fun(x): atol=tol) def testNanStdGrad(self): - # Regression test for https://github.com/google/jax/issues/8128 + # Regression test for https://github.com/jax-ml/jax/issues/8128 x = jnp.arange(5.0).at[0].set(jnp.nan) y = jax.grad(jnp.nanvar)(x) self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75]), check_dtypes=False) @@ -740,7 +740,7 @@ def assert_warns_or_errors(msg=msg): @unittest.skipIf(not config.enable_x64.value, "test requires X64") @jtu.run_on_devices("cpu") # test is for CPU float64 precision def testPercentilePrecision(self): - # Regression test for https://github.com/google/jax/issues/8513 + # Regression test for https://github.com/jax-ml/jax/issues/8513 x = jnp.float64([1, 2, 3, 4, 7, 10]) self.assertEqual(jnp.percentile(x, 50), 3.5) @@ -778,14 +778,14 @@ def np_fun(*args): self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) def testMeanLargeArray(self): - # https://github.com/google/jax/issues/15068 + # https://github.com/jax-ml/jax/issues/15068 raise unittest.SkipTest("test is slow, but it passes!") x = jnp.ones((16, 32, 1280, 4096), dtype='int8') self.assertEqual(1.0, jnp.mean(x)) self.assertEqual(1.0, jnp.mean(x, where=True)) def testStdLargeArray(self): - # https://github.com/google/jax/issues/15068 + # https://github.com/jax-ml/jax/issues/15068 raise unittest.SkipTest("test is slow, but it passes!") x = jnp.ones((16, 32, 1280, 4096), dtype='int8') self.assertEqual(0.0, jnp.std(x)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ddf42a28e2ba..d3f9f2d615e2 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1043,7 +1043,7 @@ def testOperatorRound(self, jit): check_dtypes=False) def testRoundMethod(self): - # https://github.com/google/jax/issues/15190 + # https://github.com/jax-ml/jax/issues/15190 (jnp.arange(3.) / 5.).round() # doesn't crash @jtu.sample_product(shape=[(5,), (5, 2)]) @@ -1571,7 +1571,7 @@ def testIntegerPower(self, ptype): y=[0, 32, 64, 128], ) def testIntegerPowerOverflow(self, x, y): - # Regression test for https://github.com/google/jax/issues/5987 + # Regression test for https://github.com/jax-ml/jax/issues/5987 args_maker = lambda: [x, y] self._CheckAgainstNumpy(np.power, jnp.power, args_maker) self._CompileAndCheck(jnp.power, args_maker) @@ -1713,7 +1713,7 @@ def testConcatenateArray(self, shape, dtype, axis): self._CompileAndCheck(jnp_fun, args_maker) def testConcatenateAxisNone(self): - # https://github.com/google/jax/issues/3419 + # https://github.com/jax-ml/jax/issues/3419 a = jnp.array([[1, 2], [3, 4]]) b = jnp.array([[5]]) jnp.concatenate((a, b), axis=None) @@ -2977,7 +2977,7 @@ def np_fun(x, n=n, axis=axis, prepend=prepend, append=append): self._CompileAndCheck(jnp_fun, args_maker) def testDiffPrepoendScalar(self): - # Regression test for https://github.com/google/jax/issues/19362 + # Regression test for https://github.com/jax-ml/jax/issues/19362 x = jnp.arange(10) result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) @@ -3611,7 +3611,7 @@ def _check(obj, out_dtype, weak_type): _check([jnp.complex128(1)], np.complex128, False) # Mixed inputs use JAX-style promotion. - # (regression test for https://github.com/google/jax/issues/8945) + # (regression test for https://github.com/jax-ml/jax/issues/8945) _check([0, np.int16(1)], np.int16, False) _check([0.0, np.float16(1)], np.float16, False) @@ -4229,17 +4229,17 @@ def testPathologicalFloats(self): # TODO(mattjj): test other ndarray-like method overrides def testNpMean(self): - # from https://github.com/google/jax/issues/125 + # from https://github.com/jax-ml/jax/issues/125 x = jnp.eye(3, dtype=float) + 0. ans = np.mean(x) self.assertAllClose(ans, np.array(1./3), check_dtypes=False) def testArangeOnFloats(self): np_arange = jtu.with_jax_dtype_defaults(np.arange) - # from https://github.com/google/jax/issues/145 + # from https://github.com/jax-ml/jax/issues/145 self.assertAllClose(np_arange(0.0, 1.0, 0.1), jnp.arange(0.0, 1.0, 0.1)) - # from https://github.com/google/jax/issues/3450 + # from https://github.com/jax-ml/jax/issues/3450 self.assertAllClose(np_arange(2.5), jnp.arange(2.5)) self.assertAllClose(np_arange(0., 2.5), @@ -4400,7 +4400,7 @@ def testPartition(self, shape, dtype, axis, kth): dtype=unsigned_dtypes, ) def testPartitionUnsignedWithZeros(self, kth, dtype): - # https://github.com/google/jax/issues/22137 + # https://github.com/jax-ml/jax/issues/22137 max_val = np.iinfo(dtype).max arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2]], dtype=dtype) axis = -1 @@ -4441,7 +4441,7 @@ def testArgpartition(self, shape, dtype, axis, kth): dtype=unsigned_dtypes, ) def testArgpartitionUnsignedWithZeros(self, kth, dtype): - # https://github.com/google/jax/issues/22137 + # https://github.com/jax-ml/jax/issues/22137 max_val = np.iinfo(dtype).max arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2, 3]], dtype=dtype) axis = -1 @@ -4616,7 +4616,7 @@ def args_maker(): self._CompileAndCheck(jnp_op, args_maker) def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): - # https://github.com/google/jax/issues/5088 + # https://github.com/jax-ml/jax/issues/5088 h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) q0 = jnp.take_along_axis(h, g, axis=-1) @@ -4837,9 +4837,9 @@ def testSymmetrizeDtypePromotion(self): # NOTE(mattjj): I disabled this test when removing lax._safe_mul because # introducing the convention 0 * inf = 0 leads to silently wrong results in # some cases. See this comment for details: - # https://github.com/google/jax/issues/1052#issuecomment-514083352 + # https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352 # def testIssue347(self): - # # https://github.com/google/jax/issues/347 + # # https://github.com/jax-ml/jax/issues/347 # def test_fail(x): # x = jnp.sqrt(jnp.sum(x ** 2, axis=1)) # ones = jnp.ones_like(x) @@ -4850,7 +4850,7 @@ def testSymmetrizeDtypePromotion(self): # assert not np.any(np.isnan(result)) def testIssue453(self): - # https://github.com/google/jax/issues/453 + # https://github.com/jax-ml/jax/issues/453 a = np.arange(6) + 1 ans = jnp.reshape(a, (3, 2), order='F') expected = np.reshape(a, (3, 2), order='F') @@ -4861,7 +4861,7 @@ def testIssue453(self): op=["atleast_1d", "atleast_2d", "atleast_3d"], ) def testAtLeastNdLiterals(self, dtype, op): - # Fixes: https://github.com/google/jax/issues/634 + # Fixes: https://github.com/jax-ml/jax/issues/634 np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) jnp_fun = lambda arg: getattr(jnp, op)(arg) args_maker = lambda: [dtype(2)] @@ -5489,7 +5489,7 @@ def testDisableNumpyRankPromotionBroadcastingDecorator(self): jnp.ones(2) + 3 # don't want to warn for scalars def testStackArrayArgument(self): - # tests https://github.com/google/jax/issues/1271 + # tests https://github.com/jax-ml/jax/issues/1271 @jax.jit def foo(x): return jnp.stack(x) @@ -5536,7 +5536,7 @@ def testBroadcastTo(self, from_shape, to_shape): self._CompileAndCheck(jnp_op, args_maker) def testBroadcastToInvalidShape(self): - # Regression test for https://github.com/google/jax/issues/20533 + # Regression test for https://github.com/jax-ml/jax/issues/20533 x = jnp.zeros((3, 4, 5)) with self.assertRaisesRegex( ValueError, "Cannot broadcast to shape with fewer dimensions"): @@ -5688,7 +5688,7 @@ def testGradientNonConstant(self, shape, dtype): self._CompileAndCheck(jnp.gradient, args_maker) def testZerosShapeErrors(self): - # see https://github.com/google/jax/issues/1822 + # see https://github.com/jax-ml/jax/issues/1822 self.assertRaisesRegex( TypeError, "Shapes must be 1D sequences of concrete values of integer type.*", @@ -5706,7 +5706,7 @@ def testTraceMethod(self): self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) def testIntegerPowersArePrecise(self): - # See https://github.com/google/jax/pull/3036 + # See https://github.com/jax-ml/jax/pull/3036 # Checks if the squares of float32 integers have no numerical errors. # It should be satisfied with all integers less than sqrt(2**24). x = jnp.arange(-2**12, 2**12, dtype=jnp.int32) @@ -5777,7 +5777,7 @@ def testArange64Bit(self, dtype): self._CompileAndCheck(jnp_fun, args_maker) def testIssue2347(self): - # https://github.com/google/jax/issues/2347 + # https://github.com/jax-ml/jax/issues/2347 object_list = list[tuple[jnp.array, float, float, jnp.array, bool]] self.assertRaises(TypeError, jnp.array, object_list) @@ -6096,7 +6096,7 @@ def testSincGradArrayInput(self): jax.grad(lambda x: jnp.sinc(x).sum())(jnp.arange(10.)) # doesn't crash def testTakeAlongAxisIssue1521(self): - # https://github.com/google/jax/issues/1521 + # https://github.com/jax-ml/jax/issues/1521 idx = jnp.repeat(jnp.arange(3), 10).reshape((30, 1)) def f(x): @@ -6207,7 +6207,7 @@ def testWrappedSignaturesMatch(self): if name == "clip": # JAX's support of the Array API spec for clip, and the way it handles # backwards compatibility was introduced in - # https://github.com/google/jax/pull/20550 with a different signature + # https://github.com/jax-ml/jax/pull/20550 with a different signature # from the one in numpy, introduced in # https://github.com/numpy/numpy/pull/26724 # TODO(dfm): After our deprecation period for the clip arguments ends diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index c65df8aa87a2..630d89f53c5a 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -379,7 +379,7 @@ def np_fun_at(x, idx, y): self._CompileAndCheck(jnp_fun_at, args_maker) def test_frompyfunc_at_broadcasting(self): - # Regression test for https://github.com/google/jax/issues/18004 + # Regression test for https://github.com/jax-ml/jax/issues/18004 args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]), np.arange(9.0).reshape(3, 3)] def np_fun(x, idx, y): diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index 56fd0f7817e3..985dba484845 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -258,7 +258,7 @@ def test_none_arg_bad_signature(self): f(*args) def test_rank_promotion_error(self): - # Regression test for https://github.com/google/jax/issues/22305 + # Regression test for https://github.com/jax-ml/jax/issues/22305 f = jnp.vectorize(jnp.add, signature="(),()->()") rank2 = jnp.zeros((10, 10)) rank1 = jnp.zeros(10) diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index d2e64833b964..303c67c5860d 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -469,7 +469,7 @@ def test_gmres_weak_types(self): self.assertTrue(dtypes.is_weakly_typed(x)) def test_linear_solve_batching_via_jacrev(self): - # See https://github.com/google/jax/issues/14249 + # See https://github.com/jax-ml/jax/issues/14249 rng = np.random.RandomState(0) M = rng.randn(5, 5) A = np.dot(M, M.T) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index df19750b1877..8a8b1dd42c35 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -165,7 +165,7 @@ def testLogSumExpComplexSign(self): self.assertAllClose(sign * np.exp(logsumexp).astype(x.dtype), expected_sumexp, rtol=tol) def testLogSumExpZeros(self): - # Regression test for https://github.com/google/jax/issues/5370 + # Regression test for https://github.com/jax-ml/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] @@ -173,14 +173,14 @@ def testLogSumExpZeros(self): self._CompileAndCheck(lax_fun, args_maker) def testLogSumExpOnes(self): - # Regression test for https://github.com/google/jax/issues/7390 + # Regression test for https://github.com/jax-ml/jax/issues/7390 args_maker = lambda: [np.ones(4, dtype='float32')] with jax.debug_infs(True): self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker) self._CompileAndCheck(lsp_special.logsumexp, args_maker) def testLogSumExpNans(self): - # Regression test for https://github.com/google/jax/issues/7634 + # Regression test for https://github.com/jax-ml/jax/issues/7634 with jax.debug_nans(True): with jax.disable_jit(): result = lsp_special.logsumexp(1.0) @@ -246,7 +246,7 @@ def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): - # https://github.com/google/jax/issues/15598 + # https://github.com/jax-ml/jax/issues/15598 x0, y0 = 0.0, 3.0 d_xlog1py_dx = jax.grad(lsp_special.xlogy, argnums=0)(x0, y0) self.assertAllClose(d_xlog1py_dx, lax.log(y0)) @@ -260,7 +260,7 @@ def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): - # https://github.com/google/jax/issues/15598 + # https://github.com/jax-ml/jax/issues/15598 x0, y0 = 0.0, 3.0 d_xlog1py_dx = jax.grad(lsp_special.xlog1py, argnums=0)(x0, y0) self.assertAllClose(d_xlog1py_dx, lax.log1p(y0)) @@ -284,7 +284,7 @@ def testXLogX(self): rtol=.1, eps=1e-3) def testGradOfEntrAtZero(self): - # https://github.com/google/jax/issues/15709 + # https://github.com/jax-ml/jax/issues/15709 self.assertEqual(jax.jacfwd(lsp_special.entr)(0.0), jnp.inf) self.assertEqual(jax.jacrev(lsp_special.entr)(0.0), jnp.inf) diff --git a/tests/lax_test.py b/tests/lax_test.py index 3d31bcb7d555..0ae5f77afbdb 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1002,7 +1002,7 @@ def fun_via_grad(lhs, rhs): self._CheckAgainstNumpy(fun_via_grad, fun, args_maker) def testConvTransposePaddingList(self): - # Regression test for https://github.com/google/jax/discussions/8695 + # Regression test for https://github.com/jax-ml/jax/discussions/8695 a = jnp.ones((28,28)) b = jnp.ones((3,3)) c = lax.conv_general_dilated(a[None, None], b[None, None], (1,1), [(0,0),(0,0)], (1,1)) @@ -1280,7 +1280,7 @@ def testBroadcastInDim(self, inshape, dtype, outshape, dimensions): self._CompileAndCheck(op, args_maker) def testBroadcastInDimOperandShapeTranspose(self): - # Regression test for https://github.com/google/jax/issues/5276 + # Regression test for https://github.com/jax-ml/jax/issues/5276 def f(x): return lax.broadcast_in_dim(x, (2, 3, 4), broadcast_dimensions=(0, 1, 2)).sum() def g(x): @@ -1681,7 +1681,7 @@ def args_maker(): lax.dynamic_update_slice, args_maker) def testDynamicUpdateSliceBatched(self): - # Regression test for https://github.com/google/jax/issues/9083 + # Regression test for https://github.com/jax-ml/jax/issues/9083 x = jnp.arange(5) y = jnp.arange(6, 9) ind = jnp.arange(6) @@ -2236,7 +2236,7 @@ def testReduceWindowShapeDilation(self, shape, window_dimensions, self.assertEqual(shape, result.shape) def testReduceWindowWithEmptyOutput(self): - # https://github.com/google/jax/issues/10315 + # https://github.com/jax-ml/jax/issues/10315 shape = (5, 3, 2) operand, padding, strides = np.ones(shape), 'VALID', (1,) * len(shape) out = jax.eval_shape(lambda x: lax.reduce_window(x, 0., lax.add, padding=padding, @@ -2859,13 +2859,13 @@ def test_ops_do_not_accept_complex_dtypes(self, op): op(2+3j, 4+5j) def test_population_count_booleans_not_supported(self): - # https://github.com/google/jax/issues/3886 + # https://github.com/jax-ml/jax/issues/3886 msg = "population_count does not accept dtype bool" with self.assertRaisesRegex(TypeError, msg): lax.population_count(True) def test_conv_general_dilated_different_input_ranks_error(self): - # https://github.com/google/jax/issues/4316 + # https://github.com/jax-ml/jax/issues/4316 msg = ("conv_general_dilated lhs and rhs must have the same number of " "dimensions") dimension_numbers = lax.ConvDimensionNumbers(lhs_spec=(0, 1, 2), @@ -2885,7 +2885,7 @@ def test_conv_general_dilated_different_input_ranks_error(self): lax.conv_general_dilated(lhs, rhs, **kwargs) def test_window_strides_dimension_shape_rule(self): - # https://github.com/google/jax/issues/5087 + # https://github.com/jax-ml/jax/issues/5087 msg = ("conv_general_dilated window and window_strides must have " "the same number of dimensions") lhs = jax.numpy.zeros((1, 1, 3, 3)) @@ -2894,7 +2894,7 @@ def test_window_strides_dimension_shape_rule(self): jax.lax.conv(lhs, rhs, [1], 'SAME') def test_reduce_window_scalar_init_value_shape_rule(self): - # https://github.com/google/jax/issues/4574 + # https://github.com/jax-ml/jax/issues/4574 args = { "operand": np.ones((4, 4), dtype=np.int32) , "init_value": np.zeros((1,), dtype=np.int32) , "computation": lax.max @@ -3045,7 +3045,7 @@ def testDynamicSliceU8Index(self): np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128]) def test_dot_general_batching_python_builtin_arg(self): - # https://github.com/google/jax/issues/16805 + # https://github.com/jax-ml/jax/issues/16805 @jax.remat def f(x): return jax.lax.dot_general(x, x, (([], []), ([], []))) @@ -3053,7 +3053,7 @@ def f(x): jax.hessian(f)(1.0) # don't crash def test_constant_folding_complex_to_real_scan_regression(self): - # regression test for github.com/google/jax/issues/19059 + # regression test for github.com/jax-ml/jax/issues/19059 def g(hiddens): hiddens_aug = jnp.vstack((hiddens[0], hiddens)) new_hiddens = hiddens_aug.copy() @@ -3088,7 +3088,7 @@ def testAsarray(self, typ): jaxpr = jax.make_jaxpr(asarray_closure)() self.assertLen(jaxpr.eqns, 0) - # Regression test for https://github.com/google/jax/issues/19334 + # Regression test for https://github.com/jax-ml/jax/issues/19334 # lax.asarray as a closure should not trigger transfer guard. with jax.transfer_guard('disallow'): jax.jit(asarray_closure)() @@ -3254,7 +3254,7 @@ def testArgMaxOfNanChoosesNaN(self): def testUnaryWeakTypes(self, op_name, rec_dtypes): """Test that all lax unary ops propagate weak_type information appropriately.""" if op_name == "bitwise_not": - raise unittest.SkipTest("https://github.com/google/jax/issues/12066") + raise unittest.SkipTest("https://github.com/jax-ml/jax/issues/12066") # Find a valid dtype for the function. for dtype in [float, int, complex, bool]: dtype = dtypes.canonicalize_dtype(dtype) @@ -3648,7 +3648,7 @@ def test_gather(self): self.assertEqual(ys.shape, (3, 2, 1)) def test_gather_batched_index_dtype(self): - # Regression test for https://github.com/google/jax/issues/16557 + # Regression test for https://github.com/jax-ml/jax/issues/16557 dtype = jnp.int8 size = jnp.iinfo(dtype).max + 10 indices = jnp.zeros(size, dtype=dtype) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 37d51c04f8de..37a0011e7bd0 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -693,8 +693,8 @@ def testSort(self, shape, dimension, arity, bdims, is_stable): # TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU @jtu.skip_on_devices("gpu") def test_variadic_reduce_window(self): - # https://github.com/google/jax/discussions/9818 and - # https://github.com/google/jax/issues/9837 + # https://github.com/jax-ml/jax/discussions/9818 and + # https://github.com/jax-ml/jax/issues/9837 def normpool(x): norms = jnp.linalg.norm(x, axis=-1) idxs = jnp.arange(x.shape[0]) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 446e10abd097..15963b10b6e2 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -329,7 +329,7 @@ def testEigvals(self, shape, dtype): @jtu.run_on_devices("cpu") def testEigvalsInf(self): - # https://github.com/google/jax/issues/2661 + # https://github.com/jax-ml/jax/issues/2661 x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -1004,7 +1004,7 @@ def qr_and_mul(a): @jtu.skip_on_devices("tpu") def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16): - # Regression test for https://github.com/google/jax/issues/10530 + # Regression test for https://github.com/jax-ml/jax/issues/10530 rng = jtu.rand_default(self.rng()) arr = rng(shape, dtype) if jtu.test_device_matches(['cpu']): @@ -1422,7 +1422,7 @@ def testLuOfSingularMatrix(self): @parameterized.parameters(lax_linalg.lu, lax_linalg._lu_python) def testLuOnZeroMatrix(self, lu): - # Regression test for https://github.com/google/jax/issues/19076 + # Regression test for https://github.com/jax-ml/jax/issues/19076 x = jnp.zeros((2, 2), dtype=np.float32) x_lu, _, _ = lu(x) self.assertArraysEqual(x_lu, x) diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 5daa0e0e5b84..a3b6b1efaa76 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -174,7 +174,7 @@ def test_device_put(self): def test_closed_over_values_device_placement(self): - # see https://github.com/google/jax/issues/1431 + # see https://github.com/jax-ml/jax/issues/1431 devices = self.get_devices() def f(): return lax.add(3., 4.) diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 4f2e36c64f4b..4697ba8b2858 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -148,7 +148,7 @@ def get_arr(scale): @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument") def test_closed_over_values_device_placement(self): - # see https://github.com/google/jax/issues/1431 + # see https://github.com/jax-ml/jax/issues/1431 def f(): return jnp.add(3., 4.) self.assertNotEqual(jax.jit(f)().devices(), {jax.devices('cpu')[0]}) @@ -186,7 +186,7 @@ def my_sin(x): return jnp.sin(x) @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends def test_indexing(self): - # https://github.com/google/jax/issues/2905 + # https://github.com/jax-ml/jax/issues/2905 cpus = jax.devices("cpu") x = jax.device_put(np.ones(2), cpus[0]) @@ -195,7 +195,7 @@ def test_indexing(self): @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends def test_sum(self): - # https://github.com/google/jax/issues/2905 + # https://github.com/jax-ml/jax/issues/2905 cpus = jax.devices("cpu") x = jax.device_put(np.ones(2), cpus[0]) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index 40c5b6b2ea92..5c84f8c69b62 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -46,7 +46,7 @@ @unittest.skipIf(not portpicker, "Test requires portpicker") class DistributedTest(jtu.JaxTestCase): - # TODO(phawkins): Enable after https://github.com/google/jax/issues/11222 + # TODO(phawkins): Enable after https://github.com/jax-ml/jax/issues/11222 # is fixed. @unittest.SkipTest def testInitializeAndShutdown(self): diff --git a/tests/nn_test.py b/tests/nn_test.py index be07de184e60..416beffce17b 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -325,12 +325,12 @@ def testDtypeMatchesInput(self, dtype, fn): self.assertEqual(out.dtype, dtype) def testEluMemory(self): - # see https://github.com/google/jax/pull/1640 + # see https://github.com/jax-ml/jax/pull/1640 with jax.enable_checks(False): # With checks we materialize the array jax.make_jaxpr(lambda: nn.elu(jnp.ones((10 ** 12,)))) # don't oom def testHardTanhMemory(self): - # see https://github.com/google/jax/pull/1640 + # see https://github.com/jax-ml/jax/pull/1640 with jax.enable_checks(False): # With checks we materialize the array jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom @@ -367,7 +367,7 @@ def testSoftmaxWhereMask(self, fn): @parameterized.parameters([nn.softmax, nn.log_softmax]) def testSoftmaxWhereGrad(self, fn): - # regression test for https://github.com/google/jax/issues/19490 + # regression test for https://github.com/jax-ml/jax/issues/19490 x = jnp.array([36., 10000.]) mask = x < 1000 @@ -443,7 +443,7 @@ def testOneHotCustomDtype(self): self.assertAllClose(actual, expected) def testOneHotConcretizationError(self): - # https://github.com/google/jax/issues/3654 + # https://github.com/jax-ml/jax/issues/3654 msg = r"in jax.nn.one_hot argument `num_classes`" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): jax.jit(nn.one_hot)(3, 5) @@ -463,7 +463,7 @@ def testTanhExists(self): nn.tanh # doesn't crash def testCustomJVPLeak(self): - # https://github.com/google/jax/issues/8171 + # https://github.com/jax-ml/jax/issues/8171 @jax.jit def fwd(): a = jnp.array(1.) @@ -479,7 +479,7 @@ def f(hx, _): fwd() # doesn't crash def testCustomJVPLeak2(self): - # https://github.com/google/jax/issues/8171 + # https://github.com/jax-ml/jax/issues/8171 # The above test uses jax.nn.sigmoid, as in the original #8171, but that # function no longer actually has a custom_jvp! So we inline the old def. diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index e8cd40a67b88..f5dcff837838 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -20,7 +20,7 @@ "colab_type": "text" }, "source": [ - "\"Open" + "\"Open" ] }, { diff --git a/tests/notebooks/colab_gpu.ipynb b/tests/notebooks/colab_gpu.ipynb index 8352bdaf71bc..2335455e6cf2 100644 --- a/tests/notebooks/colab_gpu.ipynb +++ b/tests/notebooks/colab_gpu.ipynb @@ -7,7 +7,7 @@ "id": "view-in-github" }, "source": [ - "\"Open" + "\"Open" ] }, { diff --git a/tests/ode_test.py b/tests/ode_test.py index 834745e1cf1c..acdfa1fc6cef 100644 --- a/tests/ode_test.py +++ b/tests/ode_test.py @@ -139,7 +139,7 @@ def swoop(_np, y, t, arg1, arg2): @jtu.skip_on_devices("tpu", "gpu") def test_odeint_vmap_grad(self): - # https://github.com/google/jax/issues/2531 + # https://github.com/jax-ml/jax/issues/2531 def dx_dt(x, *args): return 0.1 * x @@ -169,7 +169,7 @@ def g(x): @jtu.skip_on_devices("tpu", "gpu") def test_disable_jit_odeint_with_vmap(self): - # https://github.com/google/jax/issues/2598 + # https://github.com/jax-ml/jax/issues/2598 with jax.disable_jit(): t = jnp.array([0.0, 1.0]) x0_eval = jnp.zeros((5, 2)) @@ -178,7 +178,7 @@ def test_disable_jit_odeint_with_vmap(self): @jtu.skip_on_devices("tpu", "gpu") def test_grad_closure(self): - # simplification of https://github.com/google/jax/issues/2718 + # simplification of https://github.com/jax-ml/jax/issues/2718 def experiment(x): def model(y, t): return -x * y @@ -188,7 +188,7 @@ def model(y, t): @jtu.skip_on_devices("tpu", "gpu") def test_grad_closure_with_vmap(self): - # https://github.com/google/jax/issues/2718 + # https://github.com/jax-ml/jax/issues/2718 @jax.jit def experiment(x): def model(y, t): @@ -209,7 +209,7 @@ def model(y, t): @jtu.skip_on_devices("tpu", "gpu") def test_forward_mode_error(self): - # https://github.com/google/jax/issues/3558 + # https://github.com/jax-ml/jax/issues/3558 def f(k): return odeint(lambda x, t: k*x, 1., jnp.linspace(0, 1., 50)).sum() @@ -219,7 +219,7 @@ def f(k): @jtu.skip_on_devices("tpu", "gpu") def test_closure_nondiff(self): - # https://github.com/google/jax/issues/3584 + # https://github.com/jax-ml/jax/issues/3584 def dz_dt(z, t): return jnp.stack([z[0], z[1]]) @@ -232,8 +232,8 @@ def f(z): @jtu.skip_on_devices("tpu", "gpu") def test_complex_odeint(self): - # https://github.com/google/jax/issues/3986 - # https://github.com/google/jax/issues/8757 + # https://github.com/jax-ml/jax/issues/3986 + # https://github.com/jax-ml/jax/issues/8757 def dy_dt(y, t, alpha): return alpha * y * jnp.exp(-t).astype(y.dtype) diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index b7710d9b94c2..c4eca070798c 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -260,7 +260,7 @@ def testUtilityClipGrads(self): self.assertAllClose(ans, expected, check_dtypes=False) def testIssue758(self): - # code from https://github.com/google/jax/issues/758 + # code from https://github.com/jax-ml/jax/issues/758 # this is more of a scan + jacfwd/jacrev test, but it lives here to use the # optimizers.py code diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index e2f0e2152dc5..7692294cd6df 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -178,7 +178,7 @@ def setUp(self): (1, 384, 8, 64, True, True, True, {}), (1, 384, 8, 64, True, True, False, {}), (2, 384, 8, 64, True, True, True, {}), - # regression test: https://github.com/google/jax/pull/17314 + # regression test: https://github.com/jax-ml/jax/pull/17314 (1, 384, 8, 64, True, False, False, {'block_q': 128, 'block_k': 64}), ] ] @@ -419,7 +419,7 @@ def test_softmax(self, shape, dtype): }[dtype] # We upcast to float32 because NumPy <2.0 does not handle custom dtypes - # properly. See https://github.com/google/jax/issues/11014. + # properly. See https://github.com/jax-ml/jax/issues/11014. np.testing.assert_allclose( softmax.softmax(x, axis=-1).astype(jnp.float32), jax.nn.softmax(x, axis=-1).astype(jnp.float32), diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 63c3148e8108..d8f890c06c32 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -798,7 +798,7 @@ def kernel(x_ref, o_ref): np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) def test_abs_weak_type(self): - # see https://github.com/google/jax/issues/23191 + # see https://github.com/jax-ml/jax/issues/23191 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32), ) @@ -999,7 +999,7 @@ def kernel(x_ref, o_ref): x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) # We upcast to float32 because NumPy <2.0 does not handle custom dtypes - # properly. See https://github.com/google/jax/issues/11014. + # properly. See https://github.com/jax-ml/jax/issues/11014. np.testing.assert_allclose( kernel(x).astype(jnp.float32), jnp.tanh(x).astype(jnp.float32), @@ -1260,7 +1260,7 @@ def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): np.testing.assert_array_equal(out, o_new) def test_strided_load(self): - # Reproducer from https://github.com/google/jax/issues/20895. + # Reproducer from https://github.com/jax-ml/jax/issues/20895. @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 5ee30ba3382a..aec7fd54c925 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -484,7 +484,7 @@ def kernel(o_ref): self.assertAllClose(pids[0:4], np.array([0] * 4, dtype=np.int32)) def test_hoisted_consts(self): - # See https://github.com/google/jax/issues/21557. + # See https://github.com/jax-ml/jax/issues/21557. # to_store will be hoisted as a constant. Choose distinct shapes from in/outs. to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128)) diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index 4b3f47e6f5c1..fefccfe7eb4f 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -209,7 +209,7 @@ def sin(x_ref, o_ref): @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.skip_on_devices("cpu") # Test is very slow on CPU def test_small_large_vmap(self): - # Catches https://github.com/google/jax/issues/18361 + # Catches https://github.com/jax-ml/jax/issues/18361 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), grid=(2,)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 57106948f7d3..c20084c3c8e2 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1510,7 +1510,7 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape): def test_custom_partitioner_with_scan(self): self.skip_if_custom_partitioning_not_supported() - # This is a reproducer from https://github.com/google/jax/issues/20864. + # This is a reproducer from https://github.com/jax-ml/jax/issues/20864. @custom_partitioning def f(x): @@ -1921,7 +1921,7 @@ def f(tree): self.assertArraysEqual(s.data, input_data) def test_sds_full_like(self): - # https://github.com/google/jax/issues/20390 + # https://github.com/jax-ml/jax/issues/20390 mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) x = jax.ShapeDtypeStruct((4, 4), jnp.float32, sharding=s) @@ -4113,7 +4113,7 @@ def f(*args): def test_spmd_preserves_input_sharding_vmap_grad(self): if config.use_shardy_partitioner.value: self.skipTest("Shardy doesn't support PositionalSharding") - # https://github.com/google/jax/issues/20710 + # https://github.com/jax-ml/jax/issues/20710 n_devices = jax.device_count() sharding = PositionalSharding(jax.devices()) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 8b121d91ae85..d7dcc7ba3cc4 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -122,7 +122,7 @@ def pmap(self): def testDeviceBufferToArray(self): sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2))) - # Changed in https://github.com/google/jax/pull/10584 not to access + # Changed in https://github.com/jax-ml/jax/pull/10584 not to access # sda.device_buffers, which isn't supported, and instead ensure fast slices # of the arrays returned by pmap are set up correctly. # buf = sda.device_buffers[-1] @@ -336,7 +336,7 @@ def test_jit_lower_compile_with_compiler_options_invalid(self): compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) def test_pmap_replicated_copy(self): - # https://github.com/google/jax/issues/17690 + # https://github.com/jax-ml/jax/issues/17690 inp = jnp.arange(jax.device_count()) x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(inp) out = jnp.copy(x) @@ -605,7 +605,7 @@ def f(x): self.assertAllClose(y, ref) def testNestedPmapAxisSwap(self): - # Regression test for https://github.com/google/jax/issues/5757 + # Regression test for https://github.com/jax-ml/jax/issues/5757 if jax.device_count() < 8: raise SkipTest("test requires at least 8 devices") f = jax.pmap(jax.pmap(lambda x: x, in_axes=1, out_axes=0), in_axes=0, @@ -1180,7 +1180,7 @@ def testPShuffleWithBadPerm(self): "`perm` does not represent a permutation: \\[1.*\\]", g) def testPpermuteWithZipObject(self): - # https://github.com/google/jax/issues/1703 + # https://github.com/jax-ml/jax/issues/1703 num_devices = jax.device_count() perm = [num_devices - 1] + list(range(num_devices - 1)) f = self.pmap(lambda x: lax.ppermute(x, "i", zip(perm, range(num_devices))), "i") @@ -1501,7 +1501,7 @@ def s(keys): self.assertEqual(ans.shape, (13, N_DEVICES)) def testVmapOfPmap3(self): - # https://github.com/google/jax/issues/3399 + # https://github.com/jax-ml/jax/issues/3399 device_count = jax.device_count() if device_count < 2: raise SkipTest("test requires at least two devices") @@ -1661,7 +1661,7 @@ def g(z): @ignore_jit_of_pmap_warning() def testIssue1065(self): - # from https://github.com/google/jax/issues/1065 + # from https://github.com/jax-ml/jax/issues/1065 device_count = jax.device_count() def multi_step_pmap(state, count): @@ -1697,7 +1697,7 @@ def testArrayGetItem(self): # replica. @unittest.skip("need eager multi-replica support") def testPostProcessMap(self): - # test came from https://github.com/google/jax/issues/1369 + # test came from https://github.com/jax-ml/jax/issues/1369 nrep = jax.device_count() def pmvm(a, b): @@ -1730,7 +1730,7 @@ def f(args_list): @jax.default_matmul_precision("float32") def testPostProcessMap2(self): - # code from https://github.com/google/jax/issues/2787 + # code from https://github.com/jax-ml/jax/issues/2787 def vv(x, y): """Vector-vector multiply""" return jnp.dot(x, y) @@ -1758,7 +1758,7 @@ def distributed_matrix_vector(x, y): ('_new', new_checkpoint), ]) def testAxisIndexRemat(self, remat): - # https://github.com/google/jax/issues/2716 + # https://github.com/jax-ml/jax/issues/2716 n = len(jax.devices()) def f(key): @@ -1769,7 +1769,7 @@ def f(key): self.pmap(remat(f), axis_name='i')(keys) def testPmapMapVmapCombinations(self): - # https://github.com/google/jax/issues/2822 + # https://github.com/jax-ml/jax/issues/2822 def vv(x, y): """Vector-vector multiply""" return jnp.dot(x, y) @@ -1802,7 +1802,7 @@ def matrix_vector(x, y, parallel=True): self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3) def testPmapAxisNameError(self): - # https://github.com/google/jax/issues/3120 + # https://github.com/jax-ml/jax/issues/3120 a = np.arange(4)[np.newaxis,:] def test(x): return jax.lax.psum(x, axis_name='batch') @@ -1811,7 +1811,7 @@ def test(x): self.pmap(test)(a) def testPsumOnBooleanDtype(self): - # https://github.com/google/jax/issues/3123 + # https://github.com/jax-ml/jax/issues/3123 n = jax.device_count() if n > 1: x = jnp.array([True, False]) @@ -1889,7 +1889,7 @@ def foo(x): return x + x self.assertIn("mhlo.num_partitions = 1", hlo) def testPsumZeroCotangents(self): - # https://github.com/google/jax/issues/3651 + # https://github.com/jax-ml/jax/issues/3651 def loss(params, meta_params): (net, mpo) = params return meta_params * mpo * net @@ -1914,7 +1914,7 @@ def outer(params): @ignore_jit_of_pmap_warning() def test_issue_1062(self): - # code from https://github.com/google/jax/issues/1062 @shoyer + # code from https://github.com/jax-ml/jax/issues/1062 @shoyer # this tests, among other things, whether ShardedDeviceTuple constants work device_count = jax.device_count() @@ -1938,7 +1938,7 @@ def test_replicate_backend(self): # TODO(skye): fix backend caching so we always have multiple CPUs available if jax.device_count("cpu") < 4: self.skipTest("test requires 4 CPU device") - # https://github.com/google/jax/issues/4223 + # https://github.com/jax-ml/jax/issues/4223 def fn(indices): return jnp.equal(indices, jnp.arange(3)).astype(jnp.float32) mapped_fn = self.pmap(fn, axis_name='i', backend='cpu') @@ -1982,7 +1982,7 @@ def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op): for dtype in [np.float32, np.int32] ) def testPmapDtype(self, dtype): - # Regression test for https://github.com/google/jax/issues/6022 + # Regression test for https://github.com/jax-ml/jax/issues/6022 @partial(self.pmap, axis_name='i') def func(_): return jax.lax.psum(dtype(0), axis_name='i') @@ -1991,7 +1991,7 @@ def func(_): self.assertEqual(out_dtype, dtype) def test_num_replicas_with_switch(self): - # https://github.com/google/jax/issues/7411 + # https://github.com/jax-ml/jax/issues/7411 def identity(x): return x @@ -2154,7 +2154,7 @@ def test_axis_name_shadowing_with_vmap(self): @jtu.run_on_devices("cpu") def test_pmap_stack_size(self): - # Regression test for https://github.com/google/jax/issues/20428 + # Regression test for https://github.com/jax-ml/jax/issues/20428 # pmap isn't particularly important here, but it guarantees that the CPU # client runs the computation on a threadpool rather than inline. if jax.device_count() < 2: @@ -2164,7 +2164,7 @@ def test_pmap_stack_size(self): y.block_until_ready() # doesn't crash def test_pmap_of_prng_key(self): - # Regression test for https://github.com/google/jax/issues/20392 + # Regression test for https://github.com/jax-ml/jax/issues/20392 keys = jax.random.split(jax.random.key(0), jax.device_count()) result1 = jax.pmap(jax.random.bits)(keys) with jtu.ignore_warning( diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 76854beaea3d..d9887cf7b482 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1245,7 +1245,7 @@ def f(shard_ids, x): np.testing.assert_array_equal(shard[0] + 1, shard[1]) def test_batching_with_side_effects(self): - # https://github.com/google/jax/issues/20628#issuecomment-2050800195 + # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195 x_lst = [] def append_x(x): nonlocal x_lst @@ -1261,7 +1261,7 @@ def f(x): self.assertAllClose(x_lst, [0., 1., 2., 0., 2., 4.], check_dtypes=False) def test_batching_with_side_effects_while_loop(self): - # https://github.com/google/jax/issues/20628#issuecomment-2050921219 + # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219 x_lst = [] def append_x(x): nonlocal x_lst diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index 8d00b5eedaf4..2e0fc32238ae 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -109,7 +109,7 @@ def testJaxArrayToTorch(self, shape, dtype): self.assertAllClose(np, y.cpu().numpy()) def testTorchToJaxInt64(self): - # See https://github.com/google/jax/issues/11895 + # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64))) dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32 diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index f69687ddc6cd..63510b7295d6 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -178,7 +178,7 @@ def testNormal(self, dtype): def testNormalBfloat16(self): # Passing bfloat16 as dtype string. - # https://github.com/google/jax/issues/6813 + # https://github.com/jax-ml/jax/issues/6813 res_bfloat16_str = random.normal(self.make_key(0), dtype='bfloat16') res_bfloat16 = random.normal(self.make_key(0), dtype=jnp.bfloat16) self.assertAllClose(res_bfloat16, res_bfloat16_str) @@ -391,7 +391,7 @@ def testBeta(self, a, b, dtype): @jtu.skip_on_devices("tpu") # TPU precision causes issues. def testBetaSmallParameters(self, dtype=np.float32): - # Regression test for beta version of https://github.com/google/jax/issues/9896 + # Regression test for beta version of https://github.com/jax-ml/jax/issues/9896 key = self.make_key(0) a, b = 0.0001, 0.0002 samples = random.beta(key, a, b, shape=(100,), dtype=dtype) @@ -441,7 +441,7 @@ def testDirichlet(self, alpha, dtype): @jtu.skip_on_devices("tpu") # lower accuracy leads to failures. def testDirichletSmallAlpha(self, dtype=np.float32): - # Regression test for https://github.com/google/jax/issues/9896 + # Regression test for https://github.com/jax-ml/jax/issues/9896 key = self.make_key(0) alpha = 0.00001 * jnp.ones(3) samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype) @@ -530,7 +530,7 @@ def testGammaGrad(self, log_space, alpha): rtol=rtol) def testGammaGradType(self): - # Regression test for https://github.com/google/jax/issues/2130 + # Regression test for https://github.com/jax-ml/jax/issues/2130 key = self.make_key(0) a = jnp.array(1., dtype=jnp.float32) b = jnp.array(3., dtype=jnp.float32) @@ -663,7 +663,7 @@ def testGeneralizedNormal(self, p, shape, dtype): ) def testGeneralizedNormalKS(self, p, shape, dtype): self.skipTest( # test is also sometimes slow, with (300, ...)-shape draws - "sensitive to random key - https://github.com/google/jax/issues/18941") + "sensitive to random key - https://github.com/jax-ml/jax/issues/18941") key = lambda: self.make_key(2) rand = lambda key, p: random.generalized_normal(key, p, (300, *shape), dtype) crand = jax.jit(rand) @@ -700,7 +700,7 @@ def testBall(self, d, p, shape, dtype): @jtu.skip_on_devices("tpu") # TPU precision causes issues. def testBallKS(self, d, p, shape, dtype): self.skipTest( - "sensitive to random key - https://github.com/google/jax/issues/18932") + "sensitive to random key - https://github.com/jax-ml/jax/issues/18932") key = lambda: self.make_key(123) rand = lambda key, p: random.ball(key, d, p, (100, *shape), dtype) crand = jax.jit(rand) @@ -800,7 +800,7 @@ def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size, assert samples.shape == shape + (dim,) def testMultivariateNormalCovariance(self): - # test code based on https://github.com/google/jax/issues/1869 + # test code based on https://github.com/jax-ml/jax/issues/1869 N = 100000 mean = jnp.zeros(4) cov = jnp.array([[ 0.19, 0.00, -0.13, 0.00], @@ -827,7 +827,7 @@ def testMultivariateNormalCovariance(self): @jtu.sample_product(method=['cholesky', 'eigh', 'svd']) @jtu.skip_on_devices('gpu', 'tpu') # Some NaNs on accelerators. def testMultivariateNormalSingularCovariance(self, method): - # Singular covariance matrix https://github.com/google/jax/discussions/13293 + # Singular covariance matrix https://github.com/jax-ml/jax/discussions/13293 mu = jnp.zeros((2,)) sigma = jnp.ones((2, 2)) key = self.make_key(0) @@ -889,7 +889,7 @@ def testDtypeErrorMessage(self): def testRandomBroadcast(self): """Issue 4033""" - # test for broadcast issue in https://github.com/google/jax/issues/4033 + # test for broadcast issue in https://github.com/jax-ml/jax/issues/4033 key = lambda: self.make_key(0) shape = (10, 2) with jax.numpy_rank_promotion('allow'): @@ -1071,7 +1071,7 @@ def test_randint_out_of_range(self): self.assertGreater((r == 255).sum(), 0) def test_large_prng(self): - # https://github.com/google/jax/issues/11010 + # https://github.com/jax-ml/jax/issues/11010 def f(): return random.uniform( self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) @@ -1086,7 +1086,7 @@ def f(): logits_shape_base=[(3, 4), (3, 1), (1, 4)], axis=[-3, -2, -1, 0, 1, 2]) def test_categorical_shape_argument(self, shape, logits_shape_base, axis): - # https://github.com/google/jax/issues/13124 + # https://github.com/jax-ml/jax/issues/13124 logits_shape = list(logits_shape_base) logits_shape.insert(axis % (len(logits_shape_base) + 1), 10) assert logits_shape[axis] == 10 diff --git a/tests/random_test.py b/tests/random_test.py index 941172f75278..da182dbccae9 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -436,7 +436,7 @@ def test_threefry_split_fold_in_symmetry(self, make_key): @skipIf(not config.threefry_partitionable.value, 'enable after upgrade') @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_threefry_split_vmapped_fold_in_symmetry(self, make_key): - # See https://github.com/google/jax/issues/7708 + # See https://github.com/jax-ml/jax/issues/7708 with jax.default_prng_impl('threefry2x32'): key = make_key(72) f1, f2, f3 = vmap(lambda k, _: random.fold_in(k, lax.axis_index('batch')), @@ -450,7 +450,7 @@ def test_threefry_split_vmapped_fold_in_symmetry(self, make_key): @skipIf(config.threefry_partitionable.value, 'changed random bit values') def test_loggamma_nan_corner_case(self): - # regression test for https://github.com/google/jax/issues/17922 + # regression test for https://github.com/jax-ml/jax/issues/17922 # This particular key previously led to NaN output. # If the underlying implementation ever changes, this test will no longer # exercise this corner case, so we compare to a particular output value @@ -545,7 +545,7 @@ def test_isinstance(self, make_key): @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_key_output_vjp(self, make_key): - # See https://github.com/google/jax/issues/14856 + # See https://github.com/jax-ml/jax/issues/14856 def f(seed): return make_key(seed) jax.vjp(f, 1) # doesn't crash @@ -578,7 +578,7 @@ class ThreefryPrngTest(jtu.JaxTestCase): partial(random.PRNGKey, impl='threefry2x32'), partial(random.key, impl='threefry2x32')]]) def test_seed_no_implicit_transfers(self, make_key): - # See https://github.com/google/jax/issues/15613 + # See https://github.com/jax-ml/jax/issues/15613 with jax.transfer_guard('disallow'): make_key(jax.device_put(42)) # doesn't crash @@ -922,14 +922,14 @@ def test_select(self): self.assertEqual(ys.shape, (3, 2)) def test_select_scalar_cond(self): - # regression test for https://github.com/google/jax/issues/16422 + # regression test for https://github.com/jax-ml/jax/issues/16422 ks = self.make_keys(3) ys = lax.select(True, ks, ks) self.assertIsInstance(ys, prng_internal.PRNGKeyArray) self.assertEqual(ys.shape, (3,)) def test_vmap_of_cond(self): - # See https://github.com/google/jax/issues/15869 + # See https://github.com/jax-ml/jax/issues/15869 def f(x): keys = self.make_keys(*x.shape) return lax.select(x, keys, keys) @@ -1126,7 +1126,7 @@ def test_key_spec_repr(self, name): self.assertEqual(repr(spec), f"PRNGSpec({name!r})") def test_keyarray_custom_vjp(self): - # Regression test for https://github.com/google/jax/issues/18442 + # Regression test for https://github.com/jax-ml/jax/issues/18442 @jax.custom_vjp def f(_, state): return state diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index 701b7c570937..dd34a99a73b8 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -129,7 +129,7 @@ def args_maker(): self._CheckAgainstNumpy(osp_op, lsp_op, args_maker) def testContinuousGradients(self): - # regression test for https://github.com/google/jax/issues/3024 + # regression test for https://github.com/jax-ml/jax/issues/3024 def loss(delta): x = np.arange(100.0) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index ad988bba62d3..983cb6bdc37a 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -314,7 +314,7 @@ def args_maker(): rtol={np.float32: 2e-3, np.float64: 1e-4}) def testBetaLogPdfZero(self): - # Regression test for https://github.com/google/jax/issues/7645 + # Regression test for https://github.com/jax-ml/jax/issues/7645 a = b = 1. x = np.array([0., 1.]) self.assertAllClose( @@ -539,7 +539,7 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker) def testGammaLogPdfZero(self): - # Regression test for https://github.com/google/jax/issues/7256 + # Regression test for https://github.com/jax-ml/jax/issues/7256 self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) @@ -710,7 +710,7 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker) def testLogisticLogpdfOverflow(self): - # Regression test for https://github.com/google/jax/issues/10219 + # Regression test for https://github.com/jax-ml/jax/issues/10219 self.assertAllClose( np.array([-100, -100], np.float32), lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)), @@ -855,7 +855,7 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker) def testNormSfNearZero(self): - # Regression test for https://github.com/google/jax/issues/17199 + # Regression test for https://github.com/jax-ml/jax/issues/17199 value = np.array(10, np.float32) self.assertAllClose(osp_stats.norm.sf(value).astype('float32'), lsp_stats.norm.sf(value), @@ -1208,7 +1208,7 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) def testBinomPmfOutOfRange(self): - # Regression test for https://github.com/google/jax/issues/19150 + # Regression test for https://github.com/jax-ml/jax/issues/19150 self.assertEqual(lsp_stats.binom.pmf(k=6.5, n=5, p=0.8), 0.0) def testBinomLogPmfZerokZeron(self): diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 27199c874332..77e5273d172a 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2973,7 +2973,7 @@ def test_vmap_error(self): (2, x.shape[0]), (1, 1), "VALID"), arg_descriptors=[RandArg((3, 8), _f32)], polymorphic_shapes=["b, ..."]), - # https://github.com/google/jax/issues/11804 + # https://github.com/jax-ml/jax/issues/11804 # Use the reshape trick to simulate a polymorphic dimension of 16*b. # (See test "conv_general_dilated.1d_1" above for more details.) PolyHarness("reduce_window", "add_monoid_strides_window_size=static", diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index ae22eeca0cd4..20bc33475e14 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -724,7 +724,7 @@ def f(x): self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) def test_nested_vmap_with_capture_spmd_axis_name(self): - self.skipTest('https://github.com/google/jax/issues/23476') + self.skipTest('https://github.com/jax-ml/jax/issues/23476') mesh = jtu.create_mesh((2, 2), ('x', 'y')) def to_map_with_capture(x, y): @@ -902,7 +902,7 @@ def f(_): @jax.legacy_prng_key('allow') def test_prngkeyarray_eager(self): - # https://github.com/google/jax/issues/15398 + # https://github.com/jax-ml/jax/issues/15398 mesh = jtu.create_mesh((4,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) @@ -1069,7 +1069,7 @@ def foo(): self.assertEqual(out, 1.) def test_jaxpr_shardings_with_no_outputs(self): - # https://github.com/google/jax/issues/15385 + # https://github.com/jax-ml/jax/issues/15385 mesh = jtu.create_mesh((4,), ('i',)) @jax.jit @@ -1109,7 +1109,7 @@ def g(x): @jtu.run_on_devices('cpu', 'gpu', 'tpu') def test_key_array_with_replicated_last_tile_dim(self): - # See https://github.com/google/jax/issues/16137 + # See https://github.com/jax-ml/jax/issues/16137 mesh = jtu.create_mesh((2, 4), ('i', 'j')) @@ -1690,7 +1690,7 @@ def g(x): self.assertAllClose(grad, jnp.ones(4) * 4 * 4, check_dtypes=False) def test_repeated_psum_allowed(self): - # https://github.com/google/jax/issues/19175 + # https://github.com/jax-ml/jax/issues/19175 mesh = jtu.create_mesh((4,), 'i') @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) @@ -1927,7 +1927,7 @@ def f(): self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) def test_vmap_grad_shmap_spmd_axis_name_residuals(self): - # https://github.com/google/jax/pull/21032 + # https://github.com/jax-ml/jax/pull/21032 mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial( @@ -1944,7 +1944,7 @@ def f(x): jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self): - # https://github.com/google/jax/pull/21056 + # https://github.com/jax-ml/jax/pull/21056 mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial(jax.remat, policy=lambda *_, **__: True) @@ -1962,7 +1962,7 @@ def f(x): jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash def test_grad_shmap_residuals_axis_names_in_mesh_order(self): - # https://github.com/google/jax/issues/21236 + # https://github.com/jax-ml/jax/issues/21236 mesh = jtu.create_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a')) @partial( @@ -2108,7 +2108,7 @@ def test_in_spec_none_rank_errors(self): )((object(), object()), x) def test_custom_linear_solve_rep_rules(self): - # https://github.com/google/jax/issues/20162 + # https://github.com/jax-ml/jax/issues/20162 mesh = jtu.create_mesh((1,), ('i',)) a = jnp.array(1).reshape(1, 1) b = jnp.array(1).reshape(1) diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index 545d73bff291..38fde72f0440 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -310,7 +310,7 @@ def test_bcoo_extract_duplicate_indices_n_sparse_0(self): self.assertArraysEqual(data2, jnp.array([[1, 0], [5, 0], [9, 0]])) def test_bcoo_extract_batching(self): - # https://github.com/google/jax/issues/9431 + # https://github.com/jax-ml/jax/issues/9431 indices = jnp.zeros((4, 1, 1), dtype=int) mat = jnp.arange(4.).reshape((4, 1)) @@ -353,7 +353,7 @@ def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense): self.assertEqual(hess.shape, data.shape + 2 * M.shape) def test_bcoo_extract_zero_nse(self): - # Regression test for https://github.com/google/jax/issues/13653 + # Regression test for https://github.com/jax-ml/jax/issues/13653 # (n_batch, n_sparse, n_dense) = (1, 0, 0), nse = 2 args_maker = lambda: (jnp.zeros((3, 2, 0), dtype='int32'), jnp.arange(3)) @@ -974,7 +974,7 @@ def test_bcoo_spdot_general_nse(self, lhs_shape, rhs_shape): self.assertEqual(out.nse, expected_nse) def test_bcoo_spdot_general_ad_bug(self): - # Regression test for https://github.com/google/jax/issues/10163 + # Regression test for https://github.com/jax-ml/jax/issues/10163 A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]]) A_values = jnp.array([-2.0, 1.0, -1.0, 0.5, 2.0]) A_shape = (2, 3) @@ -1287,7 +1287,7 @@ def test_bcoo_sum_duplicates_remove_zeros(self): self.assertEqual(y2.nse, x.nse) def test_bcoo_sum_duplicates_padding(self): - # Regression test for https://github.com/google/jax/issues/8163 + # Regression test for https://github.com/jax-ml/jax/issues/8163 size = 3 data = jnp.array([1, 0, 0]) indices = jnp.array([1, size, size])[:, None] @@ -1606,7 +1606,7 @@ def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, lhs_n self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol) def test_bcoo_mul_sparse_with_duplicates(self): - # Regression test for https://github.com/google/jax/issues/8888 + # Regression test for https://github.com/jax-ml/jax/issues/8888 indices = jnp.array([[0, 1, 0, 0, 1, 1], [1, 0, 1, 2, 0, 2]]).T data = jnp.array([1, 2, 3, 4, 5, 6]) @@ -1940,7 +1940,7 @@ def test_bcsr_concatenate(self, shape, dtype, n_batch, n_dense, dimension): self._CheckGradsSparse(dense_func, sparse_func, args_maker) def test_bcoo_spdot_abstract_eval_bug(self): - # Regression test for https://github.com/google/jax/issues/21921 + # Regression test for https://github.com/jax-ml/jax/issues/21921 lhs = sparse.BCOO( (jnp.float32([[1]]), lax.broadcasted_iota(jnp.int32, (10, 1, 1), 0)), shape=(10, 10)) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 616396222ec6..eb8d70be1f05 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -323,7 +323,7 @@ def test_coo_matmat(self, shape, dtype, transpose): self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=sptu.MATMUL_TOL) def test_coo_matmat_layout(self): - # Regression test for https://github.com/google/jax/issues/7533 + # Regression test for https://github.com/jax-ml/jax/issues/7533 d = jnp.array([1.0, 2.0, 3.0, 4.0]) i = jnp.array([0, 0, 1, 2]) j = jnp.array([0, 2, 0, 0]) diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index 46086511d8b5..46c2f5aafbf6 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -610,7 +610,7 @@ def func(M): self.assertArraysEqual(jit(func)(Msp).todense(), expected) def testWeakTypes(self): - # Regression test for https://github.com/google/jax/issues/8267 + # Regression test for https://github.com/jax-ml/jax/issues/8267 M = jnp.arange(12, dtype='int32').reshape(3, 4) Msp = BCOO.fromdense(M) self.assertArraysEqual( diff --git a/tests/stax_test.py b/tests/stax_test.py index 6850f36a02ea..e21300ddd119 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -216,7 +216,7 @@ def testBatchNormShapeNHWC(self): def testBatchNormShapeNCHW(self): key = random.PRNGKey(0) - # Regression test for https://github.com/google/jax/issues/461 + # Regression test for https://github.com/jax-ml/jax/issues/461 init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3)) input_shape = (4, 5, 6, 7) diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index bc741702ce58..f8792a263117 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -343,7 +343,7 @@ def f(a, b, c): pass self.assertEqual(h.args, (3,)) def testPartialFuncAttributeHasStableHash(self): - # https://github.com/google/jax/issues/9429 + # https://github.com/jax-ml/jax/issues/9429 fun = functools.partial(print, 1) p1 = tree_util.Partial(fun, 2) p2 = tree_util.Partial(fun, 2) @@ -359,7 +359,7 @@ def testChildren(self): self.assertEqual([c0, c1], tree.children()) def testTreedefTupleFromChildren(self): - # https://github.com/google/jax/issues/7377 + # https://github.com/jax-ml/jax/issues/7377 tree = ((1, 2, (3, 4)), (5,)) leaves, treedef1 = tree_util.tree_flatten(tree) treedef2 = tree_util.treedef_tuple(treedef1.children()) @@ -368,7 +368,7 @@ def testTreedefTupleFromChildren(self): self.assertEqual(treedef1.num_nodes, treedef2.num_nodes) def testTreedefTupleComparesEqual(self): - # https://github.com/google/jax/issues/9066 + # https://github.com/jax-ml/jax/issues/9066 self.assertEqual(tree_util.tree_structure((3,)), tree_util.treedef_tuple((tree_util.tree_structure(3),))) @@ -978,7 +978,7 @@ def testEmpty(self): self.assertAllClose(tree, tree_, atol=0., rtol=0.) def testDtypePolymorphicUnravel(self): - # https://github.com/google/jax/issues/7809 + # https://github.com/jax-ml/jax/issues/7809 x = jnp.arange(10, dtype=jnp.float32) x_flat, unravel = flatten_util.ravel_pytree(x) y = x_flat < 5.3 @@ -987,7 +987,7 @@ def testDtypePolymorphicUnravel(self): @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testDtypeMonomorphicUnravel(self): - # https://github.com/google/jax/issues/7809 + # https://github.com/jax-ml/jax/issues/7809 x1 = jnp.arange(10, dtype=jnp.float32) x2 = jnp.arange(10, dtype=jnp.int32) x_flat, unravel = flatten_util.ravel_pytree((x1, x2)) diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index 58cf4a2baae3..d4403b7e5e30 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -128,7 +128,7 @@ def test_jit_cache(self): @unittest.skip("test fails, see #8552") def test_convert_element_type(self): - # Regression test for part of https://github.com/google/jax/issues/5982 + # Regression test for part of https://github.com/jax-ml/jax/issues/5982 with enable_x64(): x = jnp.int64(1) self.assertEqual(x.dtype, jnp.int64) From c4c30e1cfd0b006e3c00c7cf419d257fa2e47493 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Sep 2024 14:53:30 +0000 Subject: [PATCH 582/702] Bump actions/upload-artifact from 4.3.6 to 4.4.0 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.3.6 to 4.4.0. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/834a144ee995460fba8ed112a2fc961b36a5ec5a...50769540e7f4bd5e21e526ee35c689e35e0d6874) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/upstream-nightly.yml | 2 +- .github/workflows/wheel_win_x64.yml | 2 +- .github/workflows/windows_ci.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 74cb45920949..79c1e22d2cea 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -85,7 +85,7 @@ jobs: && steps.status.outcome == 'failure' && github.event_name == 'schedule' && github.repository == 'jax-ml/jax' - uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 + uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4 with: name: output-${{ matrix.python-version }}-log.jsonl path: output-${{ matrix.python-version }}-log.jsonl diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 367f8e05bf56..0195032ceaf1 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -45,7 +45,7 @@ jobs: --bazel_options=--config=win_clang ` --verbose - - uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 + - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4 with: name: wheels-${{ matrix.os }}-${{ matrix.pyver }} path: ${{ github.workspace }}\dist\*.whl diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 194cac6fa79a..7097d5589426 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -53,7 +53,7 @@ jobs: --bazel_options=--color=yes ` --bazel_options=--config=win_clang - - uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 + - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4 with: name: wheels path: ${{ github.workspace }}\jax\dist\*.whl From bc80ecbbe48289af4135054d421277819bf17e0d Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 20 Sep 2024 08:30:04 -0700 Subject: [PATCH 583/702] Remove forward compatibility checks from cholesky_update lowering. The forward compatibility window has ended and it should be safe to remove these checks. PiperOrigin-RevId: 676853740 --- jax/_src/lax/linalg.py | 19 +++---------------- jaxlib/gpu_linalg.py | 40 ---------------------------------------- 2 files changed, 3 insertions(+), 56 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 453e79a5c7f8..ef6a5a11a56e 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -44,7 +44,6 @@ from jax._src.lax.lax import ( standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex, _input_dtype) -from jax._src.lib import gpu_linalg from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack @@ -54,6 +53,9 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, ArrayLike +# The following import is unused but needed to register the custom_call targets +# in the gpu_linalg module. +from jax._src.lib import gpu_linalg # noqa: F401 TFun = TypeVar('TFun', bound=Callable[..., Any]) @@ -551,21 +553,6 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector): return ShapedArray(r_matrix.shape, r_matrix.dtype) def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector): - # TODO(b/360781533): Remove guard after 3 week forward compatibility period. - if ctx.is_forward_compat(): - r_matrix_aval, _ = ctx.avals_in - try: - [platform] = ctx.module_context.platforms - except ValueError: - raise ValueError( - "Can only lower cholesky_update on a single platform." - ) from None - if platform != "cuda": - raise NotImplementedError( - "Can only lower fast cholesky_update on CUDA." - ) - return gpu_linalg.cuda_cholesky_update( - r_matrix, w_vector, r_matrix_aval.dtype) rule = ffi.ffi_lowering(f"{target_name_prefix}_cholesky_update_ffi", operand_output_aliases={0: 0, 1: 1}) sub_ctx = ctx.replace(avals_out=ctx.avals_in) diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 88b7ff463800..9dedc86e4355 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -12,16 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from functools import partial import importlib -import numpy as np -import operator - -import jaxlib.mlir.ir as ir - -from .hlo_helpers import custom_call -from .gpu_common_utils import GpuLibNotLinkedError from jaxlib import xla_client @@ -62,34 +53,3 @@ xla_client.register_custom_call_target( _name, _value, platform="ROCM", api_version=api_version ) - -_prod = lambda xs: functools.reduce(operator.mul, xs, 1) - - -def _cholesky_update_hlo(platform, gpu_linalg, r_matrix, w_vector, dtype): - """Cholesky update.""" - del platform - r_type = ir.RankedTensorType(r_matrix.type) - dims = r_type.shape - assert dims[0] == dims[1] - n = dims[0] - - if not gpu_linalg: - raise GpuLibNotLinkedError() - - np_type = np.dtype(dtype) - opaque = gpu_linalg.build_cholesky_update_descriptor(np_type, n) - - return custom_call( - "cu_cholesky_update", - operands = [r_matrix, w_vector], - result_types=[ - ir.RankedTensorType.get((n, n), r_type.element_type), - ir.RankedTensorType.get((n,), r_type.element_type), - ], - operand_output_aliases={0: 0, 1: 1}, - backend_config=opaque, - ).results[:1] - - -cuda_cholesky_update = partial(_cholesky_update_hlo, "cu", _cuda_linalg) From 99195ead83e1c854c34f13f978426e212f729907 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 20 Sep 2024 08:35:49 -0700 Subject: [PATCH 584/702] [Mosaic TPU] Try reducing sublane tiling to support more vector.shape_casts In particular, 32-bit values should now support all reshapes that do not modify the last dimension. PiperOrigin-RevId: 676855401 --- .../tpu/transforms/apply_vector_layout.cc | 3 +- .../tpu/transforms/infer_vector_layout.cc | 40 ++++++++++++------- tests/pallas/tpu_pallas_test.py | 4 -- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 9b21ec1803c8..951ed59865c7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5807,7 +5807,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { // TODO: b/342235360 - This check is temporary while we increase and test // support for offsets outside of the first tile. When support is more broad, // any op without support should check it within their own rule. - if (!isa(op)) { + if (!isa(op)) { for (const Layout &layout : layouts_in) { if (layout && layout->offsets()[1].has_value() && layout->offsets()[1].value() >= layout->tiling()[1]) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 4d8f3db71027..c5a43e898cd0 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1354,6 +1354,9 @@ class VectorLayoutInferer { // TODO(tlongeri): Be smarter about trying implicit dims. We should probably // only add them when folding dimensions, and remove them when unfolding. + // The ordering of candidate implicit dims is important! Inserting an + // implicit second minor can make a reshape possible, but also very + // inefficient. We should always prefer to try with None first. SmallVector candidate_implicit_dims; if (res_shape.size() >= 2) { candidate_implicit_dims.push_back(ImplicitDim::kNone); @@ -1382,21 +1385,28 @@ class VectorLayoutInferer { for (const ImplicitDim implicit_dim : candidate_implicit_dims) { const std::array res_tiled_ishape = VectorLayout::getImplicitTiledDims(implicit_dim, res_shape, 1); - // Sublane (un)folding. - if (src_tiled_ishape[1] == res_tiled_ishape[1] && - src_tiled_ishape[0] % vreg_slice[0] == 0 && - res_tiled_ishape[0] % vreg_slice[0] == 0) { - // TODO(b/343808585): We shouldn't force second minor offset to 0 when - // unfolding, it's still a no-op, but we need to add - // support in apply-vector-layout. - const LayoutOffsets offsets = {0, layout.offsets()[1]}; - setLayout(op, - VectorLayout(layout.bitwidth(), offsets, layout.tiling(), - layout.implicit_dim()), - VectorLayout(layout.bitwidth(), offsets, layout.tiling(), - implicit_dim)); - return success(); - } + // Sublane (un)folding. We attempt to reduce the sublane tiling, which + // might make this reshape a no-op. We use do-while to handle the packed + // 1D tilings that use 1 in the sublane dimension. + int64_t sublane_tiling = vreg_slice[0]; + do { + if (src_tiled_ishape[1] == res_tiled_ishape[1] && + src_tiled_ishape[0] % sublane_tiling == 0 && + res_tiled_ishape[0] % sublane_tiling == 0) { + std::array tiling = {sublane_tiling, target_shape_[1]}; + // TODO(b/343808585): We shouldn't force second minor offset to 0 when + // unfolding, it's still a no-op, but we need to + // add support in apply-vector-layout. + LayoutOffsets offsets = {0, layout.offsets()[1]}; + setLayout(op, + VectorLayout(layout.bitwidth(), offsets, tiling, + layout.implicit_dim()), + VectorLayout(layout.bitwidth(), offsets, tiling, + implicit_dim)); + return success(); + } + sublane_tiling /= 2; + } while (sublane_tiling >= layout.packing()); // Lane (un)folding. if (src_tiled_ishape[1] != res_tiled_ishape[1] && src_tiled_ishape[1] % layout.tiling()[1] == 0 && diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 5bcf0964419f..84403e41b561 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2425,9 +2425,7 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.reshape(x, (1, 256, 8, 128))) - @only_passes_in_interpret() def test_lane_to_chunk_broadcast_fp32(self): - """b/348033362""" x = np.arange(256 * 128, dtype=jnp.float32).reshape(1, 256, 128) def kernel(x_ref, out_ref): @@ -2524,9 +2522,7 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x[:, 7, :], (1, 8, 128))) - @only_passes_in_interpret() def test_sublane_adding_shape_cast_f32(self): - """b/352833257""" x = np.arange(8 * 128, dtype=jnp.float32).reshape(8, 128) def kernel(x_ref, out_ref): From 81b8b4b7b4aaa80e0a68d1747055306f948182aa Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 20 Sep 2024 08:41:30 -0700 Subject: [PATCH 585/702] [Mosaic GPU] Clean up the module structure Previously the code was awkwardly split between the `jax.experimental.mosaic.gpu` and `jax.experimental.mosaic.gpu.dsl` namespaces. I've now merged both so that all user-visible APIs are accessible from `jax.experimental.mosaic.gpu`. PiperOrigin-RevId: 676857257 --- jax/_src/pallas/mosaic_gpu/core.py | 8 +- jax/_src/pallas/mosaic_gpu/lowering.py | 13 +- .../mosaic_gpu/pallas_call_registration.py | 4 +- jax/experimental/mosaic/gpu/__init__.py | 1013 +---------------- jax/experimental/mosaic/gpu/core.py | 979 ++++++++++++++++ jax/experimental/mosaic/gpu/dsl.py | 58 - .../mosaic/gpu/examples/flash_attention.py | 28 +- .../mosaic/gpu/examples/matmul.py | 17 +- .../mosaic/gpu/fragmented_array.py | 2 +- jax/experimental/mosaic/gpu/wgmma.py | 4 +- tests/mosaic/gpu_test.py | 99 +- 11 files changed, 1111 insertions(+), 1114 deletions(-) create mode 100644 jax/experimental/mosaic/gpu/core.py delete mode 100644 jax/experimental/mosaic/gpu/dsl.py diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index dc698b8747d9..6ef4cd1621f4 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -23,7 +23,7 @@ from jax._src import dtypes from jax._src import tree_util from jax._src.pallas import core as pallas_core -from jax.experimental.mosaic import gpu as mosaic_gpu +import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp @@ -64,7 +64,7 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol): - def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform: + def to_gpu_transform(self) -> mgpu.MemRefTransform: ... @@ -101,8 +101,8 @@ def __call__( inner_aval=block_aval.inner_aval.update(shape=new_block_shape) ) - def to_gpu_transform(self) -> mosaic_gpu.MemRefTransform: - return mosaic_gpu.TileTransform(self.tiling) + def to_gpu_transform(self) -> mgpu.MemRefTransform: + return mgpu.TileTransform(self.tiling) @dataclasses.dataclass(frozen=True) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 4b8199b36105..6eae64b7affa 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -39,8 +39,7 @@ from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.state import primitives as sp -from jax.experimental.mosaic import gpu as mosaic_gpu -from jax.experimental.mosaic.gpu import dsl as mgpu +import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp import numpy as np @@ -160,7 +159,7 @@ def stack_free_smem(self, bytes: int): @dataclasses.dataclass(frozen=True) class LoweringRuleContext: module_ctx: ModuleContext - launch_ctx: mosaic_gpu.LaunchContext + launch_ctx: mgpu.LaunchContext avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] @@ -180,7 +179,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name def _eval_index_map( module_ctx: ModuleContext, - launch_ctx: mosaic_gpu.LaunchContext, + launch_ctx: mgpu.LaunchContext, idx: ir.Value, block_mapping: pallas_core.BlockMapping, ) -> Sequence[ir.Value]: @@ -300,7 +299,7 @@ def lower_jaxpr_to_module( ) ] - def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers: ir.Value): + def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): *buffers_gmem, ( buffers_smem, *scratch_buffers_smem, @@ -494,7 +493,7 @@ def _(step, _): jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8) ) - module, out_structs_smem, _ = mosaic_gpu._lower_as_gpu_kernel( + module, out_structs_smem, _ = mgpu._lower_as_gpu_kernel( body, grid=grid, cluster=(), @@ -528,7 +527,7 @@ def deco(fn): def lower_jaxpr_to_mosaic_gpu( module_ctx: ModuleContext, - launch_ctx: mosaic_gpu.LaunchContext, + launch_ctx: mgpu.LaunchContext, jaxpr: jax_core.Jaxpr, args: Sequence[ir.Value], consts=(), diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 5b09cad176a6..510d4032f3dd 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -23,7 +23,7 @@ from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering -from jax.experimental.mosaic import gpu as mosaic_gpu +import jax.experimental.mosaic.gpu.core as mosaic_core def pallas_call_lowering( @@ -67,7 +67,7 @@ def pallas_call_lowering( print(lowering_result.module.operation) module = lowering_result.module - return mosaic_gpu._mosaic_gpu_lowering_rule( + return mosaic_core._mosaic_gpu_lowering_rule( ctx, *args, module=module.operation.get_asm(binary=True, enable_debug_info=True), diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 0e263844b18e..21c7f666b233 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -13,967 +13,52 @@ # limitations under the License. # ============================================================================== -from collections.abc import Callable, Sequence -import contextlib -import ctypes -import dataclasses -import functools -import hashlib -import itertools -import math -import os -import pathlib -import subprocess -import tempfile -import time -from typing import Any, Generic, TypeVar -import weakref - -import jax -from jax._src import config -from jax._src import core as jax_core -from jax._src.interpreters import mlir -from jax._src.lib import xla_client -from jaxlib.mlir import ir -from jaxlib.mlir.dialects import arith -from jaxlib.mlir.dialects import builtin -from jaxlib.mlir.dialects import func -from jaxlib.mlir.dialects import gpu -from jaxlib.mlir.dialects import llvm -from jaxlib.mlir.dialects import memref -from jaxlib.mlir.dialects import nvvm -from jaxlib.mlir.passmanager import PassManager -import numpy as np - -from . import profiler -from . import utils - -# mypy: ignore-errors - -# MLIR can't find libdevice unless we point it to the CUDA path -# TODO(apaszke): Unify with jax._src.lib.cuda_path -CUDA_ROOT = "/usr/local/cuda" -if os.environ.get("CUDA_ROOT") is None: - os.environ["CUDA_ROOT"] = CUDA_ROOT -else: - CUDA_ROOT = os.environ["CUDA_ROOT"] - -PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") -NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") - -TMA_DESCRIPTOR_BYTES = 128 -TMA_DESCRIPTOR_ALIGNMENT = 64 - - -c = utils.c # This is too common to fully qualify. - - -RUNTIME_PATH = None -try: - from jax._src.lib import mosaic_gpu as mosaic_gpu_lib - - RUNTIME_PATH = ( - pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent - / "libmosaic_gpu_runtime.so" - ) -except ImportError: - pass - -if RUNTIME_PATH and RUNTIME_PATH.exists(): - # Set this so that the custom call can find it - os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) - - -mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") -mosaic_gpu_p.multiple_results = True - - -@mosaic_gpu_p.def_abstract_eval -def _mosaic_gpu_abstract_eval(*_, module, out_types): - del module # Unused. - return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] - -# TODO(apaszke): Implement a proper system for managing kernel lifetimes -KNOWN_KERNELS = {} - -def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): - del out_types # Unused. - kernel_id = hashlib.sha256(module).digest() - # Note that this is technically only a half measure. Someone might load a - # compiled module with a hash collision from disk. But that's so unlikely with - # SHA256 that it shouldn't be a problem. - if (kernel_text := KNOWN_KERNELS.get(kernel_id, None)) is not None: - if kernel_text != module: - raise RuntimeError("Hash collision!") - else: - KNOWN_KERNELS[kernel_id] = module - op = mlir.custom_call( - "mosaic_gpu", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=args, - operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], - result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], - backend_config=kernel_id + module, - ) - return op.results - -mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") - - -@dataclasses.dataclass(frozen=True) -class MemRefTransform: - def apply(self, ref: ir.Value) -> ir.Value: - raise NotImplementedError("Subclasses should override this method") - - def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: - raise NotImplementedError("Subclasses should override this method") - - def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: - raise NotImplementedError("Subclasses should override this method") - - -@dataclasses.dataclass(frozen=True) -class TileTransform(MemRefTransform): - """Tiles a suffix of memref dimensions. - - For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), - the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with - the tile shape, and the size of tiled dimensions is divided by the tile size. - This is especially useful for swizzled WGMMA, which expect tiled layouts in - shared memory. - """ - tiling: tuple[int, ...] - - def apply(self, ref: ir.Value) -> ir.Value: - untiled_rank = ir.MemRefType(ref.type).rank - tiling_rank = len(self.tiling) - tiled_rank = untiled_rank + tiling_rank - for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]): - s = ir.MemRefType(ref.type).shape[d] - if s % t and s > t: - raise ValueError( - f"Dimension {d} must have size smaller or a multiple of its tiling" - f" {t}, but got {s}" - ) - ref = utils.memref_unfold(ref, d, (None, min(t, s))) - permutation = ( - *range(untiled_rank - tiling_rank), - *range(untiled_rank - tiling_rank, tiled_rank, 2), - *range(untiled_rank - tiling_rank + 1, tiled_rank, 2), - ) - return utils.memref_transpose(ref, permutation) - - def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: - index = ir.IndexType.get() - tiling_rank = len(self.tiling) - return ( - *idx[:-tiling_rank], - *( - arith.divui(i, c(t, index)) - for i, t in zip(idx[-tiling_rank:], self.tiling) - ), - *( - arith.remui(i, c(t, index)) - for i, t in zip(idx[-tiling_rank:], self.tiling) - ), - ) - - def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: - # Note that this also checks that tiled dims are not squeezed. Their slice - # size would be 1 if so. - tiling_rank = len(self.tiling) - for size, tile_size in zip(shape[-tiling_rank:], self.tiling): - if size % tile_size: - raise ValueError( - f"Expected GMEM slice shape {shape} suffix to be a multiple of" - f" tiling {self.tiling}.\nIf you're using padded async copies, your" - " slice might need to extend out of bounds of the GMEM buffer (OOB" - " accesses will be skipped)." - ) - return ( - *shape[:-tiling_rank], - *(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)), - *self.tiling, - ) - - -@dataclasses.dataclass(frozen=True) -class TransposeTransform(MemRefTransform): - """Transposes memref dimensions.""" - permutation: tuple[int, ...] - - def __post_init__(self): - if len(self.permutation) != len(set(self.permutation)): - raise ValueError("Permutation must be a permutation") - - def apply(self, ref: ir.Value) -> ir.Value: - return utils.memref_transpose(ref, self.permutation) - - def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: - return tuple(idx[p] for p in self.permutation) - - def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: - return tuple(shape[p] for p in self.permutation) - - -OnDeviceProfiler = profiler.OnDeviceProfiler - - -@dataclasses.dataclass() -class LaunchContext: - launch_op: gpu.LaunchOp - gmem_scratch_ptr: ir.Value - cluster_size: tuple[int, int, int] - profiler: OnDeviceProfiler | None = None - next_scratch_offset: int = 0 - host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( - default_factory=list, init=False - ) - tma_descriptors: dict[ - tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], - ir.Value, - ] = dataclasses.field(default_factory=dict, init=False) - - @contextlib.contextmanager - def named_region(self, *args, **kwargs): - if self.profiler is not None: - with self.profiler.record(*args, **kwargs): - yield - else: - yield - - def _alloc_scratch( - self, - size: int, - alignment: int | None = None, - host_init: Callable[[ir.Value], None] = lambda _: None, - device_init: Callable[[ir.Value], Any] = lambda x: x, - ) -> ir.Value: - """Allocates a GMEM scratch buffer. - - The buffer is initialized on the host and then copied to GMEM before the - kernel launch. - """ - i8 = ir.IntegerType.get_signless(8) - ptr_ty = ir.Type.parse("!llvm.ptr") - if alignment is None: - alignment = size - if self.next_scratch_offset % alignment: - raise NotImplementedError # TODO(apaszke): Pad to match alignment - alloc_base = self.next_scratch_offset - self.next_scratch_offset += size - def host_init_wrapped(host_ptr): - host_init( - llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) - ) - self.host_scratch_init.append(host_init_wrapped) - # with ir.InsertionPoint(self.gmem_scratch_ptr.owner): - # There is no way to create an insertion point after an operation... - gep = llvm.GEPOp( - ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 - ) - gep.move_after(self.gmem_scratch_ptr.owner) - return device_init(gep.result) - - def _get_tma_desc( - self, - gmem_ref, - gmem_transform: tuple[MemRefTransform, ...], - transformed_slice_shape: tuple[int, ...], - swizzle: int | None, - ): - tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) - if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: - i64 = ir.IntegerType.get_signless(64) - ptr_ty = ir.Type.parse("!llvm.ptr") - def init_tma_desc(host_ptr): - ref = gmem_ref - for t in gmem_transform: - ref = t.apply(ref) - ref_ty = ir.MemRefType(ref.type) - # TODO(apaszke): Use utils.memref_ptr to compute base_ptr - _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) - aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) - as_i64 = lambda i: arith.index_cast(i64, i) - alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) - llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... - base_ptr = llvm.getelementptr( - ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, - ) - rank = ref_ty.rank - assert rank * 2 == len(sizes_and_strides) - args = [ - host_ptr, - base_ptr, - c(utils.bytewidth(ref_ty.element_type), i64), - c(rank, i64), - utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), - utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), - c(0 if swizzle is None else swizzle, i64), - utils.pack_array([c(v, i64) for v in transformed_slice_shape]), - ] - func.call([], "mosaic_gpu_init_tma_desc", args) - def cast_tma_desc(device_ptr): - # TODO(apaszke): Investigate why prefetching can cause launch failures - # nvvm.prefetch_tensormap(device_ptr) - return device_ptr - tma_desc = self._alloc_scratch( - TMA_DESCRIPTOR_BYTES, - alignment=TMA_DESCRIPTOR_ALIGNMENT, - host_init=init_tma_desc, - device_init=cast_tma_desc, - ) - self.tma_descriptors[tma_desc_key] = tma_desc - return tma_desc - - def async_copy( - self, - *, - src_ref, - dst_ref, - gmem_slice: Any = (), - gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), - barrier: utils.BarrierRef | None = None, - swizzle: int | None = None, - arrive: bool | None = None, - uniform: bool = True, - collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, - ): - index = ir.IndexType.get() - i16 = ir.IntegerType.get_signless(16) - i32 = ir.IntegerType.get_signless(32) - smem = ir.Attribute.parse("#gpu.address_space") - src_ref_ty = ir.MemRefType(src_ref.type) - dst_ref_ty = ir.MemRefType(dst_ref.type) - element_type = src_ref_ty.element_type - element_bytewidth = utils.bytewidth(element_type) - if element_type != dst_ref_ty.element_type: - raise ValueError( - f"Expected same element type, got {element_type} and" - f" {dst_ref_ty.element_type}" - ) - if not isinstance(gmem_transform, tuple): - gmem_transform = (gmem_transform,) - - if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem: - gmem_ref, smem_ref = src_ref, dst_ref - if barrier is None: - raise ValueError("Barriers are required for GMEM -> SMEM copies") - if arrive is None: - arrive = True # Arrive by default - elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None: - gmem_ref, smem_ref = dst_ref, src_ref - if barrier is not None: - raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") - if arrive is not None: - raise ValueError("arrive is unsupported for SMEM -> GMEM copies") - else: - raise ValueError("Only SMEM <-> GMEM copies supported") - # TODO(apaszke): This is a very approximate check. Improve it! - expected_name = "builtin.unrealized_conversion_cast" - if ( - gmem_ref.owner is None - or gmem_ref.owner.opview.OPERATION_NAME != expected_name - ): - raise ValueError("GMEM reference in async_copy must be a kernel argument") - - base_indices, slice_shape, is_squeezed = utils.parse_indices( - gmem_slice, ir.MemRefType(gmem_ref.type).shape - ) - dyn_base_indices = tuple( - c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices - ) - slice_shape = tuple(slice_shape) - for t in gmem_transform: - dyn_base_indices = t.transform_index(dyn_base_indices) - slice_shape = t.transform_shape(slice_shape) - for dim, squeezed in enumerate(is_squeezed): - if squeezed: - smem_ref = utils.memref_unsqueeze(smem_ref, dim) - smem_ref_ty = ir.MemRefType(smem_ref.type) - - if slice_shape != tuple(smem_ref_ty.shape): - raise ValueError( - "Expected the SMEM reference to have the same shape as the" - f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" - ) - - dyn_base_indices = list(dyn_base_indices) - slice_shape = list(slice_shape) - collective_size = 1 - if collective is not None: - if isinstance(collective, gpu.Dimension): - collective = (collective,) - collective_size = math.prod(self.cluster_size[d] for d in collective) - if collective_size > 1: - def partition_dim(dim: int, idx: ir.Value, num_chunks: int): - nonlocal smem_ref - slice_shape[dim] //= num_chunks - block_offset = arith.muli(idx, c(slice_shape[dim], index)) - dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) - smem_ref = utils.memref_slice( - smem_ref, - (slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),) - ) - stride = 1 - idx = c(0, index) - for d in sorted(collective): - if self.cluster_size[d] == 1: # Optimize a multiply by 0. - continue - idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))) - stride *= self.cluster_size[d] - rem_collective_size = collective_size - for dim, slice_size in enumerate(slice_shape[:-1]): - if slice_size % rem_collective_size == 0: - partition_dim(dim, idx, rem_collective_size) - rem_collective_size = 1 - break - elif rem_collective_size % slice_size == 0: - dim_idx = arith.remui(idx, c(slice_size, index)) - partition_dim(dim, dim_idx, slice_size) - idx = arith.divui(idx, c(slice_size, index)) - rem_collective_size //= slice_size - else: - break # We failed to partition the leading dimensions. - del idx # We overwrote the block index in the loop. - if rem_collective_size > 1: - raise ValueError( - "None of the leading dimensions in the transformed slice shape" - f" {slice_shape} is divisible by the collective size" - f" {collective_size}" - ) - # Make each block load a smaller slice, adjust the GMEM indices and slice - # the SMEM reference accordingly. - multicast_mask = arith.trunci( - i16, utils.cluster_collective_mask(self.cluster_size, collective) - ) - else: - multicast_mask = None - - tma_desc = self._get_tma_desc( - gmem_ref, gmem_transform, tuple(slice_shape), swizzle, - ) - - # We constuct TMA descriptors in column-major order. - rev_dyn_base_indices = [ - arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) - ] - - uniform_ctx = ( - functools.partial(utils.single_thread, per_block=False) - if uniform - else contextlib.nullcontext - ) - - rank = len(slice_shape) - if rank > 5: # TODO: apaszke - Implement stride compression - raise ValueError("Async copies only support striding up to 5 dimensions") - if max(slice_shape) > 256: - raise ValueError( - "Async copies only support copying <=256 elements along each" - " dimension" - ) - if (zeroth_bw := slice_shape[-1] * element_bytewidth) % 16 != 0: - raise ValueError( - "Async copies require the number of bytes copied along the last" - f" dimension to be divisible by 16, but got {zeroth_bw}" - ) - if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth: - raise ValueError( - f"Async copies with {swizzle=} require last dimension of the slice to" - f" be exactly {swizzle} bytes" - f" ({swizzle // element_bytewidth} elements), but got" - f" {slice_shape[-1]}" - ) - smem_ptr = utils.memref_ptr(smem_ref, memory_space=3) - if gmem_ref is src_ref: - assert barrier is not None # for pytype - transfer_bytes = c( - np.prod(slice_shape) * element_bytewidth * collective_size, i32 - ) - barrier_ptr = barrier.get_ptr() - with uniform_ctx(): - if arrive: - nvvm.mbarrier_arrive_expect_tx_shared(barrier_ptr, transfer_bytes) - nvvm.cp_async_bulk_tensor_shared_cluster_global( - smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], multicast_mask=multicast_mask, - ) - else: - with uniform_ctx(): - nvvm.cp_async_bulk_tensor_global_shared_cta( - tma_desc, smem_ptr, rev_dyn_base_indices - ) - nvvm.cp_async_bulk_commit_group() - - def await_async_copy( - self, allow_groups: int, await_read_only: bool = False - ): - nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) - utils.warpgroup_barrier() - - -# ShapeTrees currently can not contain unions. -ShapeTree = Any -RefTree = Any -T = TypeVar('T') - - -@dataclasses.dataclass(frozen=True) -class Union(Generic[T]): - members: Sequence[T] - - def __iter__(self): - return iter(self.members) - -@dataclasses.dataclass(frozen=True) -class TMABarrier: - num_barriers: int = 1 - -@dataclasses.dataclass(frozen=True) -class Barrier: - arrival_count: int - num_barriers: int = 1 - -@dataclasses.dataclass(frozen=True) -class ClusterBarrier: - collective_dims: Sequence[gpu.Dimension] - num_barriers: int = 1 - - -def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: - return np.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize - - -def _construct_smem_reftree( - cluster_shape: tuple[int, int, int], - dynamic_smem: ir.Value, - smem_buffers: ShapeTree, - dynamic_smem_offset: int = 0, -) -> RefTree: - index = ir.IndexType.get() - i8 = ir.IntegerType.get_signless(8) - ptr = ir.Type.parse("!llvm.ptr") - smem = ir.Attribute.parse("#gpu.address_space") - flat_ref_tys, smem_buffer_tree = jax.tree.flatten( - smem_buffers, is_leaf=lambda x: isinstance(x, Union) - ) - smem_refs = [] - for ref_ty in flat_ref_tys: - def get_barrier_ptr(num_barriers: int) -> ir.Value: - nonlocal dynamic_smem_offset - smem_base_ptr = utils.memref_ptr(dynamic_smem, memory_space=3) - barrier_base_ptr = llvm.getelementptr( - ptr, smem_base_ptr, [], [dynamic_smem_offset], i8 - ) - dynamic_smem_offset += num_barriers * MBARRIER_BYTES - return barrier_base_ptr - match ref_ty: - case Union(members): - member_trees = [ - _construct_smem_reftree(cluster_shape, dynamic_smem, m, dynamic_smem_offset) - for m in members - ] - # TODO(apaszke): This is quadratic, but it shouldn't matter for now... - dynamic_smem_offset += _smem_tree_size(ref_ty) - ref = Union(member_trees) - case TMABarrier(num_barriers): - ref = utils.BarrierRef.initialize( - get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 - ) - case Barrier(arrival_count, num_barriers): - ref = utils.BarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, - arrival_count=arrival_count, - ) - case ClusterBarrier(collective_dims, num_barriers): - ref = utils.CollectiveBarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, - collective_dims, - cluster_shape, - ) - case _: - mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype) - tile_smem = memref.view( - ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), - dynamic_smem, c(dynamic_smem_offset, index), [], - ) - dynamic_smem_offset += _count_buffer_bytes(ref_ty) - ref = tile_smem - smem_refs.append(ref) - return jax.tree.unflatten(smem_buffer_tree, smem_refs) - - -MBARRIER_BYTES = 8 - - -def _smem_tree_size(smem_buffers: ShapeTree) -> int: - leaves = jax.tree.leaves( - smem_buffers, is_leaf=lambda x: isinstance(x, Union) - ) - size = 0 - for l in leaves: - match l: - case Union(members): - size += max(_smem_tree_size(s) for s in members) - case ( - TMABarrier(num_barriers) - | ClusterBarrier(_, num_barriers=num_barriers) - | Barrier(_, num_barriers=num_barriers) - ): - if size % MBARRIER_BYTES: - raise NotImplementedError("Misaligned barrier allocation") - size += num_barriers * MBARRIER_BYTES - case _: - size += _count_buffer_bytes(l) - return size - - -# TODO(apaszke): Inline this -@contextlib.contextmanager -def _launch( - token, - grid: tuple[int, int, int], - cluster: tuple[int, int, int], - block: tuple[int, int, int], - scratch_arr, - smem_buffers: ShapeTree | Union[ShapeTree], - profiler_spec: profiler.ProfilerSpec | None = None, - maybe_prof_buffer: ir.Value | None = None, -): - if (profiler_spec is None) != (maybe_prof_buffer is None): - raise ValueError - index = ir.IndexType.get() - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - grid_vals = [c(i, index) for i in grid] - block_vals = [c(i, index) for i in block] - - user_smem_bytes = _smem_tree_size(smem_buffers) - - smem_bytes = user_smem_bytes - if profiler_spec is not None: - smem_bytes += profiler_spec.smem_bytes(block=block) - - # TODO(cperivol): Query the shared memory size programmatically. - if smem_bytes > 228 * 1024: - raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000") - if math.prod(cluster) != 1: - if len(cluster) != 3: - raise ValueError("Clusters must be 3D") - cluster_kwargs = { - "clusterSize" + d: c(s, index) for s, d in zip(cluster, "XYZ") - } - for d, grid_size, cluster_size in zip("xyz", grid, cluster): - if grid_size % cluster_size != 0: - raise ValueError( - f"Grid dimension {d} must be divisible by cluster dimension:" - f" {grid_size} % {cluster_size} != 0" - ) - else: - cluster_kwargs = {} - launch_op = gpu.LaunchOp( - token.type, [token], *grid_vals, *block_vals, - dynamicSharedMemorySize=c(smem_bytes, i32), **cluster_kwargs) - launch_op.body.blocks.append(*([index] * (12 + 2 * len(cluster_kwargs)))) # Append an empty block - smem = ir.Attribute.parse("#gpu.address_space") - with ir.InsertionPoint(launch_op.body.blocks[0]): - dynamic_smem = gpu.dynamic_shared_memory( - ir.MemRefType.get( - (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem - ) - ) - - smem_ref_tree = _construct_smem_reftree( - cluster, dynamic_smem, smem_buffers - ) - # TODO(apaszke): Skip the following if no barriers were initialized. - nvvm.fence_mbarrier_init() - if math.prod(cluster) != 1: - nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) - nvvm.cluster_wait(aligned=ir.UnitAttr.get()) - gpu.barrier() - - if profiler_spec: - prof_smem = memref.view( - ir.MemRefType.get( - (profiler_spec.smem_i32_elements(block=block),), - i32, memory_space=smem, - ), - dynamic_smem, c(user_smem_bytes, index), [], - ) - prof = profiler.OnDeviceProfiler( - profiler_spec, prof_smem, maybe_prof_buffer - ) - else: - prof = None - - ptr_ty = ir.Type.parse("!llvm.ptr") - scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) - yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree - if prof is not None: - prof.finalize(grid=grid, block=block) - gpu.terminator() - - -def _lower_as_gpu_kernel( - body, - grid: tuple[int, int, int], - cluster: tuple[int, int, int], - block: tuple[int, int, int], - in_shapes: tuple[Any, ...], - out_shape, - smem_scratch_shape: ShapeTree | Union[ShapeTree], - module_name: str, - prof_spec: profiler.ProfilerSpec | None = None, -): - ptr_ty = ir.Type.parse("!llvm.ptr") - token_ty = ir.Type.parse("!gpu.async.token") - i32 = ir.IntegerType.get_signless(32) - i64 = ir.IntegerType.get_signless(64) - - def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: - return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) - - in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] - - unwrap_output_tuple = False - if isinstance(out_shape, list): - out_shape = tuple(out_shape) - elif not isinstance(out_shape, tuple): - out_shape = (out_shape,) - unwrap_output_tuple = True - out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] - if prof_spec is not None: - out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) - out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) - - module = ir.Module.create() - attrs = module.operation.attributes - attrs["sym_name"] = ir.StringAttr.get(module_name) - with ir.InsertionPoint(module.body): - _declare_runtime_functions() - gmem_scratch_bytes = 0 - global_scratch = llvm.GlobalOp( - ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet. - "global_scratch", - ir.Attribute.parse("#llvm.linkage"), - addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. - ) - @func.FuncOp.from_py_func(ptr_ty, ptr_ty) - def main(token_ptr, buffers): - nonlocal gmem_scratch_bytes - token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) - arg_refs = [] - for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): - ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) - arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) - in_refs = arg_refs[:len(in_ref_tys)] - out_refs = arg_refs[len(in_ref_tys):] - prof_buffer = out_refs.pop() if prof_spec is not None else None - empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") - scratch_alloc = llvm.AllocaOp( - ptr_ty, c(1, i64), empty_arr_ty, alignment=TMA_DESCRIPTOR_ALIGNMENT - ) - scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result) - with _launch( - token, grid, cluster, block, scratch_arr, smem_scratch_shape, - prof_spec, prof_buffer - ) as (launch_ctx, smem_refs): - body(launch_ctx, *in_refs, *out_refs, smem_refs) - gmem_scratch_bytes = launch_ctx.next_scratch_offset - # Allocate and initialize the host buffer right before the launch. - # Note that we couldn't do that before, because we had to run the body - # to learn what the scratch contains. - with ir.InsertionPoint(scratch_arr.owner): - scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") - scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty) - scratch_arr.set_type(scratch_arr_ty) - for init_callback in launch_ctx.host_scratch_init: - init_callback(scratch_alloc.result) - main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - sym_tab = ir.SymbolTable(module.operation) - sym_tab.insert(main.func_op) - sym_tab.insert(global_scratch) - module.operation.verify() - - return module, out_shape, unwrap_output_tuple - - -def _declare_runtime_functions(): - """Declares the runtime functions that can be used by the generated code.""" - ptr_ty = ir.Type.parse("!llvm.ptr") - i64 = ir.IntegerType.get_signless(64) - arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] - init_tma_desc_type = ir.FunctionType.get(arg_tys, []) - func.FuncOp( - "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" - ) - memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) - func.FuncOp( - "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" - ) - - -def as_gpu_kernel( - body, - grid: tuple[int, int, int], - block: tuple[int, int, int], - in_shape, - out_shape, - smem_scratch_shape: ShapeTree | Union[ShapeTree], - prof_spec: profiler.ProfilerSpec | None = None, - cluster: tuple[int, int, int] = (1, 1, 1), - module_name: str = "unknown", -): - if isinstance(in_shape, list): - in_shape = tuple(in_shape) - elif not isinstance(in_shape, tuple): - in_shape = (in_shape,) - - module, out_shape, unwrap_output_tuple = ( - _lower_as_gpu_kernel( - body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, prof_spec - ) - ) - - expected_arg_treedef = jax.tree.structure(in_shape) - def _check_args(*args): - arg_treedef = jax.tree.structure(args) - if arg_treedef != expected_arg_treedef: - raise ValueError( - f"Invalid argument structure: expected {expected_arg_treedef}, got" - f" {arg_treedef}, ({args=})" - ) - - module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) - def bind(*args): - return mosaic_gpu_p.bind( - *args, - out_types=out_shape, - module=module_asm, - ) - - if prof_spec is not None: - @jax.jit - def prof_kernel(*args): - _check_args(*args) - *results, prof_buffer = bind(*args) - def dump_profile(prof_buffer): - out_file = os.path.join( - os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), - f"{time.time_ns()}-trace.json", - ) - try: - with open(out_file, "x") as f: - prof_spec.dump(prof_buffer, f, grid=grid, block=block) - except FileExistsError: - pass # TODO: Retry - jax.debug.callback(dump_profile, prof_buffer) - return results[0] if unwrap_output_tuple else results - return prof_kernel - else: - @jax.jit - def kernel(*args): - _check_args(*args) - results = bind(*args) - return results[0] if unwrap_output_tuple else results - return kernel - - -def as_torch_gpu_kernel( - body, - grid: tuple[int, int, int], - block: tuple[int, int, int], - in_shape, - out_shape, - smem_scratch_shape: ShapeTree | Union[ShapeTree], - prof_spec: profiler.ProfilerSpec | None = None, - cluster: tuple[int, int, int] = (1, 1, 1), - module_name: str = "unknown", -): - try: - import torch - except ImportError: - raise RuntimeError("as_torch_gpu_kernel requires PyTorch") - torch.cuda.init() # Make sure CUDA context is set up. - - if isinstance(in_shape, list): - in_shape = tuple(in_shape) - elif not isinstance(in_shape, tuple): - in_shape = (in_shape,) - - flat_out_types, out_treedef = jax.tree.flatten(out_shape) - expected_arg_treedef = jax.tree.structure(in_shape) - - module, out_shape, unwrap_output_tuple = ( - _lower_as_gpu_kernel( - body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, prof_spec - ) - ) - - # Get our hands on the compilation and unload functions - try: - import jax_plugins.xla_cuda12 as cuda_plugin - except ImportError: - raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds " - "that use backend plugins") - dll = ctypes.CDLL(cuda_plugin._get_library_path()) - compile_func = dll.MosaicGpuCompile - compile_func.argtypes = [ctypes.c_void_p] - compile_func.restype = ctypes.POINTER(ctypes.c_void_p) - unload_func = dll.MosaicGpuUnload - unload_func.argtypes = [compile_func.restype] - unload_func.restype = None - - module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) - compiled = compile_func(ctypes.c_char_p(module_asm)) - if compiled is None: - raise RuntimeError("Failed to compile the module") - ctx, launch_ptr = compiled[0], compiled[1] - ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx)) - launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr) - - def as_torch_dtype(dtype): - # torch contains NumPy-compatible dtypes in its top namespace - return getattr(torch, np.dtype(dtype).name) - - def apply(*args): - flat_args, arg_treedef = jax.tree.flatten(args) - if arg_treedef != expected_arg_treedef: - raise ValueError( - f"Invalid argument structure: expected {expected_arg_treedef}, got" - f" {arg_treedef}, ({args=})" - ) - - # Construct a device pointer list like in the XLA calling convention - buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))() - i = -1 # Define i in case there are no args - device = 'cuda' - for i, arg in enumerate(flat_args): - buffers[i] = arg.data_ptr() - device = arg.device - flat_outs = [] - for i, t in enumerate(flat_out_types, i + 1): - out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device) - flat_outs.append(out) - buffers[i] = out.data_ptr() - # Allocate another buffer for args of the host-side program. This is sadly - # the default MLIR calling convention. - args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)() - args_ptr[0] = ctx_ptr_ptr - args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_) - args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)), - ctypes.POINTER(ctypes.c_void_p)) - launch(args_ptr) - return jax.tree.unflatten(out_treedef, flat_outs) - - # Unload the compiled code when the Python function is destroyed. - def unload(_): - unload_func(compiled) - apply.destructor = weakref.ref(apply, unload) - - return apply +from jax import ShapeDtypeStruct +from .core import ( + Barrier, + ClusterBarrier, + LaunchContext, + MemRefTransform, + TMABarrier, + TileTransform, + TransposeTransform, + Union, + as_gpu_kernel, +) +from .fragmented_array import ( + FragmentedArray, + FragmentedLayout, + WGMMA_LAYOUT, + WGMMA_ROW_LAYOUT, + WGStridedFragLayout, +) +from .utils import ( + BarrierRef, + CollectiveBarrierRef, + DynamicSlice, + Partition, + Partition1D, + bytewidth, + c, + commit_shared, + debug_print, + ds, + fori, + memref_fold, + memref_slice, + memref_transpose, + memref_unfold, + memref_unsqueeze, + single_thread, + thread_idx, + tile_shape, + warp_idx, + warpgroup_barrier, + warpgroup_idx, + when, +) +from .wgmma import ( + WGMMAAccumulator, + WGMMALayout, + wgmma, +) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py new file mode 100644 index 000000000000..0e263844b18e --- /dev/null +++ b/jax/experimental/mosaic/gpu/core.py @@ -0,0 +1,979 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Callable, Sequence +import contextlib +import ctypes +import dataclasses +import functools +import hashlib +import itertools +import math +import os +import pathlib +import subprocess +import tempfile +import time +from typing import Any, Generic, TypeVar +import weakref + +import jax +from jax._src import config +from jax._src import core as jax_core +from jax._src.interpreters import mlir +from jax._src.lib import xla_client +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin +from jaxlib.mlir.dialects import func +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvvm +from jaxlib.mlir.passmanager import PassManager +import numpy as np + +from . import profiler +from . import utils + +# mypy: ignore-errors + +# MLIR can't find libdevice unless we point it to the CUDA path +# TODO(apaszke): Unify with jax._src.lib.cuda_path +CUDA_ROOT = "/usr/local/cuda" +if os.environ.get("CUDA_ROOT") is None: + os.environ["CUDA_ROOT"] = CUDA_ROOT +else: + CUDA_ROOT = os.environ["CUDA_ROOT"] + +PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") +NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") + +TMA_DESCRIPTOR_BYTES = 128 +TMA_DESCRIPTOR_ALIGNMENT = 64 + + +c = utils.c # This is too common to fully qualify. + + +RUNTIME_PATH = None +try: + from jax._src.lib import mosaic_gpu as mosaic_gpu_lib + + RUNTIME_PATH = ( + pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent + / "libmosaic_gpu_runtime.so" + ) +except ImportError: + pass + +if RUNTIME_PATH and RUNTIME_PATH.exists(): + # Set this so that the custom call can find it + os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) + + +mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") +mosaic_gpu_p.multiple_results = True + + +@mosaic_gpu_p.def_abstract_eval +def _mosaic_gpu_abstract_eval(*_, module, out_types): + del module # Unused. + return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] + +# TODO(apaszke): Implement a proper system for managing kernel lifetimes +KNOWN_KERNELS = {} + +def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): + del out_types # Unused. + kernel_id = hashlib.sha256(module).digest() + # Note that this is technically only a half measure. Someone might load a + # compiled module with a hash collision from disk. But that's so unlikely with + # SHA256 that it shouldn't be a problem. + if (kernel_text := KNOWN_KERNELS.get(kernel_id, None)) is not None: + if kernel_text != module: + raise RuntimeError("Hash collision!") + else: + KNOWN_KERNELS[kernel_id] = module + op = mlir.custom_call( + "mosaic_gpu", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=kernel_id + module, + ) + return op.results + +mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") + + +@dataclasses.dataclass(frozen=True) +class MemRefTransform: + def apply(self, ref: ir.Value) -> ir.Value: + raise NotImplementedError("Subclasses should override this method") + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + raise NotImplementedError("Subclasses should override this method") + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + raise NotImplementedError("Subclasses should override this method") + + +@dataclasses.dataclass(frozen=True) +class TileTransform(MemRefTransform): + """Tiles a suffix of memref dimensions. + + For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), + the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with + the tile shape, and the size of tiled dimensions is divided by the tile size. + This is especially useful for swizzled WGMMA, which expect tiled layouts in + shared memory. + """ + tiling: tuple[int, ...] + + def apply(self, ref: ir.Value) -> ir.Value: + untiled_rank = ir.MemRefType(ref.type).rank + tiling_rank = len(self.tiling) + tiled_rank = untiled_rank + tiling_rank + for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]): + s = ir.MemRefType(ref.type).shape[d] + if s % t and s > t: + raise ValueError( + f"Dimension {d} must have size smaller or a multiple of its tiling" + f" {t}, but got {s}" + ) + ref = utils.memref_unfold(ref, d, (None, min(t, s))) + permutation = ( + *range(untiled_rank - tiling_rank), + *range(untiled_rank - tiling_rank, tiled_rank, 2), + *range(untiled_rank - tiling_rank + 1, tiled_rank, 2), + ) + return utils.memref_transpose(ref, permutation) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + index = ir.IndexType.get() + tiling_rank = len(self.tiling) + return ( + *idx[:-tiling_rank], + *( + arith.divui(i, c(t, index)) + for i, t in zip(idx[-tiling_rank:], self.tiling) + ), + *( + arith.remui(i, c(t, index)) + for i, t in zip(idx[-tiling_rank:], self.tiling) + ), + ) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + # Note that this also checks that tiled dims are not squeezed. Their slice + # size would be 1 if so. + tiling_rank = len(self.tiling) + for size, tile_size in zip(shape[-tiling_rank:], self.tiling): + if size % tile_size: + raise ValueError( + f"Expected GMEM slice shape {shape} suffix to be a multiple of" + f" tiling {self.tiling}.\nIf you're using padded async copies, your" + " slice might need to extend out of bounds of the GMEM buffer (OOB" + " accesses will be skipped)." + ) + return ( + *shape[:-tiling_rank], + *(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)), + *self.tiling, + ) + + +@dataclasses.dataclass(frozen=True) +class TransposeTransform(MemRefTransform): + """Transposes memref dimensions.""" + permutation: tuple[int, ...] + + def __post_init__(self): + if len(self.permutation) != len(set(self.permutation)): + raise ValueError("Permutation must be a permutation") + + def apply(self, ref: ir.Value) -> ir.Value: + return utils.memref_transpose(ref, self.permutation) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + return tuple(idx[p] for p in self.permutation) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + return tuple(shape[p] for p in self.permutation) + + +OnDeviceProfiler = profiler.OnDeviceProfiler + + +@dataclasses.dataclass() +class LaunchContext: + launch_op: gpu.LaunchOp + gmem_scratch_ptr: ir.Value + cluster_size: tuple[int, int, int] + profiler: OnDeviceProfiler | None = None + next_scratch_offset: int = 0 + host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( + default_factory=list, init=False + ) + tma_descriptors: dict[ + tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], + ir.Value, + ] = dataclasses.field(default_factory=dict, init=False) + + @contextlib.contextmanager + def named_region(self, *args, **kwargs): + if self.profiler is not None: + with self.profiler.record(*args, **kwargs): + yield + else: + yield + + def _alloc_scratch( + self, + size: int, + alignment: int | None = None, + host_init: Callable[[ir.Value], None] = lambda _: None, + device_init: Callable[[ir.Value], Any] = lambda x: x, + ) -> ir.Value: + """Allocates a GMEM scratch buffer. + + The buffer is initialized on the host and then copied to GMEM before the + kernel launch. + """ + i8 = ir.IntegerType.get_signless(8) + ptr_ty = ir.Type.parse("!llvm.ptr") + if alignment is None: + alignment = size + if self.next_scratch_offset % alignment: + raise NotImplementedError # TODO(apaszke): Pad to match alignment + alloc_base = self.next_scratch_offset + self.next_scratch_offset += size + def host_init_wrapped(host_ptr): + host_init( + llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) + ) + self.host_scratch_init.append(host_init_wrapped) + # with ir.InsertionPoint(self.gmem_scratch_ptr.owner): + # There is no way to create an insertion point after an operation... + gep = llvm.GEPOp( + ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 + ) + gep.move_after(self.gmem_scratch_ptr.owner) + return device_init(gep.result) + + def _get_tma_desc( + self, + gmem_ref, + gmem_transform: tuple[MemRefTransform, ...], + transformed_slice_shape: tuple[int, ...], + swizzle: int | None, + ): + tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) + if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr") + def init_tma_desc(host_ptr): + ref = gmem_ref + for t in gmem_transform: + ref = t.apply(ref) + ref_ty = ir.MemRefType(ref.type) + # TODO(apaszke): Use utils.memref_ptr to compute base_ptr + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) + aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) + as_i64 = lambda i: arith.index_cast(i64, i) + alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) + llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... + base_ptr = llvm.getelementptr( + ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, + ) + rank = ref_ty.rank + assert rank * 2 == len(sizes_and_strides) + args = [ + host_ptr, + base_ptr, + c(utils.bytewidth(ref_ty.element_type), i64), + c(rank, i64), + utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), + utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), + c(0 if swizzle is None else swizzle, i64), + utils.pack_array([c(v, i64) for v in transformed_slice_shape]), + ] + func.call([], "mosaic_gpu_init_tma_desc", args) + def cast_tma_desc(device_ptr): + # TODO(apaszke): Investigate why prefetching can cause launch failures + # nvvm.prefetch_tensormap(device_ptr) + return device_ptr + tma_desc = self._alloc_scratch( + TMA_DESCRIPTOR_BYTES, + alignment=TMA_DESCRIPTOR_ALIGNMENT, + host_init=init_tma_desc, + device_init=cast_tma_desc, + ) + self.tma_descriptors[tma_desc_key] = tma_desc + return tma_desc + + def async_copy( + self, + *, + src_ref, + dst_ref, + gmem_slice: Any = (), + gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), + barrier: utils.BarrierRef | None = None, + swizzle: int | None = None, + arrive: bool | None = None, + uniform: bool = True, + collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, + ): + index = ir.IndexType.get() + i16 = ir.IntegerType.get_signless(16) + i32 = ir.IntegerType.get_signless(32) + smem = ir.Attribute.parse("#gpu.address_space") + src_ref_ty = ir.MemRefType(src_ref.type) + dst_ref_ty = ir.MemRefType(dst_ref.type) + element_type = src_ref_ty.element_type + element_bytewidth = utils.bytewidth(element_type) + if element_type != dst_ref_ty.element_type: + raise ValueError( + f"Expected same element type, got {element_type} and" + f" {dst_ref_ty.element_type}" + ) + if not isinstance(gmem_transform, tuple): + gmem_transform = (gmem_transform,) + + if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem: + gmem_ref, smem_ref = src_ref, dst_ref + if barrier is None: + raise ValueError("Barriers are required for GMEM -> SMEM copies") + if arrive is None: + arrive = True # Arrive by default + elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None: + gmem_ref, smem_ref = dst_ref, src_ref + if barrier is not None: + raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") + if arrive is not None: + raise ValueError("arrive is unsupported for SMEM -> GMEM copies") + else: + raise ValueError("Only SMEM <-> GMEM copies supported") + # TODO(apaszke): This is a very approximate check. Improve it! + expected_name = "builtin.unrealized_conversion_cast" + if ( + gmem_ref.owner is None + or gmem_ref.owner.opview.OPERATION_NAME != expected_name + ): + raise ValueError("GMEM reference in async_copy must be a kernel argument") + + base_indices, slice_shape, is_squeezed = utils.parse_indices( + gmem_slice, ir.MemRefType(gmem_ref.type).shape + ) + dyn_base_indices = tuple( + c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices + ) + slice_shape = tuple(slice_shape) + for t in gmem_transform: + dyn_base_indices = t.transform_index(dyn_base_indices) + slice_shape = t.transform_shape(slice_shape) + for dim, squeezed in enumerate(is_squeezed): + if squeezed: + smem_ref = utils.memref_unsqueeze(smem_ref, dim) + smem_ref_ty = ir.MemRefType(smem_ref.type) + + if slice_shape != tuple(smem_ref_ty.shape): + raise ValueError( + "Expected the SMEM reference to have the same shape as the" + f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" + ) + + dyn_base_indices = list(dyn_base_indices) + slice_shape = list(slice_shape) + collective_size = 1 + if collective is not None: + if isinstance(collective, gpu.Dimension): + collective = (collective,) + collective_size = math.prod(self.cluster_size[d] for d in collective) + if collective_size > 1: + def partition_dim(dim: int, idx: ir.Value, num_chunks: int): + nonlocal smem_ref + slice_shape[dim] //= num_chunks + block_offset = arith.muli(idx, c(slice_shape[dim], index)) + dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) + smem_ref = utils.memref_slice( + smem_ref, + (slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),) + ) + stride = 1 + idx = c(0, index) + for d in sorted(collective): + if self.cluster_size[d] == 1: # Optimize a multiply by 0. + continue + idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))) + stride *= self.cluster_size[d] + rem_collective_size = collective_size + for dim, slice_size in enumerate(slice_shape[:-1]): + if slice_size % rem_collective_size == 0: + partition_dim(dim, idx, rem_collective_size) + rem_collective_size = 1 + break + elif rem_collective_size % slice_size == 0: + dim_idx = arith.remui(idx, c(slice_size, index)) + partition_dim(dim, dim_idx, slice_size) + idx = arith.divui(idx, c(slice_size, index)) + rem_collective_size //= slice_size + else: + break # We failed to partition the leading dimensions. + del idx # We overwrote the block index in the loop. + if rem_collective_size > 1: + raise ValueError( + "None of the leading dimensions in the transformed slice shape" + f" {slice_shape} is divisible by the collective size" + f" {collective_size}" + ) + # Make each block load a smaller slice, adjust the GMEM indices and slice + # the SMEM reference accordingly. + multicast_mask = arith.trunci( + i16, utils.cluster_collective_mask(self.cluster_size, collective) + ) + else: + multicast_mask = None + + tma_desc = self._get_tma_desc( + gmem_ref, gmem_transform, tuple(slice_shape), swizzle, + ) + + # We constuct TMA descriptors in column-major order. + rev_dyn_base_indices = [ + arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) + ] + + uniform_ctx = ( + functools.partial(utils.single_thread, per_block=False) + if uniform + else contextlib.nullcontext + ) + + rank = len(slice_shape) + if rank > 5: # TODO: apaszke - Implement stride compression + raise ValueError("Async copies only support striding up to 5 dimensions") + if max(slice_shape) > 256: + raise ValueError( + "Async copies only support copying <=256 elements along each" + " dimension" + ) + if (zeroth_bw := slice_shape[-1] * element_bytewidth) % 16 != 0: + raise ValueError( + "Async copies require the number of bytes copied along the last" + f" dimension to be divisible by 16, but got {zeroth_bw}" + ) + if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth: + raise ValueError( + f"Async copies with {swizzle=} require last dimension of the slice to" + f" be exactly {swizzle} bytes" + f" ({swizzle // element_bytewidth} elements), but got" + f" {slice_shape[-1]}" + ) + smem_ptr = utils.memref_ptr(smem_ref, memory_space=3) + if gmem_ref is src_ref: + assert barrier is not None # for pytype + transfer_bytes = c( + np.prod(slice_shape) * element_bytewidth * collective_size, i32 + ) + barrier_ptr = barrier.get_ptr() + with uniform_ctx(): + if arrive: + nvvm.mbarrier_arrive_expect_tx_shared(barrier_ptr, transfer_bytes) + nvvm.cp_async_bulk_tensor_shared_cluster_global( + smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], multicast_mask=multicast_mask, + ) + else: + with uniform_ctx(): + nvvm.cp_async_bulk_tensor_global_shared_cta( + tma_desc, smem_ptr, rev_dyn_base_indices + ) + nvvm.cp_async_bulk_commit_group() + + def await_async_copy( + self, allow_groups: int, await_read_only: bool = False + ): + nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) + utils.warpgroup_barrier() + + +# ShapeTrees currently can not contain unions. +ShapeTree = Any +RefTree = Any +T = TypeVar('T') + + +@dataclasses.dataclass(frozen=True) +class Union(Generic[T]): + members: Sequence[T] + + def __iter__(self): + return iter(self.members) + +@dataclasses.dataclass(frozen=True) +class TMABarrier: + num_barriers: int = 1 + +@dataclasses.dataclass(frozen=True) +class Barrier: + arrival_count: int + num_barriers: int = 1 + +@dataclasses.dataclass(frozen=True) +class ClusterBarrier: + collective_dims: Sequence[gpu.Dimension] + num_barriers: int = 1 + + +def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: + return np.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize + + +def _construct_smem_reftree( + cluster_shape: tuple[int, int, int], + dynamic_smem: ir.Value, + smem_buffers: ShapeTree, + dynamic_smem_offset: int = 0, +) -> RefTree: + index = ir.IndexType.get() + i8 = ir.IntegerType.get_signless(8) + ptr = ir.Type.parse("!llvm.ptr") + smem = ir.Attribute.parse("#gpu.address_space") + flat_ref_tys, smem_buffer_tree = jax.tree.flatten( + smem_buffers, is_leaf=lambda x: isinstance(x, Union) + ) + smem_refs = [] + for ref_ty in flat_ref_tys: + def get_barrier_ptr(num_barriers: int) -> ir.Value: + nonlocal dynamic_smem_offset + smem_base_ptr = utils.memref_ptr(dynamic_smem, memory_space=3) + barrier_base_ptr = llvm.getelementptr( + ptr, smem_base_ptr, [], [dynamic_smem_offset], i8 + ) + dynamic_smem_offset += num_barriers * MBARRIER_BYTES + return barrier_base_ptr + match ref_ty: + case Union(members): + member_trees = [ + _construct_smem_reftree(cluster_shape, dynamic_smem, m, dynamic_smem_offset) + for m in members + ] + # TODO(apaszke): This is quadratic, but it shouldn't matter for now... + dynamic_smem_offset += _smem_tree_size(ref_ty) + ref = Union(member_trees) + case TMABarrier(num_barriers): + ref = utils.BarrierRef.initialize( + get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 + ) + case Barrier(arrival_count, num_barriers): + ref = utils.BarrierRef.initialize( + get_barrier_ptr(num_barriers), + num_barriers, + arrival_count=arrival_count, + ) + case ClusterBarrier(collective_dims, num_barriers): + ref = utils.CollectiveBarrierRef.initialize( + get_barrier_ptr(num_barriers), + num_barriers, + collective_dims, + cluster_shape, + ) + case _: + mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype) + tile_smem = memref.view( + ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), + dynamic_smem, c(dynamic_smem_offset, index), [], + ) + dynamic_smem_offset += _count_buffer_bytes(ref_ty) + ref = tile_smem + smem_refs.append(ref) + return jax.tree.unflatten(smem_buffer_tree, smem_refs) + + +MBARRIER_BYTES = 8 + + +def _smem_tree_size(smem_buffers: ShapeTree) -> int: + leaves = jax.tree.leaves( + smem_buffers, is_leaf=lambda x: isinstance(x, Union) + ) + size = 0 + for l in leaves: + match l: + case Union(members): + size += max(_smem_tree_size(s) for s in members) + case ( + TMABarrier(num_barriers) + | ClusterBarrier(_, num_barriers=num_barriers) + | Barrier(_, num_barriers=num_barriers) + ): + if size % MBARRIER_BYTES: + raise NotImplementedError("Misaligned barrier allocation") + size += num_barriers * MBARRIER_BYTES + case _: + size += _count_buffer_bytes(l) + return size + + +# TODO(apaszke): Inline this +@contextlib.contextmanager +def _launch( + token, + grid: tuple[int, int, int], + cluster: tuple[int, int, int], + block: tuple[int, int, int], + scratch_arr, + smem_buffers: ShapeTree | Union[ShapeTree], + profiler_spec: profiler.ProfilerSpec | None = None, + maybe_prof_buffer: ir.Value | None = None, +): + if (profiler_spec is None) != (maybe_prof_buffer is None): + raise ValueError + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + grid_vals = [c(i, index) for i in grid] + block_vals = [c(i, index) for i in block] + + user_smem_bytes = _smem_tree_size(smem_buffers) + + smem_bytes = user_smem_bytes + if profiler_spec is not None: + smem_bytes += profiler_spec.smem_bytes(block=block) + + # TODO(cperivol): Query the shared memory size programmatically. + if smem_bytes > 228 * 1024: + raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000") + if math.prod(cluster) != 1: + if len(cluster) != 3: + raise ValueError("Clusters must be 3D") + cluster_kwargs = { + "clusterSize" + d: c(s, index) for s, d in zip(cluster, "XYZ") + } + for d, grid_size, cluster_size in zip("xyz", grid, cluster): + if grid_size % cluster_size != 0: + raise ValueError( + f"Grid dimension {d} must be divisible by cluster dimension:" + f" {grid_size} % {cluster_size} != 0" + ) + else: + cluster_kwargs = {} + launch_op = gpu.LaunchOp( + token.type, [token], *grid_vals, *block_vals, + dynamicSharedMemorySize=c(smem_bytes, i32), **cluster_kwargs) + launch_op.body.blocks.append(*([index] * (12 + 2 * len(cluster_kwargs)))) # Append an empty block + smem = ir.Attribute.parse("#gpu.address_space") + with ir.InsertionPoint(launch_op.body.blocks[0]): + dynamic_smem = gpu.dynamic_shared_memory( + ir.MemRefType.get( + (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem + ) + ) + + smem_ref_tree = _construct_smem_reftree( + cluster, dynamic_smem, smem_buffers + ) + # TODO(apaszke): Skip the following if no barriers were initialized. + nvvm.fence_mbarrier_init() + if math.prod(cluster) != 1: + nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) + nvvm.cluster_wait(aligned=ir.UnitAttr.get()) + gpu.barrier() + + if profiler_spec: + prof_smem = memref.view( + ir.MemRefType.get( + (profiler_spec.smem_i32_elements(block=block),), + i32, memory_space=smem, + ), + dynamic_smem, c(user_smem_bytes, index), [], + ) + prof = profiler.OnDeviceProfiler( + profiler_spec, prof_smem, maybe_prof_buffer + ) + else: + prof = None + + ptr_ty = ir.Type.parse("!llvm.ptr") + scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) + yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree + if prof is not None: + prof.finalize(grid=grid, block=block) + gpu.terminator() + + +def _lower_as_gpu_kernel( + body, + grid: tuple[int, int, int], + cluster: tuple[int, int, int], + block: tuple[int, int, int], + in_shapes: tuple[Any, ...], + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + module_name: str, + prof_spec: profiler.ProfilerSpec | None = None, +): + ptr_ty = ir.Type.parse("!llvm.ptr") + token_ty = ir.Type.parse("!gpu.async.token") + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + + def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: + return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) + + in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] + + unwrap_output_tuple = False + if isinstance(out_shape, list): + out_shape = tuple(out_shape) + elif not isinstance(out_shape, tuple): + out_shape = (out_shape,) + unwrap_output_tuple = True + out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] + if prof_spec is not None: + out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) + out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) + + module = ir.Module.create() + attrs = module.operation.attributes + attrs["sym_name"] = ir.StringAttr.get(module_name) + with ir.InsertionPoint(module.body): + _declare_runtime_functions() + gmem_scratch_bytes = 0 + global_scratch = llvm.GlobalOp( + ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet. + "global_scratch", + ir.Attribute.parse("#llvm.linkage"), + addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. + ) + @func.FuncOp.from_py_func(ptr_ty, ptr_ty) + def main(token_ptr, buffers): + nonlocal gmem_scratch_bytes + token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) + arg_refs = [] + for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): + ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) + arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) + in_refs = arg_refs[:len(in_ref_tys)] + out_refs = arg_refs[len(in_ref_tys):] + prof_buffer = out_refs.pop() if prof_spec is not None else None + empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") + scratch_alloc = llvm.AllocaOp( + ptr_ty, c(1, i64), empty_arr_ty, alignment=TMA_DESCRIPTOR_ALIGNMENT + ) + scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result) + with _launch( + token, grid, cluster, block, scratch_arr, smem_scratch_shape, + prof_spec, prof_buffer + ) as (launch_ctx, smem_refs): + body(launch_ctx, *in_refs, *out_refs, smem_refs) + gmem_scratch_bytes = launch_ctx.next_scratch_offset + # Allocate and initialize the host buffer right before the launch. + # Note that we couldn't do that before, because we had to run the body + # to learn what the scratch contains. + with ir.InsertionPoint(scratch_arr.owner): + scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") + scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty) + scratch_arr.set_type(scratch_arr_ty) + for init_callback in launch_ctx.host_scratch_init: + init_callback(scratch_alloc.result) + main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + sym_tab = ir.SymbolTable(module.operation) + sym_tab.insert(main.func_op) + sym_tab.insert(global_scratch) + module.operation.verify() + + return module, out_shape, unwrap_output_tuple + + +def _declare_runtime_functions(): + """Declares the runtime functions that can be used by the generated code.""" + ptr_ty = ir.Type.parse("!llvm.ptr") + i64 = ir.IntegerType.get_signless(64) + arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] + init_tma_desc_type = ir.FunctionType.get(arg_tys, []) + func.FuncOp( + "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" + ) + memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) + func.FuncOp( + "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" + ) + + +def as_gpu_kernel( + body, + grid: tuple[int, int, int], + block: tuple[int, int, int], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, + cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", +): + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + + module, out_shape, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, + module_name, prof_spec + ) + ) + + expected_arg_treedef = jax.tree.structure(in_shape) + def _check_args(*args): + arg_treedef = jax.tree.structure(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({args=})" + ) + + module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + def bind(*args): + return mosaic_gpu_p.bind( + *args, + out_types=out_shape, + module=module_asm, + ) + + if prof_spec is not None: + @jax.jit + def prof_kernel(*args): + _check_args(*args) + *results, prof_buffer = bind(*args) + def dump_profile(prof_buffer): + out_file = os.path.join( + os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), + f"{time.time_ns()}-trace.json", + ) + try: + with open(out_file, "x") as f: + prof_spec.dump(prof_buffer, f, grid=grid, block=block) + except FileExistsError: + pass # TODO: Retry + jax.debug.callback(dump_profile, prof_buffer) + return results[0] if unwrap_output_tuple else results + return prof_kernel + else: + @jax.jit + def kernel(*args): + _check_args(*args) + results = bind(*args) + return results[0] if unwrap_output_tuple else results + return kernel + + +def as_torch_gpu_kernel( + body, + grid: tuple[int, int, int], + block: tuple[int, int, int], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, + cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", +): + try: + import torch + except ImportError: + raise RuntimeError("as_torch_gpu_kernel requires PyTorch") + torch.cuda.init() # Make sure CUDA context is set up. + + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + + flat_out_types, out_treedef = jax.tree.flatten(out_shape) + expected_arg_treedef = jax.tree.structure(in_shape) + + module, out_shape, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, + module_name, prof_spec + ) + ) + + # Get our hands on the compilation and unload functions + try: + import jax_plugins.xla_cuda12 as cuda_plugin + except ImportError: + raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds " + "that use backend plugins") + dll = ctypes.CDLL(cuda_plugin._get_library_path()) + compile_func = dll.MosaicGpuCompile + compile_func.argtypes = [ctypes.c_void_p] + compile_func.restype = ctypes.POINTER(ctypes.c_void_p) + unload_func = dll.MosaicGpuUnload + unload_func.argtypes = [compile_func.restype] + unload_func.restype = None + + module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + compiled = compile_func(ctypes.c_char_p(module_asm)) + if compiled is None: + raise RuntimeError("Failed to compile the module") + ctx, launch_ptr = compiled[0], compiled[1] + ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx)) + launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr) + + def as_torch_dtype(dtype): + # torch contains NumPy-compatible dtypes in its top namespace + return getattr(torch, np.dtype(dtype).name) + + def apply(*args): + flat_args, arg_treedef = jax.tree.flatten(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({args=})" + ) + + # Construct a device pointer list like in the XLA calling convention + buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))() + i = -1 # Define i in case there are no args + device = 'cuda' + for i, arg in enumerate(flat_args): + buffers[i] = arg.data_ptr() + device = arg.device + flat_outs = [] + for i, t in enumerate(flat_out_types, i + 1): + out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device) + flat_outs.append(out) + buffers[i] = out.data_ptr() + # Allocate another buffer for args of the host-side program. This is sadly + # the default MLIR calling convention. + args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)() + args_ptr[0] = ctx_ptr_ptr + args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_) + args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)), + ctypes.POINTER(ctypes.c_void_p)) + launch(args_ptr) + return jax.tree.unflatten(out_treedef, flat_outs) + + # Unload the compiled code when the Python function is destroyed. + def unload(_): + unload_func(compiled) + apply.destructor = weakref.ref(apply, unload) + + return apply diff --git a/jax/experimental/mosaic/gpu/dsl.py b/jax/experimental/mosaic/gpu/dsl.py deleted file mode 100644 index a12e5bc18803..000000000000 --- a/jax/experimental/mosaic/gpu/dsl.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2024 The JAX Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from . import ( - Barrier, - ClusterBarrier, - TMABarrier, - Union, -) -from .fragmented_array import ( - FragmentedArray, - FragmentedLayout, - WGMMA_LAYOUT, - WGMMA_ROW_LAYOUT, - WGStridedFragLayout, -) -from .utils import ( - BarrierRef, - CollectiveBarrierRef, - DynamicSlice, - Partition, - Partition1D, - bytewidth, - c, - commit_shared, - debug_print, - ds, - fori, - memref_fold, - memref_slice, - memref_transpose, - memref_unfold, - memref_unsqueeze, - single_thread, - thread_idx, - tile_shape, - warp_idx, - warpgroup_barrier, - warpgroup_idx, - when, -) -from .wgmma import ( - WGMMAAccumulator, - WGMMALayout, - wgmma, -) diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index a9a533ca361c..99586875ae90 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -24,9 +24,8 @@ from jax import random from jax._src.interpreters import mlir from jax._src import test_util as jtu -from jax.experimental.mosaic import gpu as mosaic_gpu from jax.experimental.mosaic.gpu import profiler -from jax.experimental.mosaic.gpu.dsl import * # noqa: F403 +from jax.experimental.mosaic.gpu import * # noqa: F403 import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -144,7 +143,7 @@ def c(value, ty=index): return _utils_c(value, ty) def tma_wg_kernel( - ctx: mosaic_gpu.LaunchContext, + ctx: LaunchContext, q_gmem, k_gmem, v_gmem, @@ -190,7 +189,7 @@ def only_wg(idx): ctx.async_copy( src_ref=q_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), dst_ref=qo_smem, barrier=q_barriers[wg_idx], swizzle=128, @@ -294,7 +293,7 @@ def kv_loop(kv_step, carry): src_ref=qo_smem, dst_ref=out_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), swizzle=128, ) ctx.await_async_copy(0) @@ -304,10 +303,9 @@ def kv_loop(kv_step, carry): nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease) with single_thread(per_block=False): k_tr = ( - mosaic_gpu.TileTransform(tiling), - mosaic_gpu.TransposeTransform((0, 2, 1, 3, 4)), + TileTransform(tiling), TransposeTransform((0, 2, 1, 3, 4)), ) - v_tr = mosaic_gpu.TileTransform(tiling) + v_tr = TileTransform(tiling) kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform): ctx.async_copy( @@ -350,7 +348,7 @@ def _kv_loop_memory(i, _): scf.yield_([]) def compute_only_kernel( - ctx: mosaic_gpu.LaunchContext, + ctx: LaunchContext, q_gmem, k_gmem, v_gmem, @@ -388,7 +386,7 @@ def only_wg(idx): ctx.async_copy( src_ref=q_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), dst_ref=qo_smem, barrier=barriers[q_barrier], swizzle=128, @@ -401,10 +399,10 @@ def kv_copy_init(slot, kv_seq_base): txcount = 2 * blocks.kv * head_dim * bytewidth(f16) barriers[slot].arrive_expect_tx(txcount) k_tr = ( - mosaic_gpu.TileTransform(tiling), - mosaic_gpu.TransposeTransform((0, 2, 1, 3, 4)), + TileTransform(tiling), + TransposeTransform((0, 2, 1, 3, 4)), ) - v_tr = mosaic_gpu.TileTransform(tiling) + v_tr = TileTransform(tiling) for smem, gmem, t in ((k_smem, k_gmem, k_tr), (v_smem, v_gmem, v_tr)): ctx.async_copy( dst_ref=memref_slice(smem, slot), @@ -526,7 +524,7 @@ def kv_loop(kv_step, carry): src_ref=qo_smem, dst_ref=out_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), swizzle=128, ) ctx.await_async_copy(0) @@ -551,7 +549,7 @@ def kv_loop(kv_step, carry): Barrier(arrival_count=256, num_barriers=2), Barrier(arrival_count=256, num_barriers=1), ) - return mosaic_gpu.as_gpu_kernel( + return as_gpu_kernel( kernel, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec ) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 775b7c2ea898..c56c5cd6b982 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -22,17 +22,14 @@ import jax from jax import random from jax._src.interpreters import mlir -from jax.experimental.mosaic import gpu as mosaic_gpu from jax.experimental.mosaic.gpu import profiler -from jax.experimental.mosaic.gpu.dsl import * # noqa: F403 +from jax.experimental.mosaic.gpu import * # noqa: F403 import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu -from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import scf -from jaxlib.mlir.dialects import vector import numpy as np # mypy: ignore-errors @@ -190,7 +187,7 @@ def safe_div(x, y): wgmma_impl.smem_shape_extra(block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose), ) epilogue_scratch_shape = jax.ShapeDtypeStruct(out_tile.shape, out_tile.dtype) - smem_shape = mosaic_gpu.Union([compute_scratch_shape, epilogue_scratch_shape]) + smem_shape = Union([compute_scratch_shape, epilogue_scratch_shape]) def _main(ctx, a_device, b_device, c_device, smem): ((lhs_smem, rhs_smem, impl_smem), epilogue_smem), *barriers = smem @@ -218,15 +215,15 @@ def fetch(slot, ki): src_ref=a_device, dst_ref=memref_slice(lhs_smem, slot), gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)), - gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk), + gmem_transform=TileTransform(tma_tiling.mk), collective=(gpu.Dimension.x, gpu.Dimension.z), **common_copy_args, ) rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n)) - rhs_transform = (mosaic_gpu.TileTransform(tma_tiling.kn),) + rhs_transform = (TileTransform(tma_tiling.kn),) if rhs_transpose: rhs_slice = rhs_slice[::-1] - rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform += (TransposeTransform((1, 0, 2, 3)),) assert tma_tiling.n == tma_tiling.k, block_tiling # No need to flip the tiling. ctx.async_copy( src_ref=b_device, @@ -292,7 +289,7 @@ def stage_loop_body(ki, accs): src_ref=epilogue_smem, dst_ref=c_device, gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)), - gmem_transform=mosaic_gpu.TileTransform(out_tiling), + gmem_transform=TileTransform(out_tiling), swizzle=out_swizzle, ) ctx.await_async_copy(0) @@ -304,7 +301,7 @@ def stage_loop_body(ki, accs): f" {grid_tile_n=})" ) cluster = (cluster_tile_n, cluster_m, cluster_n // cluster_tile_n) - return mosaic_gpu.as_gpu_kernel( + return as_gpu_kernel( _main, grid, block, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 502373bdc91e..d5a6e9eb69d1 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -29,7 +29,7 @@ from jaxlib.mlir.dialects import vector import numpy as np -from . import dsl as mgpu +import jax.experimental.mosaic.gpu as mgpu from . import utils # mypy: ignore-errors diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index b64418022d0e..5b0282080c55 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -21,13 +21,11 @@ import jax from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith -from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import llvm -from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import vector import numpy as np -from . import dsl as mgpu +import jax.experimental.mosaic.gpu as mgpu from . import utils # mypy: ignore-errors diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1a29bbb5736d..9f2f1222bde4 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -43,8 +43,7 @@ class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok y = 1 z = 2 else: - from jax.experimental.mosaic import gpu as mosaic_gpu - from jax.experimental.mosaic.gpu import dsl as mgpu + import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import utils as utils from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu.utils import * # noqa: F403 @@ -171,14 +170,14 @@ def test_copy_basic(self): def kernel(ctx, src, dst, _): copy(src, dst) x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) np.testing.assert_array_equal(y, x) def test_copy_swizzle(self): def kernel(ctx, src, dst, _): copy(src, dst, swizzle=128) x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) expected = np.zeros_like(y) for i in range(8): for j in range(8): @@ -192,7 +191,7 @@ def kernel(ctx, src, dst, smem): copy(src, smem, swizzle=128) copy(smem, dst, swizzle=128) x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) np.testing.assert_array_equal(y, x) def test_iota_tensor(self): @@ -209,7 +208,7 @@ def kernel(ctx, dst, _): reg, dst, [gpu.thread_id(gpu.Dimension.x), c(2 * i + j, index)] ) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) - regs = mosaic_gpu.as_gpu_kernel( + regs = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() thread_ids = np.arange(128) @@ -248,7 +247,7 @@ def kernel(ctx, inp, out, _): out_shape = list(x.shape) out_shape.insert(dim, 1) out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) np.testing.assert_array_equal(y, x.reshape(out_shape)) @@ -276,7 +275,7 @@ def kernel(ctx, inp, out, _): out_shape = list(in_shape) out_shape[dim:dim + 1] = [2, 2, out_shape[dim] // 4] out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) @@ -290,7 +289,7 @@ def kernel(ctx, inp, out, _): x = np.arange(8 * 2 * 8, dtype=jnp.float32).reshape(8, 2, 8) out_ty = jax.ShapeDtypeStruct((16, 8) if dim == 0 else (8, 16), jnp.float32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) @@ -329,7 +328,7 @@ def kernel(ctx, inp, out, _): copy(memref_fold(memref_slice(inp, index), dim, fold_rank), out) out = np_fold(np_inp[index], dim, fold_rank) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), np_inp, out, () )(np_inp) assert ( @@ -371,7 +370,7 @@ def kernel(ctx, out, _): del ctx iota_tensor(64, 64, mlir_dtype).store_untiled(out) expected = np.arange(64 * 64, dtype=jax_dtype).reshape(64, 64) - iota = mosaic_gpu.as_gpu_kernel( + iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() np.testing.assert_array_equal(iota, expected) @@ -387,7 +386,7 @@ def kernel(ctx, out, _): del ctx mgpu.FragmentedArray.splat(c(1., mlir_dtype), (size,)).store_untiled(out) expected = np.ones((size,), jax_dtype) - mosaic_ones = mosaic_gpu.as_gpu_kernel( + mosaic_ones = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() np.testing.assert_array_equal(mosaic_ones, expected) @@ -419,7 +418,7 @@ def kernel(ctx, out, smem): .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) .transpose(0, 2, 1, 3) ) - iota = mosaic_gpu.as_gpu_kernel( + iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, expected )() np.testing.assert_array_equal(iota, expected) @@ -445,12 +444,12 @@ def kernel(ctx, out, smem): dst_ref=out, swizzle=swizzle, gmem_slice=(ds(0, m), ds(0, col_tiling)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=mgpu.TileTransform(tiling), ) ctx.await_async_copy(0) smem_shape = jax.ShapeDtypeStruct((m // tiling[0], 1, *tiling), jax_dtype) expected = np.arange(m * n, dtype=jax_dtype).reshape(m, n) - iota = mosaic_gpu.as_gpu_kernel( + iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, smem_shape )() np.testing.assert_array_equal(iota, expected) @@ -493,7 +492,7 @@ def kernel(ctx, inp, out, smem): expected_from = expected(jax_dtype_from, from_tiling) expected_to = expected(jax_dtype_to, to_tiling) - res = mosaic_gpu.as_gpu_kernel( + res = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), @@ -525,7 +524,7 @@ def kernel(ctx, in_, out, smem): .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) .transpose(0, 2, 1, 3) ) - iota = mosaic_gpu.as_gpu_kernel( + iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), expected, expected, (expected,) * 2 )(expected) np.testing.assert_array_equal(iota, expected) @@ -593,13 +592,13 @@ def test_wgmma_basic( def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers = scratch if tma_inputs: - lhs_transform = (mosaic_gpu.TileTransform((64, nk_tile)),) + lhs_transform = (mgpu.TileTransform((64, nk_tile)),) if lhs_transpose: assert nk_tile == 64 # Make sure we didn't have to transpose tiling. - lhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) - rhs_transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),) + lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),) if rhs_transpose: - rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( src_ref=lhs, dst_ref=lhs_smem, @@ -666,7 +665,7 @@ def quantize(x): ), mgpu.TMABarrier(2), ] - z = mosaic_gpu.as_gpu_kernel( + z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape )(x, y) x32, y32 = x.astype(np.float32), y.astype(np.float32) @@ -724,7 +723,7 @@ def kernel(ctx, rhs, out, rhs_smem): scratch_shape = jax.ShapeDtypeStruct( (k_steps, n // nk_tile, nk_tile, nk_tile), jax_dtype ) - z = mosaic_gpu.as_gpu_kernel( + z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape )(y) x = np.arange(m * k, dtype=jax_dtype).reshape(m, k) @@ -752,11 +751,11 @@ def kernel(ctx, rhs, out, smem): rhs_smem, barrier = smem gmem_slice = (ds(0, k), ds(0, nk_tile)) smem_slice = (slice(None), slice(None), slice(None), ds(0, n)) - transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),) + transform = (mgpu.TileTransform((nk_tile, nk_tile)),) if rhs_transpose: gmem_slice = gmem_slice[::-1] smem_slice = (slice(None), slice(None), ds(0, n), slice(None)) - transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( src_ref=rhs, dst_ref=rhs_smem, @@ -781,7 +780,7 @@ def kernel(ctx, rhs, out, smem): rhs_scratch_shape = jax.ShapeDtypeStruct( (k_steps, 1, nk_tile, nk_tile), jax_dtype ) - z = mosaic_gpu.as_gpu_kernel( + z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), y, out_shape, (rhs_scratch_shape, mgpu.TMABarrier()), )(y) x = np.arange(m * k, dtype=jax_dtype).reshape(m, k) @@ -823,7 +822,7 @@ def kernel(ctx, dst, scratch): final_arr.store_untiled(memref_slice(dst, 1)) scf.yield_([]) out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (2 * 128, 1, 1), @@ -892,7 +891,7 @@ def kernel(ctx, dst, mask, collective_barrier): if group_dims: barrier_dims = (collective_dims[:2], *collective_dims[2:]) scratch = mgpu.ClusterBarrier(barrier_dims) - y, mask = mosaic_gpu.as_gpu_kernel( + y, mask = mgpu.as_gpu_kernel( kernel, cluster, (128, 1, 1), (), (out_shape, mask_shape), scratch, cluster=cluster, )() np.testing.assert_array_equal( @@ -931,7 +930,7 @@ def kernel(ctx, src, dst, smem): copy(tmp, dst, swizzle=swizzle) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) smem = (x, mgpu.TMABarrier()) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) np.testing.assert_array_equal(y, x) @parameterized.named_parameters( @@ -1009,7 +1008,7 @@ def kernel(ctx, src, dst, scratch): ) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) smem_shape = (jax.ShapeDtypeStruct(shape[1:], dtype), mgpu.TMABarrier()) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, cluster, (128, 1, 1), x, x, smem_shape, cluster=cluster )(x) np.testing.assert_array_equal(y, x) @@ -1033,7 +1032,7 @@ def kernel(ctx, src, dst, scratch): dst_ref=tmp, swizzle=swizzle, barrier=barrier, - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=mgpu.TileTransform(tiling), ) barrier.wait_parity(c(0, i1)) for idxs in np.ndindex(tiled_shape): @@ -1048,7 +1047,7 @@ def kernel(ctx, src, dst, scratch): jax.ShapeDtypeStruct(tile_shape(shape, tiling), dtype), mgpu.TMABarrier(), ) - f = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem) y = f(x) np.testing.assert_array_equal(y, x) @@ -1075,7 +1074,7 @@ def kernel(ctx, src, dst, smem): copy(tmp, dst, swizzle=swizzle) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) smem = (x, mgpu.TMABarrier()) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) np.testing.assert_array_equal(y, x) def test_parity_tracking(self): @@ -1091,7 +1090,7 @@ def kernel(ctx, src, dst, smem): barrier.wait() copy(tmp, memref_slice(dst, s)) x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, x, (x[0:1], mgpu.TMABarrier()) )(x) np.testing.assert_array_equal(y, x) @@ -1109,7 +1108,7 @@ def kernel(ctx, src, dst, tmp): ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle) ctx.await_async_copy(0) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) np.testing.assert_array_equal(y, x) @parameterized.parameters(0, 1) @@ -1128,7 +1127,7 @@ def kernel(ctx, src, dst, smem): src_ref=src, dst_ref=tmp, swizzle=128, - gmem_transform=mosaic_gpu.TileTransform((64, 64)), + gmem_transform=mgpu.TileTransform((64, 64)), gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])), barrier=barrier, ) @@ -1136,7 +1135,7 @@ def kernel(ctx, src, dst, smem): copy(tmp, dst, swizzle=128) x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape) tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16) - y_tiled = mosaic_gpu.as_gpu_kernel( + y_tiled = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, tiled, (tiled, mgpu.TMABarrier()), )(x) y = y_tiled.swapaxes(1, 2).reshape(padded_shape) @@ -1165,13 +1164,13 @@ def kernel(ctx, dst, tmp): src_ref=tmp, dst_ref=dst, swizzle=128, - gmem_transform=mosaic_gpu.TileTransform((64, 64)), + gmem_transform=mgpu.TileTransform((64, 64)), gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])), ) ctx.await_async_copy(0) tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16) out = jax.ShapeDtypeStruct(shape, jnp.float16) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out, tiled, )() iota = np.arange(np.prod(padded_shape), dtype=jnp.float16).reshape( @@ -1187,7 +1186,7 @@ def kernel(ctx, src, dst, tmp): def run_kernel(shape): x = np.arange(np.prod(shape)).reshape(shape) - _ = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) with self.assertRaisesRegex(ValueError, "only support striding up to 5"): run_kernel([1] * 6) @@ -1224,7 +1223,7 @@ def kernel(ctx, dst, _): rhs = iota if scalar_rhs is None else c(scalar_rhs, iota.mlir_dtype) op(iota, rhs).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() ref_x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) @@ -1257,7 +1256,7 @@ def kernel(ctx, dst, _): iota = iota_tensor(m=m, n=n, mlir_dtype=f32) op(iota).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) @@ -1276,7 +1275,7 @@ def kernel(ctx, dst, _): iota = iota_tensor(m=m, n=n, mlir_dtype=f32) iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) @@ -1298,7 +1297,7 @@ def kernel(ctx, dst, _): cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n)) (iota + cte_arr).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() expected = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + 1 @@ -1311,7 +1310,7 @@ def kernel(ctx, dst, _): t = mgpu.FragmentedArray.splat(v, (128,), mgpu.WGMMA_ROW_LAYOUT) t.broadcast_minor(32).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32)) @@ -1326,7 +1325,7 @@ def kernel(ctx, *args): copy(smem_output, gmem_output) inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], )(inp) np.testing.assert_array_equal(inp, result) @@ -1341,7 +1340,7 @@ def kernel(ctx, out, *_): memref.store(grp, out, [tid]) x = np.arange(128, dtype=jnp.int32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), x, [], )() for i in range(0, 128, 4): @@ -1364,7 +1363,7 @@ def kernel(ctx, inp, out, smem): x = jnp.arange(-128, 128, dtype=jax_dtype_from) reference = x.astype(jax_dtype_to) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, reference, None, )(x) np.testing.assert_array_equal(result, reference) @@ -1382,7 +1381,7 @@ def test_multigpu(self): def kernel(ctx, src, dst, _): mgpu.FragmentedArray.load_strided(src).store_untiled(dst) x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) - f = jax.jit(mosaic_gpu.as_gpu_kernel( + f = jax.jit(mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, x, () )) # Make sure we can invoke the same program on different devices. @@ -1407,7 +1406,7 @@ def kernel(ctx, i_gmem, o_gmem, _): ty = jax.ShapeDtypeStruct((128, 128), jnp.float32) x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda') - f = mosaic_gpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ()) + f = mgpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ()) y = f(x) np.testing.assert_allclose(y.cpu(), x.cpu() * 2) del y # Make sure the destructor runs successfully. From 71450cad5689e6466583b0d705be766397fec21c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 20 Sep 2024 09:14:38 -0700 Subject: [PATCH 586/702] Add docstrings for jnp.blackman, jnp.bartlett, jnp.hamming, jnp.hanning, jnp.kaiser Part of https://github.com/jax-ml/jax/issues/21461 PiperOrigin-RevId: 676866721 --- jax/_src/numpy/lax_numpy.py | 111 ++++++++++++++++++++++++++++++++++-- 1 file changed, 106 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 5b936268581a..716c764ee074 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -10328,8 +10328,28 @@ def clamp_index(i: DimSize, which: str): return start, step, slice_size -@util.implements(np.blackman) def blackman(M: int) -> Array: + """Return a Blackman window of size M. + + JAX implementation of :func:`numpy.blackman`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Blackman window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.blackman(4)) + [-0. 0.63 0.63 -0. ] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.blackman") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -10338,8 +10358,28 @@ def blackman(M: int) -> Array: return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1)) -@util.implements(np.bartlett) def bartlett(M: int) -> Array: + """Return a Bartlett window of size M. + + JAX implementation of :func:`numpy.bartlett`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Bartlett window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.bartlett(4)) + [0. 0.67 0.67 0. ] + + See also: + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.bartlett") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -10348,8 +10388,28 @@ def bartlett(M: int) -> Array: return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) -@util.implements(np.hamming) def hamming(M: int) -> Array: + """Return a Hamming window of size M. + + JAX implementation of :func:`numpy.hamming`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Hamming window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.hamming(4)) + [0.08 0.77 0.77 0.08] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.hamming") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -10358,8 +10418,28 @@ def hamming(M: int) -> Array: return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1)) -@util.implements(np.hanning) def hanning(M: int) -> Array: + """Return a Hanning window of size M. + + JAX implementation of :func:`numpy.hanning`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Hanning window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.hanning(4)) + [0. 0.75 0.75 0. ] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.hanning") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -10368,8 +10448,29 @@ def hanning(M: int) -> Array: return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1))) -@util.implements(np.kaiser) def kaiser(M: int, beta: ArrayLike) -> Array: + """Return a Kaiser window of size M. + + JAX implementation of :func:`numpy.kaiser`. + + Args: + M: The window size. + beta: The Kaiser window parameter. + + Returns: + An array of size M containing the Kaiser window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.kaiser(4, 1.5)) + [0.61 0.95 0.95 0.61] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.kaiser") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: From 339db2b4337736ca3e04ccd0349d98543410d956 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 20 Sep 2024 12:35:23 -0400 Subject: [PATCH 587/702] Format MLIR dump names with leading zeros. This means the dumps sort in order in a directory listing. --- jax/_src/interpreters/mlir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index c4c77c72b88b..5b9b71e4865f 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -537,7 +537,7 @@ def dump_module_to_file(module: ir.Module, stage_name: str) -> str | None: sym_name = module.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value - name = f"jax_ir{id}_{_make_string_safe_for_filename(module_name)}_{stage_name}.mlir" + name = f"jax_ir{id:04d}_{_make_string_safe_for_filename(module_name)}_{stage_name}.mlir" out_dir = path.Path(out_dir_name) out_dir.mkdir(parents=True, exist_ok=True) From 81e50118cfe868b11d6b8b7e126d6d94d00cd441 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 20 Sep 2024 22:19:31 +0530 Subject: [PATCH 588/702] Better doc for jax.numpy.i0 --- jax/_src/numpy/lax_numpy.py | 46 ++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a1601e9201fe..6e1834c0e92e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5973,19 +5973,53 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, return output -@custom_jvp -@util.implements(np.i0) @jit def i0(x: ArrayLike) -> Array: + r"""Calculate modified Bessel function of first kind, zeroth order. + + JAX implementation of :func:`numpy.i0`. + + Modified Bessel function of first kind, zeroth order is defined by: + + .. math:: + + \mathrm{i0}(x) = I_0(x) = \sum_{k=0}^{\infty} \frac{(x^2/4)^k}{(k!)^2} + + Args: + x: scalar or array. Specifies the argument of Bessel function. Complex inputs + are not supported. + + Returns: + An array containing the corresponding vlaues of the modified Bessel function + of ``x``. + + See also: + - :func:`jax.scipy.special.i0`: Calculates the modified Bessel function of + zeroth order. + - :func:`jax.scipy.special.i1`: Calculates the modified Bessel function of + first order. + - :func:`jax.scipy.special.i0e`: Calculates the exponentially scaled modified + Bessel function of zeroth order. + + Examples: + >>> x = jnp.array([-2, -1, 0, 1, 2]) + >>> jnp.i0(x) + Array([2.2795851, 1.266066 , 1.0000001, 1.266066 , 2.2795851], dtype=float32) + """ x_arr, = util.promote_args_inexact("i0", x) if not issubdtype(x_arr.dtype, np.floating): raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}") - x_arr = lax.abs(x_arr) - return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr)) + return _i0(x_arr) + + +@custom_jvp +def _i0(x): + abs_x = lax.abs(x) + return lax.mul(lax.exp(abs_x), lax.bessel_i0e(abs_x)) -@i0.defjvp +@_i0.defjvp def _i0_jvp(primals, tangents): - primal_out, tangent_out = jax.jvp(i0.fun, primals, tangents) + primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) def ix_(*args: ArrayLike) -> tuple[Array, ...]: From 0c87a23a265f270b206efe4a6f144aac5e64f442 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 20 Sep 2024 22:22:17 +0530 Subject: [PATCH 589/702] Improve docs for jax.numpy: deg2rad, rad2deg, degrees, radians --- jax/_src/numpy/ufuncs.py | 68 +++++++++++++++++++++++++++++++++++++--- tests/lax_numpy_test.py | 4 +-- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index b45b3370fe53..5682c3c72a49 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2016,22 +2016,82 @@ def square(x: ArrayLike, /) -> Array: return lax.integer_pow(x, 2) -@implements(np.deg2rad, module='numpy') @partial(jit, inline=True) def deg2rad(x: ArrayLike, /) -> Array: + r"""Convert angles from degrees to radians. + + JAX implementation of :obj:`numpy.deg2rad`. + + The angle in degrees is converted to radians by: + + .. math:: + + deg2rad(x) = x * \frac{pi}{180} + + Args: + x: scalar or array. Specifies the angle in degrees. + + Returns: + An array containing the angles in radians. + + See also: + - :func:`jax.numpy.rad2deg` and :func:`jax.numpy.degrees`: Converts the angles + from radians to degrees. + - :func:`jax.numpy.radians`: Alias of ``deg2rad``. + + Examples: + >>> x = jnp.array([60, 90, 120, 180]) + >>> jnp.deg2rad(x) + Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32) + >>> x * jnp.pi / 180 + Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32, weak_type=True) + """ x, = promote_args_inexact("deg2rad", x) return lax.mul(x, _lax_const(x, np.pi / 180)) -@implements(np.rad2deg, module='numpy') @partial(jit, inline=True) def rad2deg(x: ArrayLike, /) -> Array: + r"""Convert angles from radians to degrees. + + JAX implementation of :obj:`numpy.rad2deg`. + + The angle in radians is converted to degrees by: + + .. math:: + + rad2deg(x) = x * \frac{180}{pi} + + Args: + x: scalar or array. Specifies the angle in radians. + + Returns: + An array containing the angles in degrees. + + See also: + - :func:`jax.numpy.deg2rad` and :func:`jax.numpy.radians`: Converts the angles + from degrees to radians. + - :func:`jax.numpy.degrees`: Alias of ``rad2deg``. + + Examples: + >>> pi = jnp.pi + >>> x = jnp.array([pi/4, pi/2, 2*pi/3]) + >>> jnp.rad2deg(x) + Array([ 45. , 90. , 120.00001], dtype=float32) + >>> x * 180 / pi + Array([ 45., 90., 120.], dtype=float32) + """ x, = promote_args_inexact("rad2deg", x) return lax.mul(x, _lax_const(x, 180 / np.pi)) -degrees = rad2deg -radians = deg2rad +def degrees(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.rad2deg`""" + return rad2deg(x) + +def radians(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.deg2rad`""" + return deg2rad(x) @implements(np.conjugate, module='numpy') diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ddf42a28e2ba..b85966fa6e4a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6308,8 +6308,8 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', - 'amax', 'amin', 'around', 'bitwise_right_shift', 'divide', 'pow', - 'round_'] + 'amax', 'amin', 'around', 'bitwise_right_shift', 'degrees', 'divide', + 'pow', 'radians', 'round_'] skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split'] for name in dir(jnp): From 629be0b701bf04d163b2270ee81a99c50f4cb0cc Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 20 Sep 2024 10:02:53 -0700 Subject: [PATCH 590/702] Tighten test tolerances after the underlying issue causing nondeterministic results for _nrm2 in Eigen BLAS was fixed in https://gitlab.com/libeigen/eigen/-/merge_requests/1667 -> cl/663346025 PiperOrigin-RevId: 676881791 --- tests/lax_scipy_test.py | 2 -- tests/linalg_test.py | 74 ++++++++++++++++++++++++----------------- 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 8a8b1dd42c35..4840972e9483 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -523,8 +523,6 @@ def testPolar( tol = 650 * float(jnp.finfo(matrix.dtype).eps) eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype) with self.subTest('Test unitarity.'): - if jtu.test_device_matches(["cpu"]): - tol = max(tol, 1e-8) self.assertAllClose( eye_mat, should_be_eye, atol=tol * 1000 * min(shape)) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 15963b10b6e2..e52582eb7526 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -269,10 +269,7 @@ def check_left_eigenvectors(a, w, vl): if compute_right_eigenvectors: check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors]) - # TODO(phawkins): we are seeing nondeterminism in LAPACK routines with - # avx enabled, because for Eigen BLAS nrm2 has an alignment dependence. - # self._CompileAndCheck(partial(jnp.linalg.eig), args_maker, - # rtol=1e-3) + self._CompileAndCheck(partial(jnp.linalg.eig), args_maker, rtol=1e-3) @jtu.sample_product( shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)], @@ -1399,7 +1396,6 @@ def testBlockDiag(self, args): args_maker, check_dtypes=False) self._CompileAndCheck(jsp.linalg.block_diag, args_maker) - @jtu.sample_product( shape=[(1, 1), (4, 5), (10, 5), (50, 50)], dtype=float_types + complex_types, @@ -1782,7 +1778,6 @@ def sp_func(a): self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-4, atol=1e-4, check_dtypes=False) - @jtu.sample_product( n=[1, 4, 5, 20, 50, 100], dtype=float_types + complex_types, @@ -1818,7 +1813,6 @@ def args_maker(): self._CheckAgainstNumpy(osp.linalg.cho_solve, jsp.linalg.cho_solve, args_maker, tol=1e-3) - @jtu.sample_product( n=[1, 4, 5, 20, 50, 100], dtype=float_types + complex_types, @@ -1844,13 +1838,13 @@ def args_maker(): e = rng((n, n), dtype) return [a, e, ] - #compute_expm is True + # compute_expm is True osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=True) jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=True) self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False) - #compute_expm is False + # compute_expm is False osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=False) jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=False) self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, @@ -1886,18 +1880,31 @@ def expm(x): jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol, rtol=tol) + @jtu.sample_product( + shape=[(4, 4), (15, 15), (50, 50), (100, 100)], + dtype=float_types + complex_types, + ) + @jtu.run_on_devices("cpu") + def testSchur(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(osp.linalg.schur, jsp.linalg.schur, args_maker) + self._CompileAndCheck(jsp.linalg.schur, args_maker) + @jtu.sample_product( shape=[(1, 1), (4, 4), (15, 15), (50, 50), (100, 100)], dtype=float_types + complex_types, ) @jtu.run_on_devices("cpu") def testRsf2csf(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] - tol = 3e-5 - self._CheckAgainstNumpy(osp.linalg.rsf2csf, jsp.linalg.rsf2csf, - args_maker, tol=tol) - self._CompileAndCheck(jsp.linalg.rsf2csf, args_maker) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] + tol = 3e-5 + self._CheckAgainstNumpy( + osp.linalg.rsf2csf, jsp.linalg.rsf2csf, args_maker, tol=tol + ) + self._CompileAndCheck(jsp.linalg.rsf2csf, args_maker) @jtu.sample_product( shape=[(1, 1), (5, 5), (20, 20), (50, 50)], @@ -1908,17 +1915,22 @@ def testRsf2csf(self, shape, dtype): # backend only, so tests on GPU and TPU backends are skipped here @jtu.run_on_devices("cpu") def testFunm(self, shape, dtype, disp): - def func(x): - return x**-2.718 - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - jnp_fun = lambda arr: jsp.linalg.funm(arr, func, disp=disp) - scp_fun = lambda arr: osp.linalg.funm(arr, func, disp=disp) - self._CheckAgainstNumpy( - jnp_fun, scp_fun, args_maker, check_dtypes=False, - tol={np.float32: 2e-3,np.complex64: 2e-3, np.complex128: 1e-6}) - # TODO(phawkins): nondeterminism due to alignment. - # self._CompileAndCheck(jnp_fun, args_maker, atol=2e-5) + + def func(x): + return x**-2.718 + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + jnp_fun = lambda arr: jsp.linalg.funm(arr, func, disp=disp) + scp_fun = lambda arr: osp.linalg.funm(arr, func, disp=disp) + self._CheckAgainstNumpy( + jnp_fun, + scp_fun, + args_maker, + check_dtypes=False, + tol={np.complex64: 1e-5, np.complex128: 1e-6}, + ) + self._CompileAndCheck(jnp_fun, args_maker, atol=2e-5) @jtu.sample_product( shape=[(4, 4), (15, 15), (50, 50), (100, 100)], @@ -1933,9 +1945,9 @@ def testSqrtmPSDMatrix(self, shape, dtype): mat = arg @ arg.T args_maker = lambda : [mat] if dtype == np.float32 or dtype == np.complex64: - tol = 1e-4 + tol = 1e-4 else: - tol = 1e-8 + tol = 1e-8 self._CheckAgainstNumpy(osp.linalg.sqrtm, jsp.linalg.sqrtm, args_maker, @@ -2144,8 +2156,10 @@ def testSchur(self, shape, dtype): args = rng(shape, dtype) Ts, Ss = lax.linalg.schur(args) eps = np.finfo(dtype).eps - self.assertAllClose(args, Ss @ Ts @ jnp.conj(Ss.T), atol=eps * 600) - self.assertAllClose(np.eye(*shape, dtype=dtype), Ss @ jnp.conj(Ss.T), atol=eps * 100) + self.assertAllClose(args, Ss @ Ts @ jnp.conj(Ss.T), atol=600 * eps) + self.assertAllClose( + np.eye(*shape, dtype=dtype), Ss @ jnp.conj(Ss.T), atol=100 * eps + ) @jtu.sample_product( shape=[(2, 2), (4, 4), (15, 15), (50, 50), (100, 100)], From c88c3aecae03351b3541bac4bb0447c10b1e33a3 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Mon, 8 Jul 2024 05:08:25 +0000 Subject: [PATCH 591/702] add k8s cluster environment --- jax/BUILD | 1 + jax/_src/clusters/__init__.py | 1 + jax/_src/clusters/k8s_cluster.py | 124 +++++++++++++++++++++++++++++++ pyproject.toml | 1 + setup.py | 5 ++ 5 files changed, 132 insertions(+) create mode 100644 jax/_src/clusters/k8s_cluster.py diff --git a/jax/BUILD b/jax/BUILD index c6d8fe25af59..3809192a8295 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -945,6 +945,7 @@ pytype_strict_library( "_src/clusters/cloud_tpu_cluster.py", "_src/clusters/cluster.py", "_src/clusters/mpi4py_cluster.py", + "_src/clusters/k8s_cluster.py", "_src/clusters/ompi_cluster.py", "_src/clusters/slurm_cluster.py", "_src/distributed.py", diff --git a/jax/_src/clusters/__init__.py b/jax/_src/clusters/__init__.py index 73e4ac9412f7..9abb628f8ae3 100644 --- a/jax/_src/clusters/__init__.py +++ b/jax/_src/clusters/__init__.py @@ -25,3 +25,4 @@ from .mpi4py_cluster import Mpi4pyCluster from .cloud_tpu_cluster import GkeTpuCluster from .cloud_tpu_cluster import GceTpuCluster +from .k8s_cluster import K8sCluster diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py new file mode 100644 index 000000000000..1274724b8ebd --- /dev/null +++ b/jax/_src/clusters/k8s_cluster.py @@ -0,0 +1,124 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +from functools import cache +import os +import socket +import textwrap +import warnings +from jax._src import clusters + + +class K8sCluster(clusters.ClusterEnv): + + # Use an arbitrarily chosen port for the coordinator since we cannot + # rely on communication to choose one in real time. + _coordinator_port = '55527' + + @classmethod + def is_env_present(cls) -> bool: + if 'KUBERNETES_SERVICE_HOST' in os.environ: + try: + import kubernetes as k8s # pytype: disable=import-error + except ImportError as e: + warnings.warn(textwrap.fill( + "Kubernetes environment detected, but the `kubernetes` package is " + "not installed to enable automatic bootstrapping in this " + "environment. To enable automatic boostrapping, please install " + "jax with the [k8s] extra. For example:" + " pip install jax[k8s]" + " OR" + " pip install jax[k8s,]" + )) + return False + + k8s.config.load_incluster_config() + cls._core_api = k8s.client.CoreV1Api() + cls._batch_api = k8s.client.BatchV1Api() + cls._ApiException = k8s.client.exceptions.ApiException + return True + else: + return False + + @classmethod + @contextmanager + def _handle_api_exception(cls): + try: + yield + except cls._ApiException as e: + err_msg = [f"Kubernetes API Error: {e.status} - {e.reason}"] + if e.status == 403: + err_msg.append(textwrap.fill( + "It appears that the Kubernetes service account (SA) associated with " + "this job does not have the permission for pod introspection. Please " + "either grant the default SA permission to read pod info, or create a " + "dedicated service account with the permission and associated with " + "the job. For more details, see .", + width=80 + )) + raise RuntimeError('\n'.join(err_msg)) from e + + @classmethod + @cache + def _namespace(cls): + return open( + '/var/run/secrets/kubernetes.io/serviceaccount/namespace' + ).read().strip() + + @classmethod + @cache + def _pod(cls): + with cls._handle_api_exception(): + ip = socket.gethostbyname(os.getenv('HOSTNAME')) + pods = cls._core_api.list_namespaced_pod( + namespace=cls._namespace(), + field_selector=f'status.podIP={ip}' + ).items + assert len(pods) == 1, \ + f"Exactly 1 Kubernetes pod should have IP {ip}, got {len(pods)}." + return pods[0] + + @classmethod + @cache + def _job(cls): + with cls._handle_api_exception(): + return cls._batch_api.read_namespaced_job( + name=cls._pod().metadata.labels['job-name'], namespace=cls._namespace() + ) + + @classmethod + def get_coordinator_address(cls, timeout_secs: int | None) -> str: + return '{job_name}-0.{jobset_name}:{port}'.format( + job_name=cls._pod().metadata.labels['job-name'], + jobset_name=cls._job().metadata.labels['jobset.sigs.k8s.io/jobset-name'], + port=cls._coordinator_port + ) + + @classmethod + def get_process_count(cls) -> int: + # https://kubernetes.io/docs/concepts/workloads/controllers/job/#controlling-parallelism + return cls._job().spec.parallelism + + @classmethod + def get_process_id(cls) -> int: + # https://kubernetes.io/docs/concepts/workloads/controllers/job/#completion-mode + try: + return int(os.environ['JOB_COMPLETION_INDEX']) + except KeyError: + raise RuntimeError( + 'K8s job must be run with `completionMode: "Indexed"`.' + ) diff --git a/pyproject.toml b/pyproject.toml index 23135b95e126..a69adbfae2fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ module = [ "tensorstore.*", "web_pdb.*", "zstandard.*", + "kubernetes.*" ] ignore_missing_imports = true diff --git a/setup.py b/setup.py index e807ff3b0052..762b5ad7a281 100644 --- a/setup.py +++ b/setup.py @@ -103,6 +103,11 @@ def load_version_module(pkg_path): f"jaxlib=={_current_jaxlib_version}", f"jax-cuda12-plugin=={_current_jaxlib_version}", ], + + # For automatic bootstrapping distributed jobs in Kubernetes + 'k8s': [ + 'kubernetes', + ], }, url='https://github.com/jax-ml/jax', license='Apache-2.0', From 1acf9567aae0742ced26475e2fe4ec3b551a16bd Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 20 Sep 2024 11:24:36 -0700 Subject: [PATCH 592/702] Add get_replication to shard_map.py for verifying if an array is replicated. PiperOrigin-RevId: 676910872 --- jax/experimental/shard_map.py | 10 ++++++++++ tests/shard_map_test.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index f19401525cc0..35d665943792 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -2011,3 +2011,13 @@ def _match_replication(src, dst, x): if src - dst: x = pbroadcast(x, tuple(n for n in src if n not in dst)) return x + +# TODO(parkers,mattjj): change implementation when we have sharding-in-types. +def get_replication(x: jax.Array) -> set[AxisName]: + """For a jax.Array, return what axes it is known to be replicated along.""" + + if isinstance(x, RewriteTracer): + return x.rep + if isinstance(x, batching.BatchTracer): + return get_replication(x.val) + raise ValueError("get_replication not defined on %s" % repr(type(x))) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 20bc33475e14..fbe9746513f5 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2151,6 +2151,34 @@ def f(a): f(A()) # don't crash + def test_get_check_rep(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + + def f(x, reduce_along, use_jit): + out_spec = P(*(n for n in ('x', 'y') if n not in reduce_along)) + + @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=out_spec) + def g(x): + result = lax.psum(x, axis_name=reduce_along) + def check_rep(result): + self.assertEqual( + jax.experimental.shard_map.get_replication(result), + set(reduce_along)) + return result + result = check_rep(result) + result = jax.vmap(check_rep)(result) + return result + if use_jit: + return jax.jit(g)(x) + else: + return g(x) + + for use_jit in [True, False]: + x = np.zeros((8, 8), dtype=np.float32) + f(x, reduce_along=('y',), use_jit=use_jit) + f(x, reduce_along=('x',), use_jit=use_jit) + f(x, reduce_along=('x', 'y'), use_jit=use_jit) + class FunSpec(NamedTuple): name: str From 6a5553d6beaa30a98b4912e9e7a15968749ae2db Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Sat, 21 Sep 2024 00:09:42 +0530 Subject: [PATCH 593/702] Improve docs for jax.numpy: remainder, mod and fmod --- jax/_src/numpy/ufuncs.py | 75 ++++++++++++++++++++++++++++++++++++++-- tests/lax_numpy_test.py | 2 +- 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 4f491e7f9b49..857ed8668d59 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -1984,9 +1984,42 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) -@implements(np.remainder, module='numpy') @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Returns element-wise remainder of the division. + + JAX implementation of :obj:`numpy.remainder`. + + Args: + x1: scalar or array. Specifies the dividend. + x2: scalar or array. Specifies the divisor. ``x1`` and ``x2`` should either + have same shape or be broadcast compatible. + + Returns: + An array containing the remainder of element-wise division of ``x1`` by + ``x2`` with same sign as the elements of ``x2``. + + Note: + The result of ``jnp.remainder`` is equivalent to ``x1 - x2 * jnp.floor(x1 / x2)``. + + See also: + - :func:`jax.numpy.mod`: Returns the element-wise remainder of the division. + - :func:`jax.numpy.fmod`: Calculates the element-wise floating-point modulo + operation. + - :func:`jax.numpy.divmod`: Calculates the integer quotient and remainder of + ``x1`` by ``x2``, element-wise. + + Examples: + >>> x1 = jnp.array([[3, -1, 4], + ... [8, 5, -2]]) + >>> x2 = jnp.array([2, 3, -5]) + >>> jnp.remainder(x1, x2) + Array([[ 1, 2, -1], + [ 0, 2, -2]], dtype=int32) + >>> x1 - x2 * jnp.floor(x1 / x2) + Array([[ 1., 2., -1.], + [ 0., 2., -2.]], dtype=float32) + """ x1, x2 = promote_args_numeric("remainder", x1, x2) zero = _constant_like(x1, 0) if dtypes.issubdtype(x2.dtype, np.integer): @@ -1996,12 +2029,48 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) -mod = implements(np.mod, module='numpy')(remainder) -@implements(np.fmod, module='numpy') +def mod(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.remainder`""" + return remainder(x1, x2) + + @jit def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Calculate element-wise floating-point modulo operation. + + JAX implementation of :obj:`numpy.fmod`. + + Args: + x1: scalar or array. Specifies the dividend. + x2: scalar or array. Specifies the divisor. ``x1`` and ``x2`` should either + have same shape or be broadcast compatible. + + Returns: + An array containing the result of the element-wise floating-point modulo + operation of ``x1`` and ``x2`` with same sign as the elements of ``x1``. + + Note: + The result of ``jnp.fmod`` is equivalent to ``x1 - x2 * jnp.fix(x1 / x2)``. + + See also: + - :func:`jax.numpy.mod` and :func:`jax.numpy.remainder`: Returns the element-wise + remainder of the division. + - :func:`jax.numpy.divmod`: Calculates the integer quotient and remainder of + ``x1`` by ``x2``, element-wise. + + Examples: + >>> x1 = jnp.array([[3, -1, 4], + ... [8, 5, -2]]) + >>> x2 = jnp.array([2, 3, -5]) + >>> jnp.fmod(x1, x2) + Array([[ 1, -1, 4], + [ 0, 2, -2]], dtype=int32) + >>> x1 - x2 * jnp.fix(x1 / x2) + Array([[ 1., -1., 4.], + [ 0., 2., -2.]], dtype=float32) + """ check_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 70c9b503b895..a1d1a0292338 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6309,7 +6309,7 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', 'amax', 'amin', 'around', 'bitwise_right_shift', 'degrees', 'divide', - 'pow', 'radians', 'round_'] + 'mod', 'pow', 'radians', 'round_'] skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split'] for name in dir(jnp): From ca97af9d43697ddb48ea4ba39669540adbeb3432 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 20 Sep 2024 13:05:14 -0700 Subject: [PATCH 594/702] Change the default implementation of GeLU to a numerically stable formulation. The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation. PiperOrigin-RevId: 676944344 --- jax/_src/nn/functions.py | 14 ++++++++------ tests/nn_test.py | 15 +++++++++++---- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index c1f4831e5ec0..c81d51ea054b 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -430,8 +430,8 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Array: If ``approximate=False``, computes the element-wise function: .. math:: - \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( - \frac{x}{\sqrt{2}} \right) \right) + \mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( + \frac{-x}{\sqrt{2}} \right) \right) If ``approximate=True``, uses the approximate formulation of GELU: @@ -443,7 +443,7 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Array: `_, section 2. Args: - x : input array + x: input array approximate: whether to use the approximate or exact formulation. """ [x_arr] = numpy_util.promote_args_inexact("gelu", x) @@ -453,8 +453,10 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Array: cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x_arr + 0.044715 * (x_arr ** 3)))) return x_arr * cdf else: - sqrt_2 = np.sqrt(2).astype(x_arr.dtype) - return jnp.array(x_arr * (lax.erf(x_arr / sqrt_2) + 1) / 2, dtype=x_arr.dtype) + sqrt_half = np.sqrt(0.5).astype(x_arr.dtype) + return jnp.array( + 0.5 * x_arr * (lax.erfc(-x_arr * sqrt_half)), dtype=x_arr.dtype + ) @partial(jax.jit, static_argnames=("axis",)) def glu(x: ArrayLike, axis: int = -1) -> Array: @@ -541,7 +543,7 @@ def log_softmax(x: ArrayLike, # TODO(phawkins): this jit was found to change numerics in a test. Debug this. -#@partial(jax.jit, static_argnames=("axis",)) +# @partial(jax.jit, static_argnames=("axis",)) def softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, diff --git a/tests/nn_test.py b/tests/nn_test.py index 416beffce17b..d6153d32c63e 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -308,11 +308,18 @@ def testGeluIntType(self, approximate): def testGelu(self, approximate): def gelu_reference(x): return x * scipy.stats.norm.cdf(x) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng((4, 5, 6), jnp.float32)] + args_maker = lambda: [jnp.linspace(-12, 5, 10000, dtype=jnp.float32)] + rtol = 2e-5 + atol = 1e-3 if approximate else 0 self._CheckAgainstNumpy( - gelu_reference, partial(nn.gelu, approximate=approximate), args_maker, - check_dtypes=False, tol=1e-3 if approximate else None) + gelu_reference, + partial(nn.gelu, approximate=approximate), + args_maker, + check_dtypes=False, + tol=0, + rtol=rtol, + atol=atol, + ) @parameterized.parameters(*itertools.product( (jnp.float32, jnp.bfloat16, jnp.float16), From a533635898ea78591a4f84efe8f2bd7aea992e39 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 20 Sep 2024 14:17:11 -0700 Subject: [PATCH 595/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/44d14566fc5d298d5d410efa24ed8630ce137137. PiperOrigin-RevId: 676967851 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 72bad324e0f0..376f2167e35c 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "a0cb79873742367204ad1386e9ca4fd815b3f860" -XLA_SHA256 = "bcedc70cf3cdcc94159313365b15eb49e25e0d8a9d4713c290ead5a507d2b366" +XLA_COMMIT = "44d14566fc5d298d5d410efa24ed8630ce137137" +XLA_SHA256 = "2aa4d49121faa95c063c413e28fa87e0a5af64177588acd35459401f0e76f2ea" def repo(): tf_http_archive( From 6b93b35842c5973f3aede340b1bf6fd285cd2e0c Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 20 Sep 2024 15:00:00 -0700 Subject: [PATCH 596/702] [Mosaic:TPU] Efficient relayout with internal scratch We should support all different retilings (x*packing1, 128) <-> (y*packing2, 128) with any dtype in this cl at this moment. The efficient relayout with scratch brings significant improvements on current retiling in <= TPUv4 and retiling with (packing, 128) in TPUv5. All missing retiling supports are added in this cl, including increase sublane retiling and packed type retiling. PiperOrigin-RevId: 676982957 --- jaxlib/mosaic/dialect/tpu/tpu.td | 1 + jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 1 + .../tpu/transforms/apply_vector_layout.cc | 566 ++++++++++++++---- tests/pallas/tpu_pallas_test.py | 6 +- 4 files changed, 439 insertions(+), 135 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index c1fba60f4cc5..ffcc8d52cd05 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -790,6 +790,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO Option<"mxu_contracting_size", "mxu-contracting-size", "int", /*default=*/"128", "">, Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">, Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">, + Option<"vmem_banks", "vmem-banks", "int", /*default=*/"-1", "">, ]; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 510bd384d656..00bd15b57153 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -62,6 +62,7 @@ struct ApplyVectorLayoutContext { // mxu_shape = {contracting_size, non_contracting_size} std::array mxu_shape = {128, 128}; int64_t max_sublanes_in_scratch = 0; + int64_t vmem_banks = -1; // -1 means "unspecified". }; std::pair mightCommunicateBetweenChips(Operation* op); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 951ed59865c7..d3e1b59afe16 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" @@ -46,6 +47,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -139,18 +141,21 @@ void moveAllRegions(Operation &src, Operation &dst) { // // Returns: // A memref of the requested shape and type. -FailureOr getInternalScratch(RewriteContext &ctx, OpBuilder &builder, - Location loc, ArrayRef shape, - Type elem_ty) { +FailureOr> getInternalScratch( + RewriteContext &ctx, OpBuilder &builder, Location loc, + ArrayRef shape, Type elem_ty, int64_t sublane_tiling = 0) { if (shape.empty()) { return failure(); } if (shape.back() % ctx.target_shape[1] != 0) { return failure(); } - int sublane_count = + int packing = 32 / elem_ty.getIntOrFloatBitWidth(); + int sublane_count = llvm::divideCeil( std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) / - ctx.target_shape[1]; + ctx.target_shape[1], + packing); + if (sublane_count > ctx.max_sublanes_in_scratch) { return failure(); } @@ -159,7 +164,7 @@ FailureOr getInternalScratch(RewriteContext &ctx, OpBuilder &builder, FAILUREOR_ASSIGN_OR_RETURN( MemRefType scratch_ref_ty, inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation, - /*tpu_tiling_flags=*/{})); + /*tpu_tiling_flags=*/{}, sublane_tiling)); return builder.create(loc, scratch_ref_ty) .getResult(); } @@ -4752,30 +4757,6 @@ xla::Array retileToReducedSublanes( return dst_vreg_array; } -// Returns true iff the layout changes involve reduced sublanes per tile. -// -// Arguments: -// src: The existing layout. -// dst: The new layout based on which the retiling is to be carried out. -bool isSupportedReducedSublanesRetile( - const VectorLayout &src, const VectorLayout &dst, - const std::array target_shape) { - return src.implicit_dim() == dst.implicit_dim() && - llvm::all_of(llvm::zip_equal(src.offsets(), dst.offsets()), - [](auto tup) { - auto [lhs, rhs] = tup; - return lhs.value_or(0) == rhs.value_or(0); - }) - // TODO (kumudbhandari): We have not tested any tile size where - // tile[-1] != TARGET_SHAPE.lanes. It should work but needs to be - // tested. - && src.tiling()[1] == target_shape[1] && - dst.tiling()[1] == target_shape[1] && - dst.tiling()[0] < src.tiling()[0] && - src.bitwidth() == dst.bitwidth() && - llvm::isPowerOf2_64(src.tiling()[0]) && - llvm::isPowerOf2_64(dst.tiling()[0]); -} // Copy one sublane from a vreg to another vreg. // @@ -5368,13 +5349,353 @@ FailureOr>> changeOffsets( return std::make_pair(dst, std::move(vregs)); } -// TODO(b/265133506): Generalize retiling. +LogicalResult retileToLargeTileWithScratch( + RewriteContext &ctx, OpBuilder &builder, const Location loc, + xla::Array &dst_tiles, const std::array &dst_tile, + const xla::Array &src_tiles, const std::array &src_tile, + TypedValue scratch_ref) { + if (dst_tile[0] % src_tile[0] != 0) { + return failure(); + } + // Number of src vregs needed to assemble one dst vreg. + int vregs_per_group = dst_tile[0] / src_tile[0]; + // Number of sublanes needed per src vreg to assemble one dst vreg. + int sl_per_vreg = ctx.target_shape[0] / vregs_per_group; + int stride = vregs_per_group; + + xla::Array sublane_offsets( + {ctx.target_shape[0] / dst_tile[0], src_tile[0], vregs_per_group}, 0); + absl::c_iota(sublane_offsets, 0); + // The older hardware has limited support for shuffles so even if we have bank + // conflicts, we just accept them and will have the lowering unroll the + // loads/stores. + bool should_handle_bank_confict = + ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && + ctx.vmem_banks < stride * ctx.target_shape[0]; + // Add one extra sublane to stride to avoid bank conflict. + if (should_handle_bank_confict) { + // Adjust sublane offsets to match the stride. + for (int i = 0; i < sublane_offsets.num_elements(); i += 1) { + *(sublane_offsets.begin() + i) += i / stride; + } + stride += 1; + } + sublane_offsets.TransposeDimensions({0, 2, 1}); + + auto mlirIndexConst = [&](int d) { + return builder.create( + src_tiles.begin()->getLoc(), + builder.getIntegerAttr(builder.getIndexType(), d)); + }; + auto cst_0 = mlirIndexConst(0); + // Each group has exact number of src vregs needed to assemble one dst vreg. + // We can not use circular buffer here because we need to have enough space to + // strided load/store. + int64_t sublanes_per_group = stride * sl_per_vreg * vregs_per_group; + int64_t max_groups_in_scratch = + ctx.max_sublanes_in_scratch / sublanes_per_group; + if (max_groups_in_scratch < 1) { + return emitError(loc, + "scratch space is not enough for retiling to large tile"); + } + int64_t stored_group_cnt = 0; + auto dst_vreg_ty = src_tiles.begin()->getType(); + // Create a new vreg type that can be stored in scratch memref. + auto temp_vreg_ty = + VectorType::get(ctx.target_shape, scratch_ref.getType().getElementType()); + SmallVector sublane_mask(ctx.target_shape[0], true); + // (dst_vreg, load_offset) + std::vector> delayed_loads; + delayed_loads.reserve(max_groups_in_scratch * vregs_per_group); + // We only emit the loads when we run out of scratch space or we are at the + // last vreg of the batch to help bundle scheduling. + auto emit_all_delayed_loads = [&]() { + for (auto [dst_vreg, load_offset] : delayed_loads) { + Value load_op = builder.create( + loc, temp_vreg_ty, scratch_ref, ArrayRef({load_offset, cst_0}), + ArrayRef(sublane_mask), + ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); + *dst_vreg = builder.create(loc, dst_vreg_ty, load_op); + } + delayed_loads.clear(); + }; + + int rank = src_tiles.dimensions().size(); + if (rank != dst_tiles.dimensions().size()) { + return emitError(loc, "src and dst tiles have different ranks"); + } + for (int i = 0; i < rank - 2; ++i) { + if (src_tiles.dim(i) != dst_tiles.dim(i)) { + return emitError(loc, + "Expected src and dst tiles have same dimension " + "sizes on dim") + << i << ", but got " << src_tiles.dim(i) << " vs " + << dst_tiles.dim(i); + } + } + SmallVector src_idx(rank); + dst_tiles.Each([&](absl::Span dst_idx, Value *dst_vreg) { + int64_t dst_row_idx = *(dst_idx.end() - 2); + int64_t dst_col_idx = *(dst_idx.end() - 1); + int64_t vreg_idx_in_group = dst_col_idx % vregs_per_group; + int64_t load_offset = sublanes_per_group * stored_group_cnt + + vreg_idx_in_group * sl_per_vreg * stride; + delayed_loads.push_back( + std::make_pair(dst_vreg, mlirIndexConst(load_offset))); + // When dst vreg is at the last vreg of the group or the current dst + // vregs' row, this indicates we have scheduled delayed loads for all + // the vregs from current group and now we need to store corresponding + // group of src vregs before actually emitting the loads. + if (vreg_idx_in_group == vregs_per_group - 1 || + dst_col_idx == dst_tiles.dimensions().back() - 1) { + auto src_row_idx = dst_row_idx * vregs_per_group; + auto src_col_idx = dst_col_idx / vregs_per_group; + std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin()); + for (int vi = 0; vi < vregs_per_group; ++vi) { + if (src_row_idx + vi >= src_tiles.dim(rank - 2) || + src_col_idx >= src_tiles.dim(rank - 1)) { + break; + } + *(src_idx.end() - 2) = src_row_idx + vi; + *(src_idx.end() - 1) = src_col_idx; + Value src_vreg = src_tiles(src_idx); + src_vreg = + builder.create(loc, temp_vreg_ty, src_vreg); + Value store_offset = + mlirIndexConst(sublanes_per_group * stored_group_cnt + vi); + builder.create( + loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), + ArrayRef(sublane_mask), + /*mask=*/nullptr, builder.getI32IntegerAttr(stride)); + } + stored_group_cnt = (stored_group_cnt + 1) % max_groups_in_scratch; + // We emit loads when we run out of scratch space or we are at the + // last vreg of the batch. + if (stored_group_cnt == 0 || + (*(dst_idx.end() - 2) == dst_tiles.dim(rank - 2) - 1 && + *(dst_idx.end() - 1) == dst_tiles.dim(rank - 1) - 1)) { + emit_all_delayed_loads(); + } + } + }); + return success(); +} + +LogicalResult retileToSmallTileWithScratch( + RewriteContext &ctx, OpBuilder &builder, const Location loc, + xla::Array &dst_tiles, const std::array &dst_tile, + const xla::Array &src_tiles, const std::array &src_tile, + TypedValue scratch_ref) { + if (src_tile[0] % dst_tile[0] != 0) { + return failure(); + } + // Number of src vregs needed to assemble one dst vreg. + int vregs_per_group = src_tile[0] / dst_tile[0]; + // Number of sublanes needed per src vreg to assemble one dst vreg. + int sl_per_vreg = ctx.target_shape[0] / vregs_per_group; + int stride = vregs_per_group; + + xla::Array sublane_offsets( + {ctx.target_shape[0] / src_tile[0], dst_tile[0], vregs_per_group}, 0); + absl::c_iota(sublane_offsets, 0); + // The older hardware has limited support for shuffles so even if we have + // bank conflicts, we just accept them and will have the lowering unroll the + // loads/stores. + bool should_handle_bank_confict = + ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && + ctx.vmem_banks < stride * ctx.target_shape[0]; + bool use_shuffled_load = false; + if (ctx.hardware_generation <= 4) { + if (src_tile[0] == 8) { + // The older hardware does not support shuffled store. However, if the src + // tile is (8, 128), we can convert (shuffled store + strided load) to + // (strided store + shuffled load). + use_shuffled_load = true; + } else if (src_tile[0] == 4) { + // In this case, the trick of replacing a shuffled store with a shuffled + // load does not work. Handling bank conflicts will cause the sublane + // offsets to increase which might make emulation harder, so we avoid + // doing so. + should_handle_bank_confict = false; + } + } + + // Add one extra sublane to stride to avoid bank conflict. + if (should_handle_bank_confict) { + // Adjust sublane offsets to match the stride. + for (int i = 0; i < sublane_offsets.num_elements(); i += 1) { + *(sublane_offsets.begin() + i) += i / stride; + } + stride += 1; + } + sublane_offsets.TransposeDimensions({0, 2, 1}); + auto mlirIndexConst = [&](int d) { + return builder.create( + src_tiles.begin()->getLoc(), + builder.getIntegerAttr(builder.getIndexType(), d)); + }; + auto cst_0 = mlirIndexConst(0); + // Each group has exact number of src vregs needed to assemble one dst vreg. + // We can not use circular buffer here because we need to have enough space + // to strided load/store. + int64_t sublanes_per_group = stride * sl_per_vreg * vregs_per_group; + int64_t max_groups_in_scratch = + ctx.max_sublanes_in_scratch / sublanes_per_group; + if (max_groups_in_scratch < 1) { + return emitError(loc, + "scratch space is not enough for retiling to small tile"); + } + int64_t stored_group_cnt = 0; + auto dst_vreg_ty = src_tiles.begin()->getType(); + // Create a new vreg type that can be stored in scratch memref. + auto temp_vreg_ty = + VectorType::get(ctx.target_shape, scratch_ref.getType().getElementType()); + SmallVector sublane_mask(ctx.target_shape[0], true); + // (dst_vreg, load_offset) + std::vector> delayed_loads; + delayed_loads.reserve(max_groups_in_scratch * vregs_per_group); + // We only emit the loads when we run out of scratch space or we are at the + // last vreg of the batch to help bundle scheduling. + auto emit_all_delayed_loads = [&]() { + for (auto [dst_vreg, load_offset] : delayed_loads) { + Value load_op; + if (use_shuffled_load) { + load_op = builder.create( + loc, temp_vreg_ty, scratch_ref, + ArrayRef({load_offset, cst_0}), ArrayRef(sublane_mask), + ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); + } else { + load_op = builder.create( + loc, temp_vreg_ty, scratch_ref, + ArrayRef({load_offset, cst_0}), ArrayRef(sublane_mask), + builder.getI32IntegerAttr(stride)); + } + *dst_vreg = builder.create(loc, dst_vreg_ty, load_op); + } + delayed_loads.clear(); + }; + int rank = src_tiles.dimensions().size(); + if (rank != dst_tiles.dimensions().size()) { + return emitError(loc, "src and dst tiles have different ranks"); + } + for (int i = 0; i < rank - 2; ++i) { + if (src_tiles.dim(i) != dst_tiles.dim(i)) { + return emitError(loc, + "Expected src and dst tiles have same dimension " + "sizes on dim") + << i << ", but got " << src_tiles.dim(i) << " vs " + << dst_tiles.dim(i); + } + } + SmallVector dst_idx(rank); + src_tiles.Each([&](absl::Span src_idx, Value src_vreg) { + int64_t src_row_idx = *(src_idx.end() - 2); + int64_t src_col_idx = *(src_idx.end() - 1); + int64_t vreg_idx_in_group = src_col_idx % vregs_per_group; + src_vreg = builder.create(loc, temp_vreg_ty, src_vreg); + if (use_shuffled_load) { + Value store_offset = mlirIndexConst( + sublanes_per_group * stored_group_cnt + vreg_idx_in_group); + builder.create( + loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), + ArrayRef(sublane_mask), + /*mask=*/nullptr, builder.getI32IntegerAttr(stride)); + } else { + Value store_offset = + mlirIndexConst(sublanes_per_group * stored_group_cnt + + vreg_idx_in_group * sl_per_vreg * stride); + builder.create( + loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), + ArrayRef(sublane_mask), + ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); + } + // When src vreg is at the last vreg of the group or the current src + // vregs' row, this indicates we have stored all the vregs needed to + // assemble a new group of dst vreg. + if (vreg_idx_in_group == vregs_per_group - 1 || + src_col_idx == src_tiles.dimensions().back() - 1) { + auto dst_row_idx = src_row_idx * vregs_per_group; + auto dst_col_idx = src_col_idx / vregs_per_group; + std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin()); + for (int vi = 0; vi < vregs_per_group; ++vi) { + if (dst_row_idx + vi >= dst_tiles.dim(rank - 2) || + dst_col_idx >= dst_tiles.dim(rank - 1)) { + break; + } + *(dst_idx.end() - 2) = dst_row_idx + vi; + *(dst_idx.end() - 1) = dst_col_idx; + Value *dst_vreg = &dst_tiles(dst_idx); + int64_t load_offset = + use_shuffled_load ? (sublanes_per_group * stored_group_cnt + + vi * sl_per_vreg * stride) + : (sublanes_per_group * stored_group_cnt + vi); + delayed_loads.push_back( + std::make_pair(dst_vreg, mlirIndexConst(load_offset))); + } + stored_group_cnt = (stored_group_cnt + 1) % max_groups_in_scratch; + // We emit loads when we run out of scratch space or we are at the + // last vreg of the batch. + if (stored_group_cnt == 0 || + (*(src_idx.end() - 2) == src_tiles.dim(rank - 2) - 1 && + *(src_idx.end() - 1) == src_tiles.dim(rank - 1) - 1)) { + emit_all_delayed_loads(); + } + } + }); + return success(); +} + +// go/mosaic-retiling-in-scratch is the full internal documentation that +// includes more details about the TPU generations. +LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder, + const Location loc, + xla::Array &dst_tiles, + const std::array &dst_tiling, + const xla::Array &src_tiles, + const std::array &src_tiling, + int packing) { + if (!(src_tiling[1] == ctx.target_shape[1] && + dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 && + dst_tiling[0] % packing == 0)) { + return failure(); + } + // Try to get i32 vector scratch space. Because we will bitcast vregs to + // i32 vregs before using scratch for retiling. Through this way we can + // handle packed types as well. + auto vi32_scratch_ref = getInternalScratch( + ctx, builder, loc, {ctx.max_sublanes_in_scratch, ctx.target_shape[1]}, + builder.getI32Type(), /*sublane_tiling=*/1); + if (failed(vi32_scratch_ref)) { + return emitError(loc, "Failed to get scratch ref for retiling"); + } + auto ref = vi32_scratch_ref.value(); + std::array vi32_dst_tiling = {dst_tiling[0] / packing, + dst_tiling[1]}; + std::array vi32_src_tiling = {src_tiling[0] / packing, + src_tiling[1]}; + if (src_tiling[0] > dst_tiling[0]) { + return retileToSmallTileWithScratch(ctx, builder, loc, dst_tiles, + vi32_dst_tiling, src_tiles, + vi32_src_tiling, ref); + } + if (src_tiling[0] < dst_tiling[0]) { + return retileToLargeTileWithScratch(ctx, builder, loc, dst_tiles, + vi32_dst_tiling, src_tiles, + vi32_src_tiling, ref); + } + dst_tiles = std::move(src_tiles); + return success(); +} + FailureOr>> changeTiling( RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, const VectorLayout src, xla::Array vregs, const std::array dst_tiling, bool try_replicate_rows) { + bool has_enough_scratch = ctx.max_sublanes_in_scratch >= + ctx.target_shape[0] * (ctx.target_shape[0] + 1); const auto &target_shape = ctx.target_shape; - if (src.tiling() == dst_tiling) { + const std::array src_tiling = src.tiling(); + if (src_tiling == dst_tiling) { return std::pair(src, std::move(vregs)); } const int packing = src.packing(); @@ -5384,106 +5705,62 @@ FailureOr>> changeTiling( if (!dst.isValid(target_shape)) { return emitError(loc, "Not implemented: invalid offsets in tiling target"); } - // Handle retiling from (packing, 128) to (8 * packing, 128). - if (src.offsets() == LayoutOffsets{0, 0} && - src.tiling() == std::array{packing, 128} && - dst_tiling == std::array{8 * packing, 128}) { - bool replicate_sublanes = try_replicate_rows && packing == 1 && - *(vregs.dimensions().end() - 2) == 1; - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + auto dst_tiles_shape = + dst.tileArrayImplicitShape(vty.getShape(), target_shape); + // Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating + // sublanes. + if (try_replicate_rows && packing == 1 && + *(vregs.dimensions().end() - 2) == 1 && + src.offsets() == LayoutOffsets{0, 0} && + src.tiling() == std::array{1, 128} && + dst_tiling == std::array{8, 128}) { + xla::Array retiled(dst_tiles_shape); retiled.Each([&](absl::Span idx, Value *tile) { SmallVector src_idx(idx.begin(), idx.end()); *(src_idx.end() - 2) *= target_shape[0]; *(src_idx.end() - 1) /= target_shape[0]; const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0]; - if (replicate_sublanes) { - CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1); - *tile = - broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); - } else { - for (int dst_sl_idx = 0; - dst_sl_idx < target_shape[0] && - *(src_idx.end() - 2) < *(vregs.dimensions().end() - 2); - ++dst_sl_idx, ++*(src_idx.end() - 2)) { - *tile = copy_one_sublane(builder, vregs(src_idx), src_sl_idx, *tile, - dst_sl_idx, target_shape); - } - } + CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1); + *tile = + broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); }); // We have successfully replicated sublanes. - if (replicate_sublanes) { - dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling, - dst.implicit_dim()); - } - return std::pair(dst, std::move(retiled)); - } - // Handle retiling from (m, 128) to (8, 128) for 32-bit data - // where m < 8 and m is a power of 2. - // TODO(b/306692696): Handle any vregs.dimensions(). - if (bitwidth == 32 && src.offsets() == LayoutOffsets{0, 0} && - target_shape[0] % src.tiling()[0] == 0 && - src.tiling()[1] == target_shape[1] && dst.tiling() == target_shape && - *(vregs.dimensions().end() - 2) == 1) { - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - retiled.Each([&](const absl::Span idx, - Value *const new_src_tile) { - const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape); - const int64_t dst_col = idx.back(); - const int64_t src_col = dst_col / tiles_per_vreg; - const int64_t start_slane_idx = - src.tiling()[0] * (dst_col % tiles_per_vreg); - SmallVector src_idx(toArrayRef(idx)); - src_idx.back() = src_col; - Value src_tile = vregs(src_idx); - if (start_slane_idx) { - SmallVector slane_idxs; - slane_idxs.reserve(target_shape[0]); - for (int i = 0; i < target_shape[0]; ++i) { - slane_idxs.push_back(start_slane_idx + (i % src.tiling()[0])); - } - const DenseI32ArrayAttr gather_indices = - builder.getDenseI32ArrayAttr(slane_idxs); - *new_src_tile = builder.create(loc, src_tile.getType(), - src_tile, gather_indices, - /*dimension=*/0); - } else { - *new_src_tile = src_tile; - } - }); + dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling, + dst.implicit_dim()); return std::pair(dst, std::move(retiled)); } // (8,128) -> (8 * packing,128) tiling change for packed type. if (bitwidth < 32 && 32 % bitwidth == 0 && - src.tiling() == std::array{8, 128} && - dst.tiling() == std::array{8 * dst.packing(), 128}) { - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - int vty_packing = dst.packing(); - VectorType vreg_x32 = - vty.getElementType().isSignlessInteger() - ? VectorType::get(target_shape, builder.getI32Type()) - : VectorType::get(target_shape, builder.getF32Type()); - retiled.Each([&](absl::Span idx, Value *tile) { - const int vreg_part = idx.back() % vty_packing; - SmallVector parts; - parts.reserve(vty_packing); - SmallVector src_idx(idx.begin(), idx.end()); - src_idx[src_idx.size() - 2] *= vty_packing; - src_idx[src_idx.size() - 1] /= vty_packing; - for (int i = 0; i < vty_packing; ++i) { - parts.push_back(builder.create( - loc, vreg_x32, vregs(src_idx), vreg_part)); - if (src_idx[src_idx.size() - 2] < - vregs.dim(vregs.num_dimensions() - 2) - 1) { - ++src_idx[src_idx.size() - 2]; + src_tiling == std::array{8, 128} && + dst_tiling == std::array{8 * dst.packing(), 128}) { + // Note: for int4, retiling with scratch is always faster. + if (bitwidth != 4 || !has_enough_scratch) { + xla::Array retiled(dst_tiles_shape); + int vty_packing = dst.packing(); + VectorType vreg_x32 = + vty.getElementType().isSignlessInteger() + ? VectorType::get(target_shape, builder.getI32Type()) + : VectorType::get(target_shape, builder.getF32Type()); + retiled.Each([&](absl::Span idx, Value *tile) { + const int vreg_part = idx.back() % vty_packing; + SmallVector parts; + parts.reserve(vty_packing); + SmallVector src_idx(idx.begin(), idx.end()); + src_idx[src_idx.size() - 2] *= vty_packing; + src_idx[src_idx.size() - 1] /= vty_packing; + for (int i = 0; i < vty_packing; ++i) { + parts.push_back(builder.create( + loc, vreg_x32, vregs(src_idx), vreg_part)); + if (src_idx[src_idx.size() - 2] < + vregs.dim(vregs.num_dimensions() - 2) - 1) { + ++src_idx[src_idx.size() - 2]; + } } - } - *tile = builder.create( - loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed); - }); - return std::pair(dst, std::move(retiled)); + *tile = builder.create( + loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed); + }); + return std::pair(dst, std::move(retiled)); + } } // Handle retiling from (1, 128 * packing) to (packing, 128) for // packed data. @@ -5497,8 +5774,8 @@ FailureOr>> changeTiling( // match corresponding elements without shifting. It's just that // the tiles are not adjacent (no contiguous vreg slice). if (bitwidth < 32 && 32 % bitwidth == 0 && - src.tiling() == std::array{1, 128 * packing} && - dst.tiling() == std::array{packing, 128}) { + src_tiling == std::array{1, 128 * packing} && + dst_tiling == std::array{packing, 128}) { // To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of // 4 sublanes and 2 lanes (this is convenient for to keep the example small // yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling. @@ -5539,8 +5816,7 @@ FailureOr>> changeTiling( // [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before // moving to the next one. This is exactly an interleaving of the sublanes // of the vreg parts. - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + xla::Array retiled(dst_tiles_shape); const VectorType vreg_x32 = vty.getElementType().isSignlessInteger() ? VectorType::get(target_shape, builder.getI32Type()) @@ -5565,13 +5841,41 @@ FailureOr>> changeTiling( }); return std::pair(dst, std::move(retiled)); } - if (isSupportedReducedSublanesRetile(src, dst, target_shape)) { - return std::pair(dst, retileToReducedSublanes(builder, vty.getShape(), src, - vregs, dst, target_shape)); + if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) { + // TODO(b/368088671): When sublane tiling changes, we should be able to + // preserve some replications from the source layout. But we need to + // make sure they are implemented efficiently and well-tested. For now, we + // just simply use 0 for the replicated offset after retiling. + dst = VectorLayout( + bitwidth, {src.offsets()[0].value_or(0), src.offsets()[1].value_or(0)}, + dst_tiling, dst.implicit_dim()); + + // All clauses in the and expression are based on performance benchmarking. + bool use_alu = !has_enough_scratch || + (ctx.hardware_generation >= 5 && src_tiling[0] != packing && + dst_tiling[0] != packing); + + if (use_alu) { + if (src_tiling[0] > dst_tiling[0]) { + return std::pair( + dst, retileToReducedSublanes(builder, vty.getShape(), src, vregs, + dst, target_shape)); + } else if (!has_enough_scratch) { + // TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops. + return emitError( + loc, + "Not implemented: retiling to increase sublane tiling with ALU"); + } + } + xla::Array retiled(dst_tiles_shape); + if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, vregs, + src_tiling, packing))) { + return failure(); + } + return std::pair(dst, std::move(retiled)); } return emitError(loc, "Not implemented: Unsupported tiling change for ") - << vty << ": from " << src << " to tiling (" << dst_tiling[0] << ", " - << dst_tiling[1] << ")"; + << vty << ": from " << src << " to " << dst; } FailureOr>> changeImplicitDim( @@ -5878,6 +6182,7 @@ struct ApplyVectorLayoutPass mxu_contracting_size = ctx.mxu_shape[0]; mxu_noncontracting_size = ctx.mxu_shape[1]; max_sublanes_in_scratch = ctx.max_sublanes_in_scratch; + vmem_banks = ctx.vmem_banks; } void runOnOperation() override { // Fail if hardware_generation has not been set from the default value. @@ -5889,7 +6194,8 @@ struct ApplyVectorLayoutPass .hardware_generation = hardware_generation, .target_shape = {sublane_count, lane_count}, .mxu_shape = {mxu_contracting_size, mxu_noncontracting_size}, - .max_sublanes_in_scratch = max_sublanes_in_scratch}; + .max_sublanes_in_scratch = max_sublanes_in_scratch, + .vmem_banks = vmem_banks}; if (failed(applyLayoutFunc(ctx, getOperation()))) { signalPassFailure(); return; diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 84403e41b561..87ccaa644e8c 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2464,9 +2464,7 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.broadcast_to(x, (256, 512))) - @only_passes_in_interpret(unless_generation=4) def test_bfloat16_to_uint32_bitcast(self): - """b/347771903""" x = np.arange(16 * 2 * 256, dtype=jnp.bfloat16).reshape(16, 2, 256) def kernel(x_ref, out_ref): @@ -2475,7 +2473,7 @@ def kernel(x_ref, out_ref): out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((16, 1, 256), jnp.uint32) )(x) - # FIXME: Add correctness test for result. + np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32)) @only_passes_in_interpret() def test_roll_partial(self): @@ -2548,9 +2546,7 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128))) - @only_passes_in_interpret() def test_mixed_strides(self): - """b/352841329""" x = np.zeros((8, 128), dtype=jnp.float32) y = np.zeros((8, 2, 128), dtype=jnp.bfloat16) From 0cf040c9a129c7ed47f9acdee725c0dba97e8ec1 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Fri, 20 Sep 2024 22:19:14 +0000 Subject: [PATCH 597/702] Add/update JAX Advanced Tutorials docs, ToC structure --- docs/_tutorials/index.rst | 5 - docs/{_tutorials => }/advanced-autodiff.md | 0 docs/conf.py | 2 + docs/extensions.rst | 2 - docs/{_tutorials => }/external-callbacks.md | 0 docs/ffi.ipynb | 2 +- docs/ffi.md | 2 +- docs/glossary.rst | 2 +- .../gradient-checkpointing.md | 0 docs/index.rst | 7 +- docs/{_tutorials => }/jax-primitives.md | 0 docs/{_tutorials => }/jaxpr.md | 0 docs/jaxpr.rst | 472 ----- docs/jit-compilation.md | 2 +- docs/notebooks/How_JAX_primitives_work.ipynb | 1532 ----------------- docs/notebooks/How_JAX_primitives_work.md | 771 --------- docs/notebooks/external_callbacks.ipynb | 1121 ------------ docs/notebooks/external_callbacks.md | 515 ------ docs/tutorials.rst | 10 + docs/user_guides.rst | 1 - 20 files changed, 20 insertions(+), 4426 deletions(-) rename docs/{_tutorials => }/advanced-autodiff.md (100%) rename docs/{_tutorials => }/external-callbacks.md (100%) rename docs/{_tutorials => }/gradient-checkpointing.md (100%) rename docs/{_tutorials => }/jax-primitives.md (100%) rename docs/{_tutorials => }/jaxpr.md (100%) delete mode 100644 docs/jaxpr.rst delete mode 100644 docs/notebooks/How_JAX_primitives_work.ipynb delete mode 100644 docs/notebooks/How_JAX_primitives_work.md delete mode 100644 docs/notebooks/external_callbacks.ipynb delete mode 100644 docs/notebooks/external_callbacks.md diff --git a/docs/_tutorials/index.rst b/docs/_tutorials/index.rst index 5b3d690d5e96..0e5a6a16dcfc 100644 --- a/docs/_tutorials/index.rst +++ b/docs/_tutorials/index.rst @@ -38,10 +38,7 @@ JAX 201 :maxdepth: 1 parallelism - advanced-autodiff - gradient-checkpointing advanced-debugging - external-callbacks profiling-and-performance JAX 301 @@ -50,6 +47,4 @@ JAX 301 .. toctree:: :maxdepth: 1 - jax-primitives - jaxpr advanced-compilation diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/advanced-autodiff.md similarity index 100% rename from docs/_tutorials/advanced-autodiff.md rename to docs/advanced-autodiff.md diff --git a/docs/conf.py b/docs/conf.py index e77916e265ff..d57420dec881 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -360,4 +360,6 @@ def linkcode_resolve(domain, info): 'jax-101/07-state.md': 'stateful-computations.md', 'jax-101/08-pjit.rst': 'sharded-computation.md', 'jax-101/index.rst': 'tutorials.rst', + 'notebooks/external_callbacks.md': 'external-callbacks.md', + 'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md', } diff --git a/docs/extensions.rst b/docs/extensions.rst index 92963b71f20f..856153cd8723 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -10,8 +10,6 @@ that use or interface with JAX. :caption: Extensible JAX internals :maxdepth: 1 - notebooks/How_JAX_primitives_work - jaxpr notebooks/Writing_custom_interpreters_in_Jax Custom_Operation_for_GPUs jax.extend diff --git a/docs/_tutorials/external-callbacks.md b/docs/external-callbacks.md similarity index 100% rename from docs/_tutorials/external-callbacks.md rename to docs/external-callbacks.md diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index 04ae80cbf5b1..a8cd5219d4b5 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -364,7 +364,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:" + "We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:" ] }, { diff --git a/docs/ffi.md b/docs/ffi.md index 03acf876be08..cc3863ed99b2 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -311,7 +311,7 @@ Our implementation of `rms_norm` has the appropriate semantics, and it supports np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5) ``` -We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`: +We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`: ```{code-cell} ipython3 jax.make_jaxpr(jax.vmap(rms_norm))(x) diff --git a/docs/glossary.rst b/docs/glossary.rst index 4bb9fa15667e..286b07e21a66 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -30,7 +30,7 @@ Glossary of terms jaxpr Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution. - See :ref:`understanding-jaxprs` for more discussion and examples. + See :ref:`jax-internals-jaxpr` for more discussion and examples. JIT Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of diff --git a/docs/_tutorials/gradient-checkpointing.md b/docs/gradient-checkpointing.md similarity index 100% rename from docs/_tutorials/gradient-checkpointing.md rename to docs/gradient-checkpointing.md diff --git a/docs/index.rst b/docs/index.rst index 92422edc069f..2dd856ab88ef 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -121,8 +121,6 @@ maintains an up-to-date list. installation quickstart - notebooks/Common_Gotchas_in_JAX - faq .. toctree:: :hidden: @@ -130,11 +128,14 @@ maintains an up-to-date list. tutorials + notebooks/Common_Gotchas_in_JAX + + faq .. toctree:: :hidden: :maxdepth: 2 - :caption: Resources + :caption: More guides/resources user_guides advanced_guide diff --git a/docs/_tutorials/jax-primitives.md b/docs/jax-primitives.md similarity index 100% rename from docs/_tutorials/jax-primitives.md rename to docs/jax-primitives.md diff --git a/docs/_tutorials/jaxpr.md b/docs/jaxpr.md similarity index 100% rename from docs/_tutorials/jaxpr.md rename to docs/jaxpr.md diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst deleted file mode 100644 index d7b50dcb301e..000000000000 --- a/docs/jaxpr.rst +++ /dev/null @@ -1,472 +0,0 @@ -.. _understanding-jaxprs: - -Understanding Jaxprs -==================== - -Updated: May 3, 2020 (for commit f1a46fe). - -Conceptually, one can think of JAX transformations as first trace-specializing -the Python function to be transformed into a small and well-behaved -intermediate form that is then interpreted with transformation-specific -interpretation rules. One of the reasons JAX can pack so much power into such a -small software package is that it starts with a familiar and flexible -programming interface (Python with NumPy) and it uses the actual Python -interpreter to do most of the heavy lifting to distill the essence of the -computation into a simple statically-typed expression language with limited -higher-order features. That language is the jaxpr language. - -Not all Python programs can be processed this way, but it turns out that many -scientific computing and machine learning programs can. - -Before we proceed, it is important to point out that not all JAX -transformations literally materialize a jaxpr as described above; some, e.g., -differentiation or batching, will apply transformations incrementally during -tracing. Nevertheless, if one wants to understand how JAX works internally, or -to make use of the result of JAX tracing, it is useful to understand jaxprs. - -A jaxpr instance represents a function with one or more typed parameters (input -variables) and one or more typed results. The results depend only on the input -variables; there are no free variables captured from enclosing scopes. The -inputs and outputs have types, which in JAX are represented as abstract values. -There are two related representations in the code for jaxprs, -:py:class:`jax.core.Jaxpr` and :py:class:`jax.core.ClosedJaxpr`. A -:py:class:`jax.core.ClosedJaxpr` represents a partially-applied -:py:class:`jax.core.Jaxpr`, and is what you obtain when you use -:py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following fields: - - * ``jaxpr`` is a :py:class:`jax.core.Jaxpr` representing the actual - computation content of the function (described below). - * ``consts`` is a list of constants. - -The most interesting part of the ClosedJaxpr is the actual execution content, -represented as a :py:class:`jax.core.Jaxpr` as printed using the following -grammar:: - - Jaxpr ::= { lambda Var* ; Var+. let - Eqn* - in [Expr+] } - -where: - * The parameters of the jaxpr are shown as two lists of variables separated by - ``;``. The first set of variables are the ones that have been introduced - to stand for constants that have been hoisted out. These are called the - ``constvars``, and in a :py:class:`jax.core.ClosedJaxpr` the ``consts`` - field holds corresponding values. The second list of variables, called - ``invars``, correspond to the inputs of the traced Python function. - * ``Eqn*`` is a list of equations, defining intermediate variables referring to - intermediate expressions. Each equation defines one or more variables as the - result of applying a primitive on some atomic expressions. Each equation uses only - input variables and intermediate variables defined by previous equations. - * ``Expr+``: is a list of output atomic expressions (literals or variables) - for the jaxpr. - -Equations are printed as follows:: - - Eqn ::= Var+ = Primitive [ Param* ] Expr+ - -where: - * ``Var+`` are one or more intermediate variables to be defined as the output - of a primitive invocation (some primitives can return multiple values). - * ``Expr+`` are one or more atomic expressions, each either a variable or a - literal constant. A special variable ``unitvar`` or literal ``unit``, - printed as ``*``, represents a value that is not needed - in the rest of the computation and has been elided. That is, units are just - placeholders. - * ``Param*`` are zero or more named parameters to the primitive, printed in - square brackets. Each parameter is shown as ``Name = Value``. - - -Most jaxpr primitives are first-order (they take just one or more ``Expr`` as arguments):: - - Primitive := add | sub | sin | mul | ... - - -The jaxpr primitives are documented in the :py:mod:`jax.lax` module. - -For example, here is the jaxpr produced for the function ``func1`` below - ->>> from jax import make_jaxpr ->>> import jax.numpy as jnp ->>> def func1(first, second): -... temp = first + jnp.sin(second) * 3. -... return jnp.sum(temp) -... ->>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8))) -{ lambda ; a:f32[8] b:f32[8]. let - c:f32[8] = sin b - d:f32[8] = mul c 3.0 - e:f32[8] = add a d - f:f32[] = reduce_sum[axes=(0,)] e - in (f,) } - -Here there are no constvars, ``a`` and ``b`` are the input variables -and they correspond respectively to -``first`` and ``second`` function parameters. The scalar literal ``3.0`` is kept -inline. -The ``reduce_sum`` primitive has named parameter ``axes``, in addition to the -operand ``e``. - -Note that even though execution of a program that calls into JAX builds a jaxpr, -Python-level control-flow and Python-level functions execute normally. -This means that just because a Python program contains functions and control-flow, -the resulting jaxpr does not have to contain control-flow or higher-order features. - -For example, when tracing the function ``func3`` JAX will inline the call to -``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same -jaxpr as before - ->>> def func2(inner, first, second): -... temp = first + inner(second) * 3. -... return jnp.sum(temp) -... ->>> def inner(second): -... if second.shape[0] > 4: -... return jnp.sin(second) -... else: -... assert False -... ->>> def func3(first, second): -... return func2(inner, first, second) -... ->>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8))) -{ lambda ; a:f32[8] b:f32[8]. let - c:f32[8] = sin b - d:f32[8] = mul c 3.0 - e:f32[8] = add a d - f:f32[] = reduce_sum[axes=(0,)] e - in (f,) } - - -Handling PyTrees ----------------- - -In jaxpr there are no tuple types; instead primitives take multiple inputs -and produce multiple outputs. When processing a function that has structured -inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists -of inputs and outputs. For more details, please see the documentation for -PyTrees (:ref:`pytrees`). - -For example, the following code produces an identical jaxpr to what we saw -before (with two input vars, one for each element of the input tuple) - - ->>> def func4(arg): # Arg is a pair -... temp = arg[0] + jnp.sin(arg[1]) * 3. -... return jnp.sum(temp) -... ->>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))) -{ lambda ; a:f32[8] b:f32[8]. let - c:f32[8] = sin b - d:f32[8] = mul c 3.0 - e:f32[8] = add a d - f:f32[] = reduce_sum[axes=(0,)] e - in (f,) } - - - -Constant vars -------------- - -Some values in jaxprs are constants, in that their value does not depend on the -jaxpr's arguments. When these values are scalars they are represented directly -in the jaxpr equations; non-scalar array constants are instead hoisted out to -the top-level jaxpr, where they correspond to constant variables ("constvars"). -These constvars differ from the other jaxpr parameters ("invars") only as a -bookkeeping convention. - - -Higher-order primitives ------------------------ - -jaxpr includes several higher-order primitives. They are more complicated because -they include sub-jaxprs. - -Conditionals -^^^^^^^^^^^^ - -JAX traces through normal Python conditionals. To capture a -conditional expression for dynamic execution, one must use the -:py:func:`jax.lax.switch` and :py:func:`jax.lax.cond` constructors, -which have the signatures:: - - lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B - - lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B - -Both of these will bind a primitive called ``cond`` internally. The -``cond`` primitive in jaxprs reflects the more general signature of -:py:func:`lax.switch`: it takes an integer denoting the index of the branch -to execute (clamped into valid indexing range). - -For example: - ->>> from jax import lax ->>> ->>> def one_of_three(index, arg): -... return lax.switch(index, [lambda x: x + 1., -... lambda x: x - 2., -... lambda x: x + 3.], -... arg) -... ->>> print(make_jaxpr(one_of_three)(1, 5.)) -{ lambda ; a:i32[] b:f32[]. let - c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a - d:i32[] = clamp 0 c 2 - e:f32[] = cond[ - branches=( - { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) } - { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) } - { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) } - ) - ] d b - in (e,) } - -The `branches` parameter to the cond primitive corresponds to the branch -functionals. In this example, those functionals each take one input variable, -corresponding to ``x``. - -The above instance of the cond primitive takes two operands. The first -one (``d``) is the branch index, then ``b`` is the operand (``arg``) to -be passed to whichever jaxpr in ``branches`` is selected by the branch -index. - -Another example, using :py:func:`lax.cond`: - ->>> from jax import lax ->>> ->>> def func7(arg): -... return lax.cond(arg >= 0., -... lambda xtrue: xtrue + 3., -... lambda xfalse: xfalse - 3., -... arg) -... ->>> print(make_jaxpr(func7)(5.)) -{ lambda ; a:f32[]. let - b:bool[] = ge a 0.0 - c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b - d:f32[] = cond[ - branches=( - { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) } - { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) } - ) - ] c a - in (d,) } - -In this case, the boolean predicate is converted to an integer index -(0 or 1), and ``branches`` are jaxprs that correspond to the false and -true branch functionals, in that order. Again, each functional takes -one input variable, corresponding to ``xfalse`` and ``xtrue`` -respectively. - -The following example shows a more complicated situation when the input -to the branch functionals is a tuple, and the `false` branch functional -contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar` - ->>> def func8(arg1, arg2): # arg2 is a pair -... return lax.cond(arg1 >= 0., -... lambda xtrue: xtrue[0], -... lambda xfalse: jnp.array([1]) + xfalse[1], -... arg2) -... ->>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.))) -{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let - e:bool[] = ge b 0.0 - f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e - g:f32[1] = cond[ - branches=( - { lambda ; h:i32[1] i:f32[1] j:f32[]. let - k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h - l:f32[1] = add k j - in (l,) } - { lambda ; m_:i32[1] n:f32[1] o:f32[]. let in (n,) } - ) - ] f a c d - in (g,) } - - - -While -^^^^^ - -Just like for conditionals, Python loops are inlined during tracing. -If you want to capture a loop for dynamic execution, you must use one of several -special operations, :py:func:`jax.lax.while_loop` (a primitive) -and :py:func:`jax.lax.fori_loop` -(a helper that generates a while_loop primitive):: - - lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C - lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C - - -In the above signature, “C” stands for the type of the loop “carry” value. -For example, here is an example fori loop - ->>> import numpy as np ->>> ->>> def func10(arg, n): -... ones = jnp.ones(arg.shape) # A constant -... return lax.fori_loop(0, n, -... lambda i, carry: carry + ones * 3. + arg, -... arg + ones) -... ->>> print(make_jaxpr(func10)(np.ones(16), 5)) -{ lambda ; a:f32[16] b:i32[]. let - c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 - d:f32[16] = add a c - _:i32[] _:i32[] e:f32[16] = while[ - body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let - k:i32[] = add h 1 - l:f32[16] = mul f 3.0 - m:f32[16] = add j l - n:f32[16] = add m g - in (k, i, n) } - body_nconsts=2 - cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let - r:bool[] = lt o p - in (r,) } - cond_nconsts=0 - ] c a 0 b d - in (e,) } - -The while primitive takes 5 arguments: ``c a 0 b d``, as follows: - - * 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0) - * 2 constants for ``body_jaxpr`` (``c``, and ``a``) - * 3 parameters for the initial value of carry - -Scan -^^^^ - -JAX supports a special form of loop over the elements of an array (with -statically known shape). The fact that there are a fixed number of iterations -makes this form of looping easily reverse-differentiable. Such loops are -constructed with the :py:func:`jax.lax.scan` function:: - - lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B]) - -This is written in terms of a `Haskell Type Signature`_: -``C`` is the type of the scan carry, ``A`` is the element type of the -input array(s), and ``B`` is the element type of the output array(s). - -For the example consider the function ``func11`` below - ->>> def func11(arr, extra): -... ones = jnp.ones(arr.shape) # A constant -... def body(carry, aelems): -... # carry: running dot-product of the two arrays -... # aelems: a pair with corresponding elements from the two arrays -... ae1, ae2 = aelems -... return (carry + ae1 * ae2 + extra, carry) -... return lax.scan(body, 0., (arr, ones)) -... ->>> print(make_jaxpr(func11)(np.ones(16), 5.)) -{ lambda ; a:f32[16] b:f32[]. let - c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 - d:f32[] e:f32[16] = scan[ - _split_transpose=False - jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let - j:f32[] = mul h i - k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g - l:f32[] = add k j - m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f - n:f32[] = add l m - in (n, g) } - length=16 - linear=(False, False, False, False) - num_carry=1 - num_consts=1 - reverse=False - unroll=1 - ] b 0.0 a c - in (d, e) } - -The ``linear`` parameter describes for each of the input variables whether they -are guaranteed to be used linearly in the body. Once the scan goes through -linearization, more arguments will be linear. - -The scan primitive takes 4 arguments: ``b 0.0 a c``, of which: - - * one is the free variable for the body - * one is the initial value of the carry - * The next 2 are the arrays over which the scan operates. - -XLA_call -^^^^^^^^ - -The call primitive arises from JIT compilation, and it encapsulates -a sub-jaxpr along with parameters that specify the backend and the device on -which the computation should run. For example - ->>> from jax import jit ->>> ->>> def func12(arg): -... @jit -... def inner(x): -... return x + arg * jnp.ones(1) # Include a constant in the inner function -... return arg + inner(arg - 2.) -... ->>> print(make_jaxpr(func12)(1.)) # doctest:+ELLIPSIS -{ lambda ; a:f32[]. let - b:f32[] = sub a 2.0 - c:f32[1] = pjit[ - name=inner - jaxpr={ lambda ; d:f32[] e:f32[]. let - f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0 - g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d - h:f32[1] = mul g f - i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e - j:f32[1] = add i h - in (j,) } - ] a b - k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a - l:f32[1] = add k c - in (l,) } - - -XLA_pmap -^^^^^^^^ - -If you use the :py:func:`jax.pmap` transformation, the function to be mapped is -captured using the ``xla_pmap`` primitive. Consider this example - ->>> from jax import pmap ->>> ->>> def func13(arr, extra): -... def inner(x): -... # use a free variable "extra" and a constant jnp.ones(1) -... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows') -... return pmap(inner, axis_name='rows')(arr) -... ->>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.)) -{ lambda ; a:f32[1,3] b:f32[]. let - c:f32[1,3] = xla_pmap[ - axis_name=rows - axis_size=1 - backend=None - call_jaxpr={ lambda ; d:f32[] e:f32[3]. let - f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d - g:f32[3] = add e f - h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0 - i:f32[3] = add g h - j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e - k:f32[3] = div i j - in (k,) } - devices=None - donated_invars=(False, False) - global_axis_size=1 - in_axes=(None, 0) - is_explicit_global_axis_size=False - name=inner - out_axes=(0,) - ] b a - in (c,) } - -The ``xla_pmap`` primitive specifies the name of the axis (parameter -``axis_name``) and the body of the function to be mapped as the ``call_jaxpr`` -parameter. The value of this parameter is a Jaxpr with 2 input variables. - -The parameter ``in_axes`` specifies which of the input variables should be -mapped and which should be broadcast. In our example, the value of ``extra`` -is broadcast and the value of ``arr`` is mapped. - -.. _Haskell Type Signature: https://wiki.haskell.org/Type_signature diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index bc6cb3c04cf8..59c7bbd8fb90 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -51,7 +51,7 @@ def log2(x): print(jax.make_jaxpr(log2)(3.0)) ``` -The {ref}`understanding-jaxprs` section of the documentation provides more information on the meaning of the above output. +The {ref}`jax-internals-jaxpr` section of the documentation provides more information on the meaning of the above output. Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb deleted file mode 100644 index e9924e18d023..000000000000 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ /dev/null @@ -1,1532 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "vfxqky4PCUnh" - }, - "source": [ - "# How JAX primitives work\n", - "\n", - "\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", - "\n", - "*necula@google.com*, October 2019.\n", - "\n", - "JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,\n", - "`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,\n", - "which means that as the Python function executes\n", - "the only operations it applies to the data are either inspections of data\n", - "attributes such as shape or type, or special operations called JAX primitives.\n", - "In particular, a JAX-traceable function is sometimes invoked by JAX with\n", - "abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,\n", - "which captures the type and the shape of values, but not the concrete data values.\n", - "JAX primitives know how to operate on both concrete data\n", - "values and on the JAX abstract values.\n", - "\n", - "\n", - "The JAX-transformed functions must themselves be JAX-traceable functions,\n", - "to ensure that these transformations\n", - "can be composed, e.g., `jit(jacfwd(grad(f)))`.\n", - "\n", - "There are pre-defined JAX primitives corresponding to most XLA operations,\n", - "e.g., add, matmul, sin, cos, indexing.\n", - "JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs\n", - "using JAX’s implementation of numpy are JAX-traceable and therefore transformable.\n", - "Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.\n", - "\n", - "The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives,\n", - "one can define a new primitive that encapsulates the behavior of the function.\n", - "\n", - "**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**\n", - "\n", - "Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n", - "as \"multiply_add(x, y, z) = x * y + z\".\n", - "This function operates on 3 identically-shaped tensors of floating point\n", - "values and performs the operations pointwise." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HIJYIHNTD1yI" - }, - "source": [ - "## Using existing primitives\n", - "\n", - "The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other\n", - "functions that are themselves written using JAX primitives, e.g., those\n", - "defined in the `jax.lax` module:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "tbOF0LB0EMne", - "outputId": "3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "square_add_lax = 14.0\n", - "grad(square_add_lax) = 4.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.\n", - " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" - ] - } - ], - "source": [ - "from jax import lax\n", - "from jax._src import api\n", - "\n", - "def multiply_add_lax(x, y, z):\n", - " \"\"\"Implementation of multiply-add using the jax.lax primitives.\"\"\"\n", - " return lax.add(lax.mul(x, y), z)\n", - "\n", - "\n", - "def square_add_lax(a, b):\n", - " \"\"\"A square-add function using the newly defined multiply-add.\"\"\"\n", - " return multiply_add_lax(a, a, b)\n", - "\n", - "print(\"square_add_lax = \", square_add_lax(2., 10.))\n", - "# Differentiate w.r.t. the first argument\n", - "print(\"grad(square_add_lax) = \", api.grad(square_add_lax, argnums=0)(2.0, 10.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Cgv60Wm3E_D5" - }, - "source": [ - "In order to understand how JAX is internally using the primitives,\n", - "we add some helpers for tracing function calls." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "mQRQGEGiE53K" - }, - "outputs": [], - "source": [ - "#@title Helper functions (execute this cell)\n", - "import functools\n", - "import traceback\n", - "\n", - "_indentation = 0\n", - "def _trace(msg=None):\n", - " \"\"\"Print a message at current indentation.\"\"\"\n", - " if msg is not None:\n", - " print(\" \" * _indentation + msg)\n", - "\n", - "def _trace_indent(msg=None):\n", - " \"\"\"Print a message and then indent the rest.\"\"\"\n", - " global _indentation\n", - " _trace(msg)\n", - " _indentation = 1 + _indentation\n", - "\n", - "def _trace_unindent(msg=None):\n", - " \"\"\"Unindent then print a message.\"\"\"\n", - " global _indentation\n", - " _indentation = _indentation - 1\n", - " _trace(msg)\n", - "\n", - "def trace(name):\n", - " \"\"\"A decorator for functions to trace arguments and results.\"\"\"\n", - "\n", - " def trace_func(func): # pylint: disable=missing-docstring\n", - " def pp(v):\n", - " \"\"\"Print certain values more succinctly\"\"\"\n", - " vtype = str(type(v))\n", - " if \"jax._src.xla_bridge._JaxComputationBuilder\" in vtype:\n", - " return \"\"\n", - " elif \"jaxlib.xla_extension.XlaOp\" in vtype:\n", - " return \"\".format(id(v))\n", - " elif (\"partial_eval.JaxprTracer\" in vtype or\n", - " \"batching.BatchTracer\" in vtype or\n", - " \"ad.JVPTracer\" in vtype):\n", - " return \"Traced<{}>\".format(v.aval)\n", - " elif isinstance(v, tuple):\n", - " return \"({})\".format(pp_values(v))\n", - " else:\n", - " return str(v)\n", - " def pp_values(args):\n", - " return \", \".join([pp(arg) for arg in args])\n", - "\n", - " @functools.wraps(func)\n", - " def func_wrapper(*args):\n", - " _trace_indent(\"call {}({})\".format(name, pp_values(args)))\n", - " res = func(*args)\n", - " _trace_unindent(\"|<- {} = {}\".format(name, pp(res)))\n", - " return res\n", - "\n", - " return func_wrapper\n", - "\n", - " return trace_func\n", - "\n", - "class expectNotImplementedError(object):\n", - " \"\"\"Context manager to check for NotImplementedError.\"\"\"\n", - " def __enter__(self): pass\n", - " def __exit__(self, type, value, tb):\n", - " global _indentation\n", - " _indentation = 0\n", - " if type is NotImplementedError:\n", - " print(\"\\nFound expected exception:\")\n", - " traceback.print_exc(limit=3)\n", - " return True\n", - " elif type is None: # No exception\n", - " assert False, \"Expected NotImplementedError\"\n", - " else:\n", - " return False" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qf4eLrLCFYDl" - }, - "source": [ - "Instead of using `jax.lax` primitives directly, we can use other functions\n", - "that are already written in terms of those primitives, such as those in `jax.numpy`:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "QhKorz6cFRJb", - "outputId": "aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Normal evaluation:\n", - "call square_add_numpy(2.0, 10.0)\n", - " call multiply_add_numpy(2.0, 2.0, 10.0)\n", - " |<- multiply_add_numpy = 14.0\n", - "|<- square_add_numpy = 14.0\n", - "square_add_numpy = 14.0\n", - "\n", - "Gradient evaluation:\n", - "call square_add_numpy(Traced, 10.0)\n", - " call multiply_add_numpy(Traced, Traced, 10.0)\n", - " |<- multiply_add_numpy = Traced\n", - "|<- square_add_numpy = Traced\n", - "grad(square_add_numpy) = 4.0\n" - ] - } - ], - "source": [ - "import jax.numpy as jnp\n", - "import numpy as np\n", - "\n", - "@trace(\"multiply_add_numpy\")\n", - "def multiply_add_numpy(x, y, z):\n", - " return jnp.add(jnp.multiply(x, y), z)\n", - "\n", - "@trace(\"square_add_numpy\")\n", - "def square_add_numpy(a, b):\n", - " return multiply_add_numpy(a, a, b)\n", - "\n", - "print(\"\\nNormal evaluation:\")\n", - "print(\"square_add_numpy = \", square_add_numpy(2., 10.))\n", - "print(\"\\nGradient evaluation:\")\n", - "print(\"grad(square_add_numpy) = \", api.grad(square_add_numpy)(2.0, 10.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Sg-D8EdeFn4a" - }, - "source": [ - "Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and\n", - "`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further\n", - "below in this colab).\n", - "It is important to remember that a JAX-traceable function must be able to\n", - "operate not only on concrete arguments but also on special abstract arguments\n", - "that JAX may use to abstract the function execution.\n", - "\n", - "The JAX traceability property is satisfied as long as the function is written\n", - "in terms of JAX primitives." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WxrQO7-XGLcg" - }, - "source": [ - "## Defining new JAX primitives\n", - "\n", - "The right way to add support for multiply-add is in terms of existing\n", - "JAX primitives, as shown above. However, in order to demonstrate how JAX\n", - "primitives work let us pretend that we want to add a new primitive to\n", - "JAX for the multiply-add functionality." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cPqAH1XOGTN4" - }, - "outputs": [], - "source": [ - "from jax import core\n", - "multiply_add_p = core.Primitive(\"multiply_add\") # Create the primitive\n", - "\n", - "@trace(\"multiply_add_prim\")\n", - "def multiply_add_prim(x, y, z):\n", - " \"\"\"The JAX-traceable way to use the JAX primitive.\n", - "\n", - " Note that the traced arguments must be passed as positional arguments\n", - " to `bind`.\n", - " \"\"\"\n", - " return multiply_add_p.bind(x, y, z)\n", - "\n", - "@trace(\"square_add_prim\")\n", - "def square_add_prim(a, b):\n", - " \"\"\"A square-add function implemented using the new JAX-primitive.\"\"\"\n", - " return multiply_add_prim(a, a, b)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LMzs5PAKGr-4" - }, - "source": [ - "If we try to call the newly defined functions we get an error, because\n", - "we have not yet told JAX anything about the semantics of the new primitive." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "_X3PAYxhGpWd", - "outputId": "90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(2.0, 10.0)\n", - " call multiply_add_prim(2.0, 2.0, 10.0)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " square_add_prim(2., 10.)\n", - " File \"\", line 47, in func_wrapper\n", - " res = func(*args)\n", - " File \"\", line 16, in square_add_prim\n", - " return multiply_add_prim(a, a, b)\n", - "NotImplementedError: Evaluation rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "with expectNotImplementedError():\n", - " square_add_prim(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "elha0FdgHSEF" - }, - "source": [ - "### Primal evaluation rules" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "FT34FFAGHARU", - "outputId": "4c54f1c2-8a50-4788-90e1-06aee412c43b" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "@trace(\"multiply_add_impl\")\n", - "def multiply_add_impl(x, y, z):\n", - " \"\"\"Concrete implementation of the primitive.\n", - "\n", - " This function does not need to be JAX traceable.\n", - " Args:\n", - " x, y, z: the concrete arguments of the primitive. Will only be called with\n", - " concrete values.\n", - " Returns:\n", - " the concrete result of the primitive.\n", - " \"\"\"\n", - " # Note that we can use the original numpy, which is not JAX traceable\n", - " return np.add(np.multiply(x, y), z)\n", - "\n", - "# Now we register the primal implementation with JAX\n", - "multiply_add_p.def_impl(multiply_add_impl)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "G5bstKaeNAVV", - "outputId": "deb94d5b-dfea-4e6f-9ec2-70b416c996c5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(2.0, 10.0)\n", - " call multiply_add_prim(2.0, 2.0, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - "|<- square_add_prim = 14.0\n" - ] - } - ], - "source": [ - "assert square_add_prim(2., 10.) == 14." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "upBf-uAuHhPJ" - }, - "source": [ - "### JIT\n", - "\n", - "If we now try to use `jit` we get a `NotImplementedError`:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "QG-LULjiHk4b", - "outputId": "d4ef4406-8dae-4c96-97ca-b662340474ee" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.jit(square_add_prim)(2., 10.)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 149, in f_jitted\n", - " out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/core.py\", line 569, in call_bind\n", - " outs = primitive.impl(f, *args, **params)\n", - "NotImplementedError: Abstract evaluation for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "with expectNotImplementedError():\n", - " api.jit(square_add_prim)(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rHS1bAGHH44E" - }, - "source": [ - "#### Abstract evaluation rules\n", - "In order to JIT the function, and for other transformations as well,\n", - "JAX first evaluates it abstractly using only the\n", - "shape and type of the arguments. This abstract evaluation serves multiple\n", - "purposes:\n", - "\n", - " * Gets the sequence of JAX primitives that are used in the computation. This\n", - " sequence will be compiled.\n", - " * Computes the shape and type of all vectors and operations used in the computation.\n", - "\n", - "\n", - "For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.\n", - "In the latter case, JAX uses the actual concrete value wrapped as an abstract value." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "ctQmEeckIbdo", - "outputId": "e751d0cc-460e-4ffd-df2e-fdabf9cffdc2" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 9, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "from jax import core\n", - "@trace(\"multiply_add_abstract_eval\")\n", - "def multiply_add_abstract_eval(xs, ys, zs):\n", - " \"\"\"Abstract evaluation of the primitive.\n", - "\n", - " This function does not need to be JAX traceable. It will be invoked with\n", - " abstractions of the actual arguments.\n", - " Args:\n", - " xs, ys, zs: abstractions of the arguments.\n", - " Result:\n", - " a ShapedArray for the result of the primitive.\n", - " \"\"\"\n", - " assert xs.shape == ys.shape\n", - " assert xs.shape == zs.shape\n", - " return core.ShapedArray(xs.shape, xs.dtype)\n", - "\n", - "# Now we register the abstract evaluation with JAX\n", - "multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RPN88X6YI43A" - }, - "source": [ - "If we re-attempt to JIT, we see how the abstract evaluation proceeds, but\n", - "we get another error, about missing the actual XLA compilation rule:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "eOcNR92SI2h-", - "outputId": "356ef229-3703-4696-cc3d-7c05de405fb0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.jit(square_add_prim)(2., 10.)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 149, in f_jitted\n", - " out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/core.py\", line 569, in call_bind\n", - " outs = primitive.impl(f, *args, **params)\n", - "NotImplementedError: XLA translation rule for primitive 'multiply_add' not found\n" - ] - } - ], - "source": [ - "with expectNotImplementedError():\n", - " api.jit(square_add_prim)(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9IOV1R-fJMHp" - }, - "source": [ - "#### XLA Compilation rules\n", - "\n", - "JAX compilation works by compiling each primitive into a graph of XLA operations.\n", - "\n", - "This is the biggest hurdle to adding new functionality to JAX, because the\n", - "set of XLA operations is limited, and JAX already has pre-defined primitives\n", - "for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FYQWSSjKJaWP" - }, - "outputs": [], - "source": [ - "from jax._src.lib.mlir.dialects import hlo\n", - "@trace(\"multiply_add_lowering\")\n", - "def multiply_add_lowering(ctx, xc, yc, zc):\n", - " \"\"\"The compilation to XLA of the primitive.\n", - "\n", - " Given an mlir.ir.Value for each argument, return the mlir.ir.Values for\n", - " the results of the function.\n", - "\n", - " Does not need to be a JAX-traceable function.\n", - " \"\"\"\n", - " return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]\n", - "\n", - "# Now we register the lowering rule with JAX\n", - "# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)\n", - "# TODO: TPU?\n", - "from jax.interpreters import mlir\n", - "mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K98LX-VaJkFu" - }, - "source": [ - "Now we succeed to JIT. Notice below that JAX first evaluates the function\n", - "abstractly, which triggers the `multiply_add_abstract_eval` function, and\n", - "then compiles the set of primitives it has encountered, including `multiply_add`.\n", - "At this point JAX invokes `multiply_add_xla_translation`." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "rj3TLsolJgEc", - "outputId": "e384bee4-1e9c-4344-f49c-d3b5ec08eb32" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Omrez-2_KFfo" - }, - "source": [ - "Below is another use of `jit` where we compile only\n", - "with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads\n", - "in the third argument to `multiply_add_abstract_eval` being\n", - "`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with\n", - "both `ShapedArray` and `ConcreteArray`." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "mPfTwIBoKOEK", - "outputId": "b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, 10.0)\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ConcreteArray(10.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(lambda x, y: square_add_prim(x, y),\n", - " static_argnums=1)(2., 10.) == 14." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_Ya3B5l4J1VA" - }, - "source": [ - "### Forward differentiation\n", - "\n", - "JAX implements forward differentiation in the form of\n", - "a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)).\n", - "\n", - "If we attempt now to compute the `jvp` function we get an\n", - "error because we have not yet told JAX how to differentiate\n", - "the `multiply_add` primitive." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "OxDx6NQnKwMI", - "outputId": "ce659ef3-c03c-4856-f252-49ec4b6eb964" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\", line 217, in process_primitive\n", - " jvp = primitive_jvps[primitive]\n", - "KeyError: multiply_add\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.jvp(square_add_prim, (2., 10.), (1., 1.))\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 978, in jvp\n", - " out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/linear_util.py\", line 165, in call_wrapped\n", - " ans = self.f(*args, **dict(self.params, **kwargs))\n", - "NotImplementedError: Forward-mode differentiation rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "# The second argument `(2., 10.)` are the argument values\n", - "# where we evaluate the Jacobian, and the third `(1., 1.)`\n", - "# are the values of the tangents for the arguments.\n", - "with expectNotImplementedError():\n", - " api.jvp(square_add_prim, (2., 10.), (1., 1.))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zxG24C1JMIMM" - }, - "outputs": [], - "source": [ - "from jax.interpreters import ad\n", - "\n", - "\n", - "@trace(\"multiply_add_value_and_jvp\")\n", - "def multiply_add_value_and_jvp(arg_values, arg_tangents):\n", - " \"\"\"Evaluates the primal output and the tangents (Jacobian-vector product).\n", - "\n", - " Given values of the arguments and perturbation of the arguments (tangents),\n", - " compute the output of the primitive and the perturbation of the output.\n", - "\n", - " This method must be JAX-traceable. JAX may invoke it with abstract values\n", - " for the arguments and tangents.\n", - "\n", - " Args:\n", - " arg_values: a tuple of arguments\n", - " arg_tangents: a tuple with the tangents of the arguments. The tuple has\n", - " the same length as the arg_values. Some of the tangents may also be the\n", - " special value ad.Zero to specify a zero tangent.\n", - " Returns:\n", - " a pair of the primal output and the tangent.\n", - " \"\"\"\n", - " x, y, z = arg_values\n", - " xt, yt, zt = arg_tangents\n", - " _trace(\"Primal evaluation:\")\n", - " # Now we have a JAX-traceable computation of the output.\n", - " # Normally, we can use the ma primitive itself to compute the primal output.\n", - " primal_out = multiply_add_prim(x, y, z)\n", - "\n", - " _trace(\"Tangent evaluation:\")\n", - " # We must use a JAX-traceable way to compute the tangent. It turns out that\n", - " # the output tangent can be computed as (xt * y + x * yt + zt),\n", - " # which we can implement in a JAX-traceable way using the same \"multiply_add_prim\" primitive.\n", - "\n", - " # We do need to deal specially with Zero. Here we just turn it into a\n", - " # proper tensor of 0s (of the same shape as 'x').\n", - " # An alternative would be to check for Zero and perform algebraic\n", - " # simplification of the output tangent computation.\n", - " def make_zero(tan):\n", - " return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan\n", - "\n", - " output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))\n", - " return (primal_out, output_tangent)\n", - "\n", - "# Register the forward differentiation rule with JAX\n", - "ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "ma3KBkiAMfW1", - "outputId": "f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))\n", - " Primal evaluation:\n", - " call multiply_add_prim(2.0, 2.0, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - " Tangent evaluation:\n", - " call multiply_add_prim(2.0, 1.0, 1.0)\n", - " call multiply_add_impl(2.0, 1.0, 1.0)\n", - " |<- multiply_add_impl = 3.0\n", - " |<- multiply_add_prim = 3.0\n", - " call multiply_add_prim(1.0, 2.0, 3.0)\n", - " call multiply_add_impl(1.0, 2.0, 3.0)\n", - " |<- multiply_add_impl = 5.0\n", - " |<- multiply_add_prim = 5.0\n", - " |<- multiply_add_value_and_jvp = (14.0, 5.0)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n" - ] - } - ], - "source": [ - "# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.\n", - "assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "69QsEcu-lP4u" - }, - "source": [ - "TO EXPLAIN:\n", - "\n", - " * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.\n", - " * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet\n", - " we do not call the multiply_add_abstract_eval.\n", - " * I think it would be useful to show the jaxpr here" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Sb6e3ZAHOPHv" - }, - "source": [ - "#### JIT of forward differentiation\n", - "\n", - "We can apply JIT to the forward differentiation function:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "hg-hzVu-N-hv", - "outputId": "38d32067-e152-4046-ad80-7f95a31ba628" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_value_and_jvp((Traced, Traced, Traced), (Traced, Traced, Traced))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (Traced, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(lambda arg_values, arg_tangents:\n", - " api.jvp(square_add_prim, arg_values, arg_tangents))(\n", - " (2., 10.), (1., 1.)) == (14., 5.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jlZt1_v2mU88" - }, - "source": [ - "Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn\n", - "evaluates abstractly both the primal and the tangent evaluation (a total of\n", - "3 invocations of the `ma` primitive). Then we compile the 3 occurrences\n", - "of the primitive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "555yt6ZIOePB" - }, - "source": [ - "### Reverse differentiation\n", - "\n", - "If we attempt now to use reverse differentiation we\n", - "see that JAX starts by using the `multiply_add_value_and_jvp` to\n", - "compute the forward differentiation for abstract values, but then runs\n", - "into a `NotImplementedError`.\n", - "\n", - "When computing the reverse differentiation JAX first does abstract evaluation\n", - "of the forward differentiation code `multiply_add_value_and_jvp` to obtain a\n", - "trace of primitives that compute the output tangent.\n", - "Observe that JAX performs this abstract evaluation with concrete values\n", - "for the differentiation point, and abstract values for the tangents.\n", - "Observe also that JAX uses the special abstract tangent value `Zero` for\n", - "the tangent corresponding to the 3rd argument of `ma`. This reflects the\n", - "fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n", - "which flows to the 3rd argument to `multiply_add_prim`.\n", - "\n", - "Observe also that during the abstract evaluation of the tangent we pass the\n", - "value 0.0 as the tangent for the 3rd argument. This is due to the use\n", - "of the `make_zero` function in the definition of `multiply_add_value_and_jvp`." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "8eAVnexaOjBn", - "outputId": "e4ee89cf-ab4a-4505-9817-fa978a2865ab" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, 10.0)\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_value_and_jvp((Traced, Traced, 10.0), (Traced, Traced, Zero))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, 0.0)\n", - " call multiply_add_abstract_eval(ConcreteArray(2.0), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(2.0), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (14.0, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\", line 198, in get_primitive_transpose\n", - " return primitive_transposes[p]\n", - "KeyError: multiply_add\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.grad(square_add_prim)(2., 10.)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 340, in grad_f\n", - " _, g = value_and_grad_f(*args, **kwargs)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 398, in value_and_grad_f\n", - " g = vjp_py(np.ones((), dtype=dtype))\n", - "NotImplementedError: Reverse-mode differentiation rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "# This is reverse differentiation w.r.t. the first argument of square_add_prim\n", - "with expectNotImplementedError():\n", - " api.grad(square_add_prim)(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fSHLUMDN26AY" - }, - "source": [ - "The above error is because there is a missing piece for JAX to be able\n", - "to use the forward differentiation code to compute reverse differentiation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3ibDbGF-PjK9" - }, - "source": [ - "#### Transposition\n", - "\n", - "\n", - "As explained above, when computing reverse differentiation JAX obtains\n", - "a trace of primitives that compute the tangent using forward differentiation.\n", - "Then, **JAX interprets this trace abstractly backwards** and for each\n", - "primitive it applies a **transposition** rule.\n", - "\n", - "To understand what is going on, consider for now a simpler example of the function \"f(x, y) = x * y + y\". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:\n", - "```\n", - " a = xt * 4.\n", - " b = 2. * yt\n", - " c = a + b\n", - " ft = c + yt\n", - "```\n", - "\n", - "By construction, the tangent calculation is always linear in the input tangents.\n", - "The only non-linear operator that may arise in the tangent calculation is multiplication,\n", - "but then one of the operands is constant.\n", - "\n", - "JAX will produce the reverse differentiation computation by processing the\n", - "JVP computation backwards. For each operation in the tangent computation,\n", - "it accumulates the cotangents\n", - "of the variables used by the operation, using the cotangent of the result\n", - "of the operation:\n", - "```\n", - " # Initialize cotangents of inputs and intermediate vars\n", - " xct = yct = act = bct = cct = 0.\n", - " # Initialize cotangent of the output\n", - " fct = 1.\n", - " # Process \"ft = c + yt\"\n", - " cct += fct\n", - " yct += fct\n", - " # Process \"c = a + b\"\n", - " act += cct\n", - " bct += cct\n", - " # Process \"b = 2. * yt\"\n", - " yct += 2. * bct\n", - " # Process \"a = xt * 4.\"\n", - " xct += act * 4.\n", - "```\n", - "\n", - "One can verify that this computation produces `xct = 4.` and `yct = 3.`, which\n", - "are the partial derivatives of the function `f`.\n", - "\n", - "JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:\n", - "```\n", - "p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)\n", - "```\n", - "\n", - "Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other\n", - "arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned\n", - "for the constant arguments.\n", - "\n", - "In particular,\n", - "```\n", - " add_transpose(out_ct, _, _) = (out_ct, out_ct)\n", - " mult_transpose(out_ct, x, _) = (None, x * out_ct)\n", - " mult_transpose(out_ct, _, y) = (out_ct * y, None)\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JaHxFdkRO42r" - }, - "outputs": [], - "source": [ - "@trace(\"multiply_add_transpose\")\n", - "def multiply_add_transpose(ct, x, y, z):\n", - " \"\"\"Evaluates the transpose of a linear primitive.\n", - "\n", - " This method is only used when computing the backward gradient following\n", - " value_and_jvp, and is only needed for primitives that are used in the JVP\n", - " calculation for some other primitive. We need transposition for multiply_add_prim,\n", - " because we have used multiply_add_prim in the computation of the output_tangent in\n", - " multiply_add_value_and_jvp.\n", - "\n", - " In our case, multiply_add is not a linear primitive. However, it is used linearly\n", - " w.r.t. tangents in multiply_add_value_and_jvp:\n", - " output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n", - "\n", - " Always one of the first two multiplicative arguments is a constant.\n", - "\n", - " Args:\n", - " ct: the cotangent of the output of the primitive.\n", - " x, y, z: values of the arguments. The arguments that are used linearly\n", - " get an ad.UndefinedPrimal value. The other arguments get a constant\n", - " value.\n", - " Returns:\n", - " a tuple with the cotangent of the inputs, with the value None\n", - " corresponding to the constant arguments.\n", - " \"\"\"\n", - " if not ad.is_undefined_primal(x):\n", - " # This use of multiply_add is with a constant \"x\"\n", - " assert ad.is_undefined_primal(y)\n", - " ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))\n", - " res = None, ct_y, ct\n", - " else:\n", - " # This use of multiply_add is with a constant \"y\"\n", - " assert ad.is_undefined_primal(x)\n", - " ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))\n", - " res = ct_x, None, ct\n", - " return res\n", - "\n", - "\n", - "ad.primitive_transposes[multiply_add_p] = multiply_add_transpose" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PpChox-Jp7wb" - }, - "source": [ - "Now we can complete the run of the `grad`:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "PogPKS4MPevd", - "outputId": "d33328d4-3e87-45b5-9b31-21ad624b67af" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, 10.0)\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_value_and_jvp((Traced, Traced, 10.0), (Traced, Traced, Zero))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, 0.0)\n", - " call multiply_add_abstract_eval(ConcreteArray(2.0), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(2.0), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (14.0, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_transpose(1.0, _, 2.0, _)\n", - " call multiply_add_prim(1.0, 2.0, 0.0)\n", - " call multiply_add_impl(1.0, 2.0, 0.0)\n", - " |<- multiply_add_impl = 2.0\n", - " |<- multiply_add_prim = 2.0\n", - "|<- multiply_add_transpose = (2.0, None, 1.0)\n", - "call multiply_add_transpose(1.0, 2.0, _, 0.0)\n", - " call multiply_add_prim(2.0, 1.0, 0.0)\n", - " call multiply_add_impl(2.0, 1.0, 0.0)\n", - " |<- multiply_add_impl = 2.0\n", - " |<- multiply_add_prim = 2.0\n", - "|<- multiply_add_transpose = (None, 2.0, 1.0)\n" - ] - } - ], - "source": [ - "assert api.grad(square_add_prim)(2., 10.) == 4." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8M1xLCXW4fK7" - }, - "source": [ - "Notice the two calls to `multiply_add_transpose`. They correspond to the two\n", - "uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the\n", - "last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EIJs6FYmPg6c" - }, - "source": [ - "#### JIT of reverse differentiation\n", - "\n", - "Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n", - "abstract values, while in the absence of JIT we used `ConcreteArray`." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "FZ-JGbWZPq2-", - "outputId": "e42b5222-9c3e-4853-e13a-874f6605d178" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_value_and_jvp((Traced, Traced, Traced), (Traced, Traced, Zero))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (Traced, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_transpose(1.0, _, Traced, _)\n", - " call multiply_add_prim(1.0, Traced, Traced)\n", - " call multiply_add_abstract_eval(ConcreteArray(1.0), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- multiply_add_transpose = (Traced, None, 1.0)\n", - "call multiply_add_transpose(1.0, Traced, _, Traced)\n", - " call multiply_add_prim(Traced, 1.0, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(1.0), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- multiply_add_transpose = (None, Traced, 1.0)\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(api.grad(square_add_prim))(2., 10.) == 4." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-3lqPkdQPvl5" - }, - "source": [ - "### Batching\n", - "\n", - "The batching transformation takes a point-wise computation and turns it\n", - "into a computation on vectors. If we try it right now, we get a `NotImplementedError`:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "hFvBR3I9Pzh3", - "outputId": "434608bc-281f-4d3b-83bd-eaaf3b51b1cd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 163, in get_primitive_batcher\n", - " return primitive_batchers[p]\n", - "KeyError: multiply_add\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"\", line 3, in \n", - " np.array([10., 20.]))\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 611, in batched_fun\n", - " lambda: _flatten_axes(out_tree(), out_axes))\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 41, in batch\n", - " out_vals, out_dims = batch2(fun, in_vals, in_dims)\n", - "NotImplementedError: Batching rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "# The arguments are two vectors instead of two scalars\n", - "with expectNotImplementedError():\n", - " api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),\n", - " np.array([10., 20.]))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gILasMiP6elR" - }, - "source": [ - "We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KQfeqRIrP7zg" - }, - "outputs": [], - "source": [ - "from jax.interpreters import batching\n", - "\n", - "\n", - "@trace(\"multiply_add_batch\")\n", - "def multiply_add_batch(vector_arg_values, batch_axes):\n", - " \"\"\"Computes the batched version of the primitive.\n", - "\n", - " This must be a JAX-traceable function.\n", - "\n", - " Since the multiply_add primitive already operates pointwise on arbitrary\n", - " dimension tensors, to batch it we can use the primitive itself. This works as\n", - " long as both the inputs have the same dimensions and are batched along the\n", - " same axes. The result is batched along the axis that the inputs are batched.\n", - "\n", - " Args:\n", - " vector_arg_values: a tuple of two arguments, each being a tensor of matching\n", - " shape.\n", - " batch_axes: the axes that are being batched. See vmap documentation.\n", - " Returns:\n", - " a tuple of the result, and the result axis that was batched.\n", - " \"\"\"\n", - " assert batch_axes[0] == batch_axes[1]\n", - " assert batch_axes[0] == batch_axes[2]\n", - " _trace(\"Using multiply_add to compute the batch:\")\n", - " res = multiply_add_prim(*vector_arg_values)\n", - " return res, batch_axes[0]\n", - "\n", - "\n", - "batching.primitive_batchers[multiply_add_p] = multiply_add_batch" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "id": "VwxNk869P_YG", - "outputId": "9d22c921-5803-4d33-9e88-b6e439ba9738" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))\n", - " Using multiply_add to compute the batch:\n", - " call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])\n", - " call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])\n", - " |<- multiply_add_impl = [14. 29.]\n", - " |<- multiply_add_prim = [14. 29.]\n", - " |<- multiply_add_batch = ([14. 29.], 0)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n" - ] - } - ], - "source": [ - "assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(\n", - " np.array([2., 3.]),\n", - " np.array([10., 20.])),\n", - " [14., 29.])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NmqLlV1TQDCC" - }, - "source": [ - "#### JIT of batching" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "id": "xqEdXVUgQCTt", - "outputId": "9c22fd9c-919c-491d-bbeb-32c241b808fa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_batch((Traced, Traced, Traced), (0, 0, 0))\n", - " Using multiply_add to compute the batch:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[2])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_batch = (Traced, 0)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))\n", - " (np.array([2., 3.]),\n", - " np.array([10., 20.])),\n", - " [14., 29.])" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "How JAX primitives work.ipynb", - "provenance": [], - "toc_visible": true - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md deleted file mode 100644 index 7c24ac11a6ce..000000000000 --- a/docs/notebooks/How_JAX_primitives_work.md +++ /dev/null @@ -1,771 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.4 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "vfxqky4PCUnh"} - -# How JAX primitives work - - - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) - -*necula@google.com*, October 2019. - -JAX implements certain transformations of Python functions, e.g., `jit`, `grad`, -`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, -which means that as the Python function executes -the only operations it applies to the data are either inspections of data -attributes such as shape or type, or special operations called JAX primitives. -In particular, a JAX-traceable function is sometimes invoked by JAX with -abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, -which captures the type and the shape of values, but not the concrete data values. -JAX primitives know how to operate on both concrete data -values and on the JAX abstract values. - - -The JAX-transformed functions must themselves be JAX-traceable functions, -to ensure that these transformations -can be composed, e.g., `jit(jacfwd(grad(f)))`. - -There are pre-defined JAX primitives corresponding to most XLA operations, -e.g., add, matmul, sin, cos, indexing. -JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs -using JAX’s implementation of numpy are JAX-traceable and therefore transformable. -Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives. - -The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives, -one can define a new primitive that encapsulates the behavior of the function. - -**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.** - -Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically -as "multiply_add(x, y, z) = x * y + z". -This function operates on 3 identically-shaped tensors of floating point -values and performs the operations pointwise. - -+++ {"id": "HIJYIHNTD1yI"} - -## Using existing primitives - -The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other -functions that are themselves written using JAX primitives, e.g., those -defined in the `jax.lax` module: - -```{code-cell} ipython3 -:id: tbOF0LB0EMne -:outputId: 3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528 - -from jax import lax -from jax._src import api - -def multiply_add_lax(x, y, z): - """Implementation of multiply-add using the jax.lax primitives.""" - return lax.add(lax.mul(x, y), z) - - -def square_add_lax(a, b): - """A square-add function using the newly defined multiply-add.""" - return multiply_add_lax(a, a, b) - -print("square_add_lax = ", square_add_lax(2., 10.)) -# Differentiate w.r.t. the first argument -print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.)) -``` - -+++ {"id": "Cgv60Wm3E_D5"} - -In order to understand how JAX is internally using the primitives, -we add some helpers for tracing function calls. - -```{code-cell} ipython3 -:cellView: form -:id: mQRQGEGiE53K - -#@title Helper functions (execute this cell) -import functools -import traceback - -_indentation = 0 -def _trace(msg=None): - """Print a message at current indentation.""" - if msg is not None: - print(" " * _indentation + msg) - -def _trace_indent(msg=None): - """Print a message and then indent the rest.""" - global _indentation - _trace(msg) - _indentation = 1 + _indentation - -def _trace_unindent(msg=None): - """Unindent then print a message.""" - global _indentation - _indentation = _indentation - 1 - _trace(msg) - -def trace(name): - """A decorator for functions to trace arguments and results.""" - - def trace_func(func): # pylint: disable=missing-docstring - def pp(v): - """Print certain values more succinctly""" - vtype = str(type(v)) - if "jax._src.xla_bridge._JaxComputationBuilder" in vtype: - return "" - elif "jaxlib.xla_extension.XlaOp" in vtype: - return "".format(id(v)) - elif ("partial_eval.JaxprTracer" in vtype or - "batching.BatchTracer" in vtype or - "ad.JVPTracer" in vtype): - return "Traced<{}>".format(v.aval) - elif isinstance(v, tuple): - return "({})".format(pp_values(v)) - else: - return str(v) - def pp_values(args): - return ", ".join([pp(arg) for arg in args]) - - @functools.wraps(func) - def func_wrapper(*args): - _trace_indent("call {}({})".format(name, pp_values(args))) - res = func(*args) - _trace_unindent("|<- {} = {}".format(name, pp(res))) - return res - - return func_wrapper - - return trace_func - -class expectNotImplementedError(object): - """Context manager to check for NotImplementedError.""" - def __enter__(self): pass - def __exit__(self, type, value, tb): - global _indentation - _indentation = 0 - if type is NotImplementedError: - print("\nFound expected exception:") - traceback.print_exc(limit=3) - return True - elif type is None: # No exception - assert False, "Expected NotImplementedError" - else: - return False -``` - -+++ {"id": "Qf4eLrLCFYDl"} - -Instead of using `jax.lax` primitives directly, we can use other functions -that are already written in terms of those primitives, such as those in `jax.numpy`: - -```{code-cell} ipython3 -:id: QhKorz6cFRJb -:outputId: aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a - -import jax.numpy as jnp -import numpy as np - -@trace("multiply_add_numpy") -def multiply_add_numpy(x, y, z): - return jnp.add(jnp.multiply(x, y), z) - -@trace("square_add_numpy") -def square_add_numpy(a, b): - return multiply_add_numpy(a, a, b) - -print("\nNormal evaluation:") -print("square_add_numpy = ", square_add_numpy(2., 10.)) -print("\nGradient evaluation:") -print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) -``` - -+++ {"id": "Sg-D8EdeFn4a"} - -Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and -`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further -below in this colab). -It is important to remember that a JAX-traceable function must be able to -operate not only on concrete arguments but also on special abstract arguments -that JAX may use to abstract the function execution. - -The JAX traceability property is satisfied as long as the function is written -in terms of JAX primitives. - -+++ {"id": "WxrQO7-XGLcg"} - -## Defining new JAX primitives - -The right way to add support for multiply-add is in terms of existing -JAX primitives, as shown above. However, in order to demonstrate how JAX -primitives work let us pretend that we want to add a new primitive to -JAX for the multiply-add functionality. - -```{code-cell} ipython3 -:id: cPqAH1XOGTN4 - -from jax import core -multiply_add_p = core.Primitive("multiply_add") # Create the primitive - -@trace("multiply_add_prim") -def multiply_add_prim(x, y, z): - """The JAX-traceable way to use the JAX primitive. - - Note that the traced arguments must be passed as positional arguments - to `bind`. - """ - return multiply_add_p.bind(x, y, z) - -@trace("square_add_prim") -def square_add_prim(a, b): - """A square-add function implemented using the new JAX-primitive.""" - return multiply_add_prim(a, a, b) -``` - -+++ {"id": "LMzs5PAKGr-4"} - -If we try to call the newly defined functions we get an error, because -we have not yet told JAX anything about the semantics of the new primitive. - -```{code-cell} ipython3 -:id: _X3PAYxhGpWd -:outputId: 90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8 - -with expectNotImplementedError(): - square_add_prim(2., 10.) -``` - -+++ {"id": "elha0FdgHSEF"} - -### Primal evaluation rules - -```{code-cell} ipython3 -:id: FT34FFAGHARU -:outputId: 4c54f1c2-8a50-4788-90e1-06aee412c43b - -@trace("multiply_add_impl") -def multiply_add_impl(x, y, z): - """Concrete implementation of the primitive. - - This function does not need to be JAX traceable. - Args: - x, y, z: the concrete arguments of the primitive. Will only be called with - concrete values. - Returns: - the concrete result of the primitive. - """ - # Note that we can use the original numpy, which is not JAX traceable - return np.add(np.multiply(x, y), z) - -# Now we register the primal implementation with JAX -multiply_add_p.def_impl(multiply_add_impl) -``` - -```{code-cell} ipython3 -:id: G5bstKaeNAVV -:outputId: deb94d5b-dfea-4e6f-9ec2-70b416c996c5 - -assert square_add_prim(2., 10.) == 14. -``` - -+++ {"id": "upBf-uAuHhPJ"} - -### JIT - -If we now try to use `jit` we get a `NotImplementedError`: - -```{code-cell} ipython3 -:id: QG-LULjiHk4b -:outputId: d4ef4406-8dae-4c96-97ca-b662340474ee - -with expectNotImplementedError(): - api.jit(square_add_prim)(2., 10.) -``` - -+++ {"id": "rHS1bAGHH44E"} - -#### Abstract evaluation rules -In order to JIT the function, and for other transformations as well, -JAX first evaluates it abstractly using only the -shape and type of the arguments. This abstract evaluation serves multiple -purposes: - - * Gets the sequence of JAX primitives that are used in the computation. This - sequence will be compiled. - * Computes the shape and type of all vectors and operations used in the computation. - - -For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. -In the latter case, JAX uses the actual concrete value wrapped as an abstract value. - -```{code-cell} ipython3 -:id: ctQmEeckIbdo -:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2 - -from jax import core -@trace("multiply_add_abstract_eval") -def multiply_add_abstract_eval(xs, ys, zs): - """Abstract evaluation of the primitive. - - This function does not need to be JAX traceable. It will be invoked with - abstractions of the actual arguments. - Args: - xs, ys, zs: abstractions of the arguments. - Result: - a ShapedArray for the result of the primitive. - """ - assert xs.shape == ys.shape - assert xs.shape == zs.shape - return core.ShapedArray(xs.shape, xs.dtype) - -# Now we register the abstract evaluation with JAX -multiply_add_p.def_abstract_eval(multiply_add_abstract_eval) -``` - -+++ {"id": "RPN88X6YI43A"} - -If we re-attempt to JIT, we see how the abstract evaluation proceeds, but -we get another error, about missing the actual XLA compilation rule: - -```{code-cell} ipython3 -:id: eOcNR92SI2h- -:outputId: 356ef229-3703-4696-cc3d-7c05de405fb0 - -with expectNotImplementedError(): - api.jit(square_add_prim)(2., 10.) -``` - -+++ {"id": "9IOV1R-fJMHp"} - -#### XLA Compilation rules - -JAX compilation works by compiling each primitive into a graph of XLA operations. - -This is the biggest hurdle to adding new functionality to JAX, because the -set of XLA operations is limited, and JAX already has pre-defined primitives -for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++. - -```{code-cell} ipython3 -:id: FYQWSSjKJaWP - -from jax._src.lib.mlir.dialects import hlo -@trace("multiply_add_lowering") -def multiply_add_lowering(ctx, xc, yc, zc): - """The compilation to XLA of the primitive. - - Given an mlir.ir.Value for each argument, return the mlir.ir.Values for - the results of the function. - - Does not need to be a JAX-traceable function. - """ - return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result] - -# Now we register the lowering rule with JAX -# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html) -# TODO: TPU? -from jax.interpreters import mlir -mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') -``` - -+++ {"id": "K98LX-VaJkFu"} - -Now we succeed to JIT. Notice below that JAX first evaluates the function -abstractly, which triggers the `multiply_add_abstract_eval` function, and -then compiles the set of primitives it has encountered, including `multiply_add`. -At this point JAX invokes `multiply_add_xla_translation`. - -```{code-cell} ipython3 -:id: rj3TLsolJgEc -:outputId: e384bee4-1e9c-4344-f49c-d3b5ec08eb32 - -assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. -``` - -+++ {"id": "Omrez-2_KFfo"} - -Below is another use of `jit` where we compile only -with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads -in the third argument to `multiply_add_abstract_eval` being -`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with -both `ShapedArray` and `ConcreteArray`. - -```{code-cell} ipython3 -:id: mPfTwIBoKOEK -:outputId: b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b - -assert api.jit(lambda x, y: square_add_prim(x, y), - static_argnums=1)(2., 10.) == 14. -``` - -+++ {"id": "_Ya3B5l4J1VA"} - -### Forward differentiation - -JAX implements forward differentiation in the form of -a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)). - -If we attempt now to compute the `jvp` function we get an -error because we have not yet told JAX how to differentiate -the `multiply_add` primitive. - -```{code-cell} ipython3 -:id: OxDx6NQnKwMI -:outputId: ce659ef3-c03c-4856-f252-49ec4b6eb964 - -# The second argument `(2., 10.)` are the argument values -# where we evaluate the Jacobian, and the third `(1., 1.)` -# are the values of the tangents for the arguments. -with expectNotImplementedError(): - api.jvp(square_add_prim, (2., 10.), (1., 1.)) -``` - -```{code-cell} ipython3 -:id: zxG24C1JMIMM - -from jax.interpreters import ad - - -@trace("multiply_add_value_and_jvp") -def multiply_add_value_and_jvp(arg_values, arg_tangents): - """Evaluates the primal output and the tangents (Jacobian-vector product). - - Given values of the arguments and perturbation of the arguments (tangents), - compute the output of the primitive and the perturbation of the output. - - This method must be JAX-traceable. JAX may invoke it with abstract values - for the arguments and tangents. - - Args: - arg_values: a tuple of arguments - arg_tangents: a tuple with the tangents of the arguments. The tuple has - the same length as the arg_values. Some of the tangents may also be the - special value ad.Zero to specify a zero tangent. - Returns: - a pair of the primal output and the tangent. - """ - x, y, z = arg_values - xt, yt, zt = arg_tangents - _trace("Primal evaluation:") - # Now we have a JAX-traceable computation of the output. - # Normally, we can use the ma primitive itself to compute the primal output. - primal_out = multiply_add_prim(x, y, z) - - _trace("Tangent evaluation:") - # We must use a JAX-traceable way to compute the tangent. It turns out that - # the output tangent can be computed as (xt * y + x * yt + zt), - # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive. - - # We do need to deal specially with Zero. Here we just turn it into a - # proper tensor of 0s (of the same shape as 'x'). - # An alternative would be to check for Zero and perform algebraic - # simplification of the output tangent computation. - def make_zero(tan): - return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan - - output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt))) - return (primal_out, output_tangent) - -# Register the forward differentiation rule with JAX -ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp -``` - -```{code-cell} ipython3 -:id: ma3KBkiAMfW1 -:outputId: f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd - -# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5. -assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.) -``` - -+++ {"id": "69QsEcu-lP4u"} - -TO EXPLAIN: - - * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here. - * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet - we do not call the multiply_add_abstract_eval. - * I think it would be useful to show the jaxpr here - -+++ {"id": "Sb6e3ZAHOPHv"} - -#### JIT of forward differentiation - -We can apply JIT to the forward differentiation function: - -```{code-cell} ipython3 -:id: hg-hzVu-N-hv -:outputId: 38d32067-e152-4046-ad80-7f95a31ba628 - -assert api.jit(lambda arg_values, arg_tangents: - api.jvp(square_add_prim, arg_values, arg_tangents))( - (2., 10.), (1., 1.)) == (14., 5.) -``` - -+++ {"id": "jlZt1_v2mU88"} - -Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn -evaluates abstractly both the primal and the tangent evaluation (a total of -3 invocations of the `ma` primitive). Then we compile the 3 occurrences -of the primitive. - -+++ {"id": "555yt6ZIOePB"} - -### Reverse differentiation - -If we attempt now to use reverse differentiation we -see that JAX starts by using the `multiply_add_value_and_jvp` to -compute the forward differentiation for abstract values, but then runs -into a `NotImplementedError`. - -When computing the reverse differentiation JAX first does abstract evaluation -of the forward differentiation code `multiply_add_value_and_jvp` to obtain a -trace of primitives that compute the output tangent. -Observe that JAX performs this abstract evaluation with concrete values -for the differentiation point, and abstract values for the tangents. -Observe also that JAX uses the special abstract tangent value `Zero` for -the tangent corresponding to the 3rd argument of `ma`. This reflects the -fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`, -which flows to the 3rd argument to `multiply_add_prim`. - -Observe also that during the abstract evaluation of the tangent we pass the -value 0.0 as the tangent for the 3rd argument. This is due to the use -of the `make_zero` function in the definition of `multiply_add_value_and_jvp`. - -```{code-cell} ipython3 -:id: 8eAVnexaOjBn -:outputId: e4ee89cf-ab4a-4505-9817-fa978a2865ab - -# This is reverse differentiation w.r.t. the first argument of square_add_prim -with expectNotImplementedError(): - api.grad(square_add_prim)(2., 10.) -``` - -+++ {"id": "fSHLUMDN26AY"} - -The above error is because there is a missing piece for JAX to be able -to use the forward differentiation code to compute reverse differentiation. - -+++ {"id": "3ibDbGF-PjK9"} - -#### Transposition - - -As explained above, when computing reverse differentiation JAX obtains -a trace of primitives that compute the tangent using forward differentiation. -Then, **JAX interprets this trace abstractly backwards** and for each -primitive it applies a **transposition** rule. - -To understand what is going on, consider for now a simpler example of the function "f(x, y) = x * y + y". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`: -``` - a = xt * 4. - b = 2. * yt - c = a + b - ft = c + yt -``` - -By construction, the tangent calculation is always linear in the input tangents. -The only non-linear operator that may arise in the tangent calculation is multiplication, -but then one of the operands is constant. - -JAX will produce the reverse differentiation computation by processing the -JVP computation backwards. For each operation in the tangent computation, -it accumulates the cotangents -of the variables used by the operation, using the cotangent of the result -of the operation: -``` - # Initialize cotangents of inputs and intermediate vars - xct = yct = act = bct = cct = 0. - # Initialize cotangent of the output - fct = 1. - # Process "ft = c + yt" - cct += fct - yct += fct - # Process "c = a + b" - act += cct - bct += cct - # Process "b = 2. * yt" - yct += 2. * bct - # Process "a = xt * 4." - xct += act * 4. -``` - -One can verify that this computation produces `xct = 4.` and `yct = 3.`, which -are the partial derivatives of the function `f`. - -JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is: -``` -p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz) -``` - -Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other -arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned -for the constant arguments. - -In particular, -``` - add_transpose(out_ct, _, _) = (out_ct, out_ct) - mult_transpose(out_ct, x, _) = (None, x * out_ct) - mult_transpose(out_ct, _, y) = (out_ct * y, None) -``` - -```{code-cell} ipython3 -:id: JaHxFdkRO42r - -@trace("multiply_add_transpose") -def multiply_add_transpose(ct, x, y, z): - """Evaluates the transpose of a linear primitive. - - This method is only used when computing the backward gradient following - value_and_jvp, and is only needed for primitives that are used in the JVP - calculation for some other primitive. We need transposition for multiply_add_prim, - because we have used multiply_add_prim in the computation of the output_tangent in - multiply_add_value_and_jvp. - - In our case, multiply_add is not a linear primitive. However, it is used linearly - w.r.t. tangents in multiply_add_value_and_jvp: - output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt)) - - Always one of the first two multiplicative arguments is a constant. - - Args: - ct: the cotangent of the output of the primitive. - x, y, z: values of the arguments. The arguments that are used linearly - get an ad.UndefinedPrimal value. The other arguments get a constant - value. - Returns: - a tuple with the cotangent of the inputs, with the value None - corresponding to the constant arguments. - """ - if not ad.is_undefined_primal(x): - # This use of multiply_add is with a constant "x" - assert ad.is_undefined_primal(y) - ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x)) - res = None, ct_y, ct - else: - # This use of multiply_add is with a constant "y" - assert ad.is_undefined_primal(x) - ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y)) - res = ct_x, None, ct - return res - - -ad.primitive_transposes[multiply_add_p] = multiply_add_transpose -``` - -+++ {"id": "PpChox-Jp7wb"} - -Now we can complete the run of the `grad`: - -```{code-cell} ipython3 -:id: PogPKS4MPevd -:outputId: d33328d4-3e87-45b5-9b31-21ad624b67af - -assert api.grad(square_add_prim)(2., 10.) == 4. -``` - -+++ {"id": "8M1xLCXW4fK7"} - -Notice the two calls to `multiply_add_transpose`. They correspond to the two -uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the -last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0. - -+++ {"id": "EIJs6FYmPg6c"} - -#### JIT of reverse differentiation - -Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only -abstract values, while in the absence of JIT we used `ConcreteArray`. - -```{code-cell} ipython3 -:id: FZ-JGbWZPq2- -:outputId: e42b5222-9c3e-4853-e13a-874f6605d178 - -assert api.jit(api.grad(square_add_prim))(2., 10.) == 4. -``` - -+++ {"id": "-3lqPkdQPvl5"} - -### Batching - -The batching transformation takes a point-wise computation and turns it -into a computation on vectors. If we try it right now, we get a `NotImplementedError`: - -```{code-cell} ipython3 -:id: hFvBR3I9Pzh3 -:outputId: 434608bc-281f-4d3b-83bd-eaaf3b51b1cd - -# The arguments are two vectors instead of two scalars -with expectNotImplementedError(): - api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]), - np.array([10., 20.])) -``` - -+++ {"id": "gILasMiP6elR"} - -We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation. - -```{code-cell} ipython3 -:id: KQfeqRIrP7zg - -from jax.interpreters import batching - - -@trace("multiply_add_batch") -def multiply_add_batch(vector_arg_values, batch_axes): - """Computes the batched version of the primitive. - - This must be a JAX-traceable function. - - Since the multiply_add primitive already operates pointwise on arbitrary - dimension tensors, to batch it we can use the primitive itself. This works as - long as both the inputs have the same dimensions and are batched along the - same axes. The result is batched along the axis that the inputs are batched. - - Args: - vector_arg_values: a tuple of two arguments, each being a tensor of matching - shape. - batch_axes: the axes that are being batched. See vmap documentation. - Returns: - a tuple of the result, and the result axis that was batched. - """ - assert batch_axes[0] == batch_axes[1] - assert batch_axes[0] == batch_axes[2] - _trace("Using multiply_add to compute the batch:") - res = multiply_add_prim(*vector_arg_values) - return res, batch_axes[0] - - -batching.primitive_batchers[multiply_add_p] = multiply_add_batch -``` - -```{code-cell} ipython3 -:id: VwxNk869P_YG -:outputId: 9d22c921-5803-4d33-9e88-b6e439ba9738 - -assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)( - np.array([2., 3.]), - np.array([10., 20.])), - [14., 29.]) -``` - -+++ {"id": "NmqLlV1TQDCC"} - -#### JIT of batching - -```{code-cell} ipython3 -:id: xqEdXVUgQCTt -:outputId: 9c22fd9c-919c-491d-bbeb-32c241b808fa - -assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0)) - (np.array([2., 3.]), - np.array([10., 20.])), - [14., 29.]) -``` diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb deleted file mode 100644 index 3c022124e3cc..000000000000 --- a/docs/notebooks/external_callbacks.ipynb +++ /dev/null @@ -1,1121 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "7XNMxdTwURqI" - }, - "source": [ - "# External callbacks\n", - "\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "h6lXo6bSUYGq" - }, - "source": [ - "This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Xi_nhfpnlmbm" - }, - "source": [ - "## Why callbacks?\n", - "\n", - "A callback routine is a way to perform **host-side** execution of code at runtime.\n", - "As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation.\n", - "Using a simple Python `print` statement, it looks like this:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "lz8rEL1Amb4r", - "outputId": "bbd37102-19f2-46d2-b794-3d4952c6fe97" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "intermediate value: Tracedwith\n" - ] - } - ], - "source": [ - "import jax\n", - "\n", - "@jax.jit\n", - "def f(x):\n", - " y = x + 1\n", - " print(\"intermediate value: {}\".format(y))\n", - " return y * 2\n", - "\n", - "result = f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yEy41sFAmxOp" - }, - "source": [ - "What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)).\n", - "\n", - "To print the value at runtime we need a callback, for example `jax.debug.print`:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "wFfHmoQxnKDF", - "outputId": "6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "intermediate value: 3\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def f(x):\n", - " y = x + 1\n", - " jax.debug.print(\"intermediate value: {}\", y)\n", - " return y * 2\n", - "\n", - "result = f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CvWv3pudn9X5" - }, - "source": [ - "This works by passing the runtime value represented by `y` back to the host process, where the host can print the value." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X0vR078znuT-" - }, - "source": [ - "## Flavors of Callback\n", - "\n", - "In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:\n", - "\n", - "- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect.\n", - "- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.\n", - "- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.\n", - "\n", - "(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`).\n", - "\n", - "From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.\n", - "\n", - "|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |\n", - "|-------------------------------------|----|----|----|----|----|----|\n", - "|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ |\n", - "|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ |\n", - "|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |\n", - "\n", - "¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff\n", - "\n", - "² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`.\n", - "\n", - "³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hE_M8DaPvoym" - }, - "source": [ - "### Exploring `jax.pure_callback`\n", - "\n", - "`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).\n", - "\n", - "The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "4lQDzXy6t_-k", - "outputId": "279e4daf-0540-4eab-f535-d3bcbac74c44" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "\n", - "def f_host(x):\n", - " # call a numpy (not jax.numpy) operation:\n", - " return np.sin(x).astype(x.dtype)\n", - "\n", - "def f(x):\n", - " result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)\n", - " return jax.pure_callback(f_host, result_shape, x)\n", - "\n", - "x = jnp.arange(5.0)\n", - "f(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "q7YCIr8qMrDs" - }, - "source": [ - "Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:\"" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "bgoZ0fxsuoWV", - "outputId": "901443bd-5cb4-4923-ce53-6f832ac22ca9" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.jit(f)(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "ajBRGWGfupu2", - "outputId": "b28e31ee-7457-4b92-872b-52d819f53ddf" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.vmap(f)(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "xe7AOGexvC13", - "outputId": "8fa77977-1f2b-41c5-cc5e-11993ee5aa3e" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def body_fun(_, x):\n", - " return _, f(x)\n", - "jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tMzAVs2VNj5G" - }, - "source": [ - "However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "4QAF4VhUu5bb", - "outputId": "f8a06d02-47e9-4240-8077-d7be81e5a480" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Exception reporting mode: Minimal\n" - ] - } - ], - "source": [ - "%xmode minimal" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "qUpKPxlOurfY", - "outputId": "11a665e8-40eb-4b0e-dc2e-a544a25fc57e", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n" - ] - } - ], - "source": [ - "jax.grad(f)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "y9DAibV4Nwpo" - }, - "source": [ - "For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LrvdAloMZbIe" - }, - "source": [ - "By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "mmFc_zawZrBq", - "outputId": "a4df7568-3f64-4b2f-9a2c-7adb2e0815e0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "printing something\n" - ] - } - ], - "source": [ - "def print_something():\n", - " print('printing something')\n", - " return np.int32(0)\n", - "\n", - "@jax.jit\n", - "def f1():\n", - " return jax.pure_callback(print_something, np.int32(0))\n", - "f1();" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "tTwE4kpmaNei" - }, - "outputs": [], - "source": [ - "@jax.jit\n", - "def f2():\n", - " jax.pure_callback(print_something, np.int32(0))\n", - " return 1.0\n", - "f2();" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qfyGYbw4Z5U3" - }, - "source": [ - "In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.\n", - "In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JHcJybr7OEBM" - }, - "source": [ - "### Exploring `jax.experimental.io_callback`\n", - "\n", - "In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.\n", - "\n", - "As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!)." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "eAg5xIhrOiWV", - "outputId": "e3cfec21-d843-4852-a49d-69a69fba9fc1" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "generating float32[5]\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ], dtype=float32)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jax.experimental import io_callback\n", - "from functools import partial\n", - "\n", - "global_rng = np.random.default_rng(0)\n", - "\n", - "def host_side_random_like(x):\n", - " \"\"\"Generate a random array like x using the global_rng state\"\"\"\n", - " # We have two side-effects here:\n", - " # - printing the shape and dtype\n", - " # - calling global_rng, thus updating its state\n", - " print(f'generating {x.dtype}{list(x.shape)}')\n", - " return global_rng.uniform(size=x.shape).astype(x.dtype)\n", - "\n", - "@jax.jit\n", - "def numpy_random_like(x):\n", - " return io_callback(host_side_random_like, x, x)\n", - "\n", - "x = jnp.zeros(5)\n", - "numpy_random_like(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mAIF31MlXj33" - }, - "source": [ - "The `io_callback` is compatible with `vmap` by default:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "NY3o5dG6Vg6u", - "outputId": "a67a8a98-214e-40ca-ad98-a930cd3db85e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([0.91275555, 0.60663575, 0.72949654, 0.543625 , 0.9350724 ], dtype=float32)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.vmap(numpy_random_like)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XXvSeeOXXquZ" - }, - "source": [ - "Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.\n", - "\n", - "If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "3aNmRsDrX3-2", - "outputId": "a8ff4b77-f4cb-442f-8cfb-ea7251c66274", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: Cannot `vmap` ordered IO callback.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Cannot `vmap` ordered IO callback.\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def numpy_random_like_ordered(x):\n", - " return io_callback(host_side_random_like, x, x, ordered=True)\n", - "\n", - "jax.vmap(numpy_random_like_ordered)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fD2FTHlUYAZH" - }, - "source": [ - "On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced:" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "lMVzZlIEWL7F", - "outputId": "f9741c18-a30d-4d46-b706-8102849286b5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544], dtype=float32)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def body_fun(_, x):\n", - " return _, numpy_random_like_ordered(x)\n", - "jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_sf8mCbbo8K" - }, - "source": [ - "Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "Cn6_RG4JcKZm", - "outputId": "336ae5d2-e35b-4fe5-cbfb-14a7aef28c07", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: IO callbacks do not support JVP.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m IO callbacks do not support JVP.\n" - ] - } - ], - "source": [ - "jax.grad(numpy_random_like)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "plvfn9lWcKu4" - }, - "source": [ - "However, if the callback is not dependent on a differentiated variable, it will execute:" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "wxgfDmDfb5bx", - "outputId": "d8c0285c-cd04-4b4d-d15a-1b07f778882d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hello\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def f(x):\n", - " io_callback(lambda: print('hello'), None)\n", - " return x\n", - "\n", - "jax.grad(f)(1.0);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "STLI40EZcVIY" - }, - "source": [ - "Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pkkM1ZmqclV-" - }, - "source": [ - "### Exploring `debug.callback`\n", - "\n", - "Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "74TdWyu9eqBa", - "outputId": "d8551dab-2e61-492e-9ac3-dc3db51b2c18" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "log: 1.0\n" - ] - } - ], - "source": [ - "from jax import debug\n", - "\n", - "def log_value(x):\n", - " # This could be an actual logging call; we'll use\n", - " # print() for demonstration\n", - " print(\"log:\", x)\n", - "\n", - "@jax.jit\n", - "def f(x):\n", - " debug.callback(log_value, x)\n", - " return x\n", - "\n", - "f(1.0);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "P848STlsfzmW" - }, - "source": [ - "The debug callback is compatible with `vmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "2sSNsPB-fGVI", - "outputId": "fff58575-d94c-48fb-b88a-c1c395595fd0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "log: 0.0\n", - "log: 1.0\n", - "log: 2.0\n", - "log: 3.0\n", - "log: 4.0\n" - ] - } - ], - "source": [ - "x = jnp.arange(5.0)\n", - "jax.vmap(f)(x);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VDMacqpXf3La" - }, - "source": [ - "And is also compatible with `grad` and other autodiff transformations" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "wkFRle-tfTDe", - "outputId": "4e8a81d0-5012-4c51-d843-3fbdc498df31" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "log: 1.0\n" - ] - } - ], - "source": [ - "jax.grad(f)(1.0);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w8t-SDZ3gRzE" - }, - "source": [ - "This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dF7hoWGQUneJ" - }, - "source": [ - "## Example: `pure_callback` with `custom_jvp`\n", - "\n", - "One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).\n", - "Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers.\n", - "\n", - "Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`.\n", - "We can start by defining a straightforward `pure_callback`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Ge4fNPZdVSJY" - }, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import scipy.special\n", - "\n", - "def jv(v, z):\n", - " v, z = jnp.asarray(v), jnp.asarray(z)\n", - "\n", - " # Require the order v to be integer type: this simplifies\n", - " # the JVP rule below.\n", - " assert jnp.issubdtype(v.dtype, jnp.integer)\n", - "\n", - " # Promote the input to inexact (float/complex).\n", - " # Note that jnp.result_type() accounts for the enable_x64 flag.\n", - " z = z.astype(jnp.result_type(float, z.dtype))\n", - "\n", - " # Wrap scipy function to return the expected dtype.\n", - " _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)\n", - "\n", - " # Define the expected shape & dtype of output.\n", - " result_shape_dtype = jax.ShapeDtypeStruct(\n", - " shape=jnp.broadcast_shapes(v.shape, z.shape),\n", - " dtype=z.dtype)\n", - "\n", - " # We use vectorize=True because scipy.special.jv handles broadcasted inputs.\n", - " return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vyjQj-0QVuoN" - }, - "source": [ - "This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "f4e46670f4e4" - }, - "outputs": [], - "source": [ - "j1 = partial(jv, 1)\n", - "z = jnp.arange(5.0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6svImqFHWBwj", - "outputId": "bc8c778a-6c10-443b-9be2-c0f28e2ac1a9" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n" - ] - } - ], - "source": [ - "print(j1(z))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d48eb4f2d48e" - }, - "source": [ - "Here is the same result with `jit`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "txvRqR9DWGdC", - "outputId": "d25f3476-23b1-48e4-dda1-3c06d32c3b87" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n" - ] - } - ], - "source": [ - "print(jax.jit(j1)(z))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d861a472d861" - }, - "source": [ - "And here is the same result again with `vmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BS-Ve5u_WU0C", - "outputId": "08cecd1f-6953-4853-e9db-25a03eb5b000" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n" - ] - } - ], - "source": [ - "print(jax.vmap(j1)(z))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SCH2ii_dWXP6" - }, - "source": [ - "However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "q3qh_4DrWxdQ", - "outputId": "c46b0bfa-96f3-4629-b9af-a4d4f3ccb870", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1090\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1091\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1092\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1166\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1167\u001b[0;31m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreduce_axes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1168\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_vjp\u001b[0;34m(fun, has_aux, reduce_axes, *primals)\u001b[0m\n\u001b[1;32m 2655\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun_nokwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2656\u001b[0;31m out_primal, out_vjp = ad.vjp(\n\u001b[0m\u001b[1;32m 2657\u001b[0m flat_fun, primals_flat, reduce_axes=reduce_axes)\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mvjp\u001b[0;34m(traceable, primals, has_aux, reduce_axes)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mlinearize\u001b[0;34m(traceable, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0mout_primals_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tangents_pvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_nounits\u001b[0;34m(fun, pvals, instantiate)\u001b[0m\n\u001b[1;32m 766\u001b[0m \u001b[0mfun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstantiate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 767\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3425\u001b[0m \"\"\"\n\u001b[0;32m-> 3426\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtypes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m 328\u001b[0m all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args\n\u001b[0;32m--> 329\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfind_top_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_primitive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mfull_lower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mprocess_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 310\u001b[0;31m \u001b[0mprimal_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangent_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjvp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimals_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangents_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 311\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mUnfilteredStackTrace\u001b[0m: ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m 129\u001b[0m lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)\n\u001b[1;32m 130\u001b[0m \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m result_avals=tuple(flat_result_avals), vectorized=vectorized)\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpure_callback_jvp_rule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \"Please use `jax.custom_jvp` to use callbacks while taking gradients.\")\n", - "\u001b[0;31mValueError\u001b[0m: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients." - ] - } - ], - "source": [ - "jax.grad(j1)(z)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PtYeJ_xUW09v" - }, - "source": [ - "Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:\n", - "\n", - "$$\n", - "d J_\\nu(z) = \\left\\{\n", - "\\begin{eqnarray}\n", - "-J_1(z),\\ &\\nu=0\\\\\n", - "[J_{\\nu - 1}(z) - J_{\\nu + 1}(z)]/2,\\ &\\nu\\ne 0\n", - "\\end{eqnarray}\\right.\n", - "$$\n", - "\n", - "The gradient with respect to $\\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.\n", - "\n", - "We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BOVQnt05XvLs" - }, - "outputs": [], - "source": [ - "jv = jax.custom_jvp(jv)\n", - "\n", - "@jv.defjvp\n", - "def _jv_jvp(primals, tangents):\n", - " v, z = primals\n", - " _, z_dot = tangents # Note: v_dot is always 0 because v is integer.\n", - " jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)\n", - " djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))\n", - " return jv(v, z), z_dot * djv_dz" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "W1SxcvQSX44c" - }, - "source": [ - "Now computing the gradient of our function will work correctly:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sCGceBs-X8nL", - "outputId": "71c5589f-f996-44a0-f09a-ca8bb40c167a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-0.06447162\n" - ] - } - ], - "source": [ - "j1 = partial(jv, 1)\n", - "print(jax.grad(j1)(2.0))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gWQ4phN5YB26" - }, - "source": [ - "Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "QTe5mRAvYQBh", - "outputId": "d58ecff3-9419-422a-fd0e-14a7d9cf2cc3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(-0.4003078, dtype=float32, weak_type=True)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.hessian(j1)(2.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QEXGxU4uYZii" - }, - "source": [ - "Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device.\n", - "When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.\n", - "However, if you are running JAX on a single CPU (where the \"host\" and \"device\" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities." - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md deleted file mode 100644 index 910d47bd72ae..000000000000 --- a/docs/notebooks/external_callbacks.md +++ /dev/null @@ -1,515 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.4 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "7XNMxdTwURqI"} - -# External callbacks - - - -+++ {"id": "h6lXo6bSUYGq"} - -This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation. - -+++ {"id": "Xi_nhfpnlmbm"} - -## Why callbacks? - -A callback routine is a way to perform **host-side** execution of code at runtime. -As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation. -Using a simple Python `print` statement, it looks like this: - -```{code-cell} -:id: lz8rEL1Amb4r -:outputId: bbd37102-19f2-46d2-b794-3d4952c6fe97 - -import jax - -@jax.jit -def f(x): - y = x + 1 - print("intermediate value: {}".format(y)) - return y * 2 - -result = f(2) -``` - -+++ {"id": "yEy41sFAmxOp"} - -What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)). - -To print the value at runtime we need a callback, for example `jax.debug.print`: - -```{code-cell} -:id: wFfHmoQxnKDF -:outputId: 6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a - -@jax.jit -def f(x): - y = x + 1 - jax.debug.print("intermediate value: {}", y) - return y * 2 - -result = f(2) -``` - -+++ {"id": "CvWv3pudn9X5"} - -This works by passing the runtime value represented by `y` back to the host process, where the host can print the value. - -+++ {"id": "X0vR078znuT-"} - -## Flavors of Callback - -In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations: - -- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect. -- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk. -- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler. - -(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`). - -From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. - -|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution | -|-------------------------------------|----|----|----|----|----|----| -|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ | -|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ | -|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | - -¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff - -² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`. - -³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases. - -+++ {"id": "hE_M8DaPvoym"} - -### Exploring `jax.pure_callback` - -`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.). - -The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times. - -```{code-cell} -:id: 4lQDzXy6t_-k -:outputId: 279e4daf-0540-4eab-f535-d3bcbac74c44 - -import jax -import jax.numpy as jnp -import numpy as np - -def f_host(x): - # call a numpy (not jax.numpy) operation: - return np.sin(x).astype(x.dtype) - -def f(x): - result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype) - return jax.pure_callback(f_host, result_shape, x) - -x = jnp.arange(5.0) -f(x) -``` - -+++ {"id": "q7YCIr8qMrDs"} - -Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:" - -```{code-cell} -:id: bgoZ0fxsuoWV -:outputId: 901443bd-5cb4-4923-ce53-6f832ac22ca9 - -jax.jit(f)(x) -``` - -```{code-cell} -:id: ajBRGWGfupu2 -:outputId: b28e31ee-7457-4b92-872b-52d819f53ddf - -jax.vmap(f)(x) -``` - -```{code-cell} -:id: xe7AOGexvC13 -:outputId: 8fa77977-1f2b-41c5-cc5e-11993ee5aa3e - -def body_fun(_, x): - return _, f(x) -jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] -``` - -+++ {"id": "tMzAVs2VNj5G"} - -However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics: - -```{code-cell} -:id: 4QAF4VhUu5bb -:outputId: f8a06d02-47e9-4240-8077-d7be81e5a480 - -%xmode minimal -``` - -```{code-cell} -:id: qUpKPxlOurfY -:outputId: 11a665e8-40eb-4b0e-dc2e-a544a25fc57e -:tags: [raises-exception] - -jax.grad(f)(x) -``` - -+++ {"id": "y9DAibV4Nwpo"} - -For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below. - -+++ {"id": "LrvdAloMZbIe"} - -By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely: - -```{code-cell} -:id: mmFc_zawZrBq -:outputId: a4df7568-3f64-4b2f-9a2c-7adb2e0815e0 - -def print_something(): - print('printing something') - return np.int32(0) - -@jax.jit -def f1(): - return jax.pure_callback(print_something, np.int32(0)) -f1(); -``` - -```{code-cell} -:id: tTwE4kpmaNei - -@jax.jit -def f2(): - jax.pure_callback(print_something, np.int32(0)) - return 1.0 -f2(); -``` - -+++ {"id": "qfyGYbw4Z5U3"} - -In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output. -In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects. - -+++ {"id": "JHcJybr7OEBM"} - -### Exploring `jax.experimental.io_callback` - -In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects. - -As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!). - -```{code-cell} -:id: eAg5xIhrOiWV -:outputId: e3cfec21-d843-4852-a49d-69a69fba9fc1 - -from jax.experimental import io_callback -from functools import partial - -global_rng = np.random.default_rng(0) - -def host_side_random_like(x): - """Generate a random array like x using the global_rng state""" - # We have two side-effects here: - # - printing the shape and dtype - # - calling global_rng, thus updating its state - print(f'generating {x.dtype}{list(x.shape)}') - return global_rng.uniform(size=x.shape).astype(x.dtype) - -@jax.jit -def numpy_random_like(x): - return io_callback(host_side_random_like, x, x) - -x = jnp.zeros(5) -numpy_random_like(x) -``` - -+++ {"id": "mAIF31MlXj33"} - -The `io_callback` is compatible with `vmap` by default: - -```{code-cell} -:id: NY3o5dG6Vg6u -:outputId: a67a8a98-214e-40ca-ad98-a930cd3db85e - -jax.vmap(numpy_random_like)(x) -``` - -+++ {"id": "XXvSeeOXXquZ"} - -Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run. - -If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error: - -```{code-cell} -:id: 3aNmRsDrX3-2 -:outputId: a8ff4b77-f4cb-442f-8cfb-ea7251c66274 -:tags: [raises-exception] - -@jax.jit -def numpy_random_like_ordered(x): - return io_callback(host_side_random_like, x, x, ordered=True) - -jax.vmap(numpy_random_like_ordered)(x) -``` - -+++ {"id": "fD2FTHlUYAZH"} - -On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced: - -```{code-cell} -:id: lMVzZlIEWL7F -:outputId: f9741c18-a30d-4d46-b706-8102849286b5 - -def body_fun(_, x): - return _, numpy_random_like_ordered(x) -jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] -``` - -+++ {"id": "w_sf8mCbbo8K"} - -Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable: - -```{code-cell} -:id: Cn6_RG4JcKZm -:outputId: 336ae5d2-e35b-4fe5-cbfb-14a7aef28c07 -:tags: [raises-exception] - -jax.grad(numpy_random_like)(x) -``` - -+++ {"id": "plvfn9lWcKu4"} - -However, if the callback is not dependent on a differentiated variable, it will execute: - -```{code-cell} -:id: wxgfDmDfb5bx -:outputId: d8c0285c-cd04-4b4d-d15a-1b07f778882d - -@jax.jit -def f(x): - io_callback(lambda: print('hello'), None) - return x - -jax.grad(f)(1.0); -``` - -+++ {"id": "STLI40EZcVIY"} - -Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation. - -+++ {"id": "pkkM1ZmqclV-"} - -### Exploring `debug.callback` - -Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program. - -```{code-cell} -:id: 74TdWyu9eqBa -:outputId: d8551dab-2e61-492e-9ac3-dc3db51b2c18 - -from jax import debug - -def log_value(x): - # This could be an actual logging call; we'll use - # print() for demonstration - print("log:", x) - -@jax.jit -def f(x): - debug.callback(log_value, x) - return x - -f(1.0); -``` - -+++ {"id": "P848STlsfzmW"} - -The debug callback is compatible with `vmap`: - -```{code-cell} -:id: 2sSNsPB-fGVI -:outputId: fff58575-d94c-48fb-b88a-c1c395595fd0 - -x = jnp.arange(5.0) -jax.vmap(f)(x); -``` - -+++ {"id": "VDMacqpXf3La"} - -And is also compatible with `grad` and other autodiff transformations - -```{code-cell} -:id: wkFRle-tfTDe -:outputId: 4e8a81d0-5012-4c51-d843-3fbdc498df31 - -jax.grad(f)(1.0); -``` - -+++ {"id": "w8t-SDZ3gRzE"} - -This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`. - -+++ {"id": "dF7hoWGQUneJ"} - -## Example: `pure_callback` with `custom_jvp` - -One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`). -Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers. - -Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`. -We can start by defining a straightforward `pure_callback`: - -```{code-cell} -:id: Ge4fNPZdVSJY - -import jax -import jax.numpy as jnp -import scipy.special - -def jv(v, z): - v, z = jnp.asarray(v), jnp.asarray(z) - - # Require the order v to be integer type: this simplifies - # the JVP rule below. - assert jnp.issubdtype(v.dtype, jnp.integer) - - # Promote the input to inexact (float/complex). - # Note that jnp.result_type() accounts for the enable_x64 flag. - z = z.astype(jnp.result_type(float, z.dtype)) - - # Wrap scipy function to return the expected dtype. - _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype) - - # Define the expected shape & dtype of output. - result_shape_dtype = jax.ShapeDtypeStruct( - shape=jnp.broadcast_shapes(v.shape, z.shape), - dtype=z.dtype) - - # We use vectorize=True because scipy.special.jv handles broadcasted inputs. - return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True) -``` - -+++ {"id": "vyjQj-0QVuoN"} - -This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`: - -```{code-cell} -:id: f4e46670f4e4 - -j1 = partial(jv, 1) -z = jnp.arange(5.0) -``` - -```{code-cell} -:id: 6svImqFHWBwj -:outputId: bc8c778a-6c10-443b-9be2-c0f28e2ac1a9 - -print(j1(z)) -``` - -+++ {"id": "d48eb4f2d48e"} - -Here is the same result with `jit`: - -```{code-cell} -:id: txvRqR9DWGdC -:outputId: d25f3476-23b1-48e4-dda1-3c06d32c3b87 - -print(jax.jit(j1)(z)) -``` - -+++ {"id": "d861a472d861"} - -And here is the same result again with `vmap`: - -```{code-cell} -:id: BS-Ve5u_WU0C -:outputId: 08cecd1f-6953-4853-e9db-25a03eb5b000 - -print(jax.vmap(j1)(z)) -``` - -+++ {"id": "SCH2ii_dWXP6"} - -However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function: - -```{code-cell} -:id: q3qh_4DrWxdQ -:outputId: c46b0bfa-96f3-4629-b9af-a4d4f3ccb870 -:tags: [raises-exception] - -jax.grad(j1)(z) -``` - -+++ {"id": "PtYeJ_xUW09v"} - -Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`: - -$$ -d J_\nu(z) = \left\{ -\begin{eqnarray} --J_1(z),\ &\nu=0\\ -[J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 -\end{eqnarray}\right. -$$ - -The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example. - -We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function: - -```{code-cell} -:id: BOVQnt05XvLs - -jv = jax.custom_jvp(jv) - -@jv.defjvp -def _jv_jvp(primals, tangents): - v, z = primals - _, z_dot = tangents # Note: v_dot is always 0 because v is integer. - jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z) - djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1)) - return jv(v, z), z_dot * djv_dz -``` - -+++ {"id": "W1SxcvQSX44c"} - -Now computing the gradient of our function will work correctly: - -```{code-cell} -:id: sCGceBs-X8nL -:outputId: 71c5589f-f996-44a0-f09a-ca8bb40c167a - -j1 = partial(jv, 1) -print(jax.grad(j1)(2.0)) -``` - -+++ {"id": "gWQ4phN5YB26"} - -Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free: - -```{code-cell} -:id: QTe5mRAvYQBh -:outputId: d58ecff3-9419-422a-fd0e-14a7d9cf2cc3 - -jax.hessian(j1)(2.0) -``` - -+++ {"id": "QEXGxU4uYZii"} - -Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device. -When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called. -However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities. diff --git a/docs/tutorials.rst b/docs/tutorials.rst index be70c6d41654..a31517155e1a 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -16,3 +16,13 @@ Tutorials working-with-pytrees sharded-computation stateful-computations + +.. toctree:: + :maxdepth: 1 + :caption: Advanced tutorials + + advanced-autodiff + external-callbacks + gradient-checkpointing + jax-primitives + jaxpr diff --git a/docs/user_guides.rst b/docs/user_guides.rst index e917cf2fee38..6481da7a31dd 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -33,7 +33,6 @@ or deployed codebases. :maxdepth: 1 :caption: Custom operations - notebooks/external_callbacks pallas/index ffi From d63afd8438a023598d8847a357c3b4105e3d7d16 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 20 Sep 2024 16:17:38 -0700 Subject: [PATCH 598/702] [Pallas GPU] Enable Pallas `OpsExtraTest` in 64-bit mode This is a follow-up of https://github.com/jax-ml/jax/pull/23747, which enables Pallas `OpsTest` in 64-bit mode. In order to enable Pallas `OpsExtraTest` in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications: 1. Most of the modifications are just changing `jnp.int32` to `intx` and `jnp.float32` to `floatx`, which uses the same approach as the previous PR https://github.com/jax-ml/jax/pull/23747. `intx` and `floatx` are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode. 2. For the test `test_array_reduce`, the original code uses a simple approach to determine `out_dtype` from `dtype`, which no longer works in 64-bit mode. Therefore, I modified the code to deduct `out_dtype` by executing the operation on a single element first. 3. For the test `test_masked_load_store`, the `idx` variable is expected to be an `int32` array, which is calculated based on `pl.program_id()` and `block_size`. In 64-bit mode, the computation will give out an `int64` array instead. Since `pl.program_id()` always returns an `int32` result, I modified the computation to produce `int32` result. I also modified the `pl.program_id()` docstring to document the behaviour that `pl.program_id()` always returns an `int32` result. PiperOrigin-RevId: 677007613 --- jax/_src/pallas/primitives.py | 2 + tests/pallas/ops_test.py | 69 ++++++++++++++++++++++------------- 2 files changed, 46 insertions(+), 25 deletions(-) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 89b6c6e14acd..40caae76bd8f 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -59,6 +59,8 @@ def program_id(axis: int) -> jax.Array: grid coordinates `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. + The returned value is an array of shape `()` and dtype `int32`. + Args: axis: the axis of the grid along which to count the program. """ diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index d8f890c06c32..65cdde30f7d9 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -752,8 +752,6 @@ class OpsExtraTest(PallasBaseTest): def setUp(self): super().setUp() - if jax.config.x64_enabled: - self.skipTest("Only works in 32-bit") if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: # TODO: most tests fail on TPU in non-interpret mode self.skipTest("On TPU the test works only in interpret mode") @@ -800,7 +798,7 @@ def kernel(x_ref, o_ref): def test_abs_weak_type(self): # see https://github.com/jax-ml/jax/issues/23191 @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), floatx), ) def kernel(x_ref, o_ref): o_ref[...] = jnp.abs(x_ref[...]) @@ -1145,20 +1143,20 @@ def f(x_ref, o_ref): def test_num_programs(self): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + out_shape=jax.ShapeDtypeStruct((4,), intx), grid=4, ) def kernel(o_ref): o_ref[pl.program_id(0)] = pl.num_programs(0) np.testing.assert_array_equal( - kernel(), np.asarray([4, 4, 4, 4], dtype=np.int32) + kernel(), jnp.array([4, 4, 4, 4], dtype=intx) ) def test_where_broadcasting(self): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4, 2, 2), jnp.float32), + out_shape=jax.ShapeDtypeStruct((4, 2, 2), floatx), grid=1, ) def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): @@ -1225,11 +1223,12 @@ def dot(x_ref, y_ref, o_ref): def test_masked_load_store(self, size, block_size): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((size,), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((size,), floatx)), grid=pl.cdiv(size, block_size), ) def kernel(x_ref, o_ref): - idx = pl.program_id(0) * block_size + jnp.arange(block_size) + idx = pl.program_id(0) * block_size + jnp.arange( + block_size, dtype=jnp.int32) mask = idx < x_ref.shape[0] x = pl.load(x_ref, (idx,), mask=mask) pl.store(o_ref, (idx,), x + 1.0, mask=mask) @@ -1243,7 +1242,7 @@ def test_masked_oob_load_store_slice(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((n,), floatx)), grid=1, ) def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): @@ -1276,7 +1275,7 @@ def test_broadcasted_load_store(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((m, n), floatx)), grid=1, ) def load(x_ref, o_ref): @@ -1319,7 +1318,7 @@ def test_swap(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, + out_shape=(jax.ShapeDtypeStruct((m, n), floatx),) * 2, grid=1, input_output_aliases={0: 0, 1: 1}, ) @@ -1339,7 +1338,7 @@ def test_masked_swap(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, + out_shape=(jax.ShapeDtypeStruct((m, n), floatx),) * 2, grid=1, input_output_aliases={0: 0, 1: 1}, ) @@ -1360,8 +1359,8 @@ def test_masked_oob_swap_slice(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32), - jax.ShapeDtypeStruct((m,), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((n,), floatx), + jax.ShapeDtypeStruct((m,), floatx)), grid=1, input_output_aliases={0: 0, 1: 1}, ) @@ -1430,7 +1429,7 @@ def test_array_atomic_add(self, axis): grid = m else: grid = n - out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), jnp.float32) + out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), floatx) @functools.partial( self.pallas_call, @@ -1464,8 +1463,8 @@ def reduce(x_ref, _, y_ref): def test_atomic_cas(self, init_value, cmp, new_value): @functools.partial( self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), jnp.int32), - jax.ShapeDtypeStruct((), jnp.int32)), + jax.ShapeDtypeStruct((), intx), + jax.ShapeDtypeStruct((), intx)), input_output_aliases={0: 0}) def swap(_, lock_ref, out_ref): out_ref[()] = pl.atomic_cas(lock_ref, cmp, new_value) @@ -1528,14 +1527,31 @@ def reduce(x_ref, y_ref): ("argmin", jnp.argmin), ] for axis in [0, 1, (1,), (0, 1)] - for dtype in ["float16", "float32", "int32", "uint32"] + for dtype in [ + "float16", + "float32", + "float64", + "int32", + "int64", + "uint32", + "uint64", + ] if isinstance(axis, int) or "arg" not in op_name ]) def test_array_reduce(self, op, dtype, axis): m, n = 32, 8 - out_dtype = dtype - if op in {jnp.argmin, jnp.argmax}: - out_dtype = jnp.int32 + + if not jax.config.x64_enabled and dtype in ("float64", "int64", "uint64"): + self.skipTest("64-bit types require x64_enabled") + + # Skip argmin/argmax on GPU in 64-bit mode because Pallas expects + # `index_type` to be i32 + if ( + jax.config.x64_enabled + and jtu.test_device_matches(["gpu"]) + and op in {jnp.argmin, jnp.argmax} + ): + self.skipTest("Not supported on GPU in 64-bit mode") def make_x(key): if jnp.issubdtype(dtype, jnp.integer): @@ -1545,9 +1561,10 @@ def make_x(key): else: return random.normal(key, (m, n), dtype=dtype) + # deduct `out_dtype` by executing the op on a single element + out_dtype = op(jnp.arange(1, dtype=dtype)).dtype out_shape = jax.ShapeDtypeStruct( - op(make_x(random.key(0)), axis=axis).shape, out_dtype - ) + op(make_x(random.key(0)), axis=axis).shape, out_dtype) if isinstance(axis, int): grid = tuple(a for i, a in enumerate((m, n)) if i != axis) else: @@ -1555,9 +1572,11 @@ def make_x(key): @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) def reduce(x_ref, y_ref): - x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None])) + x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], + jnp.arange(n, dtype=jnp.int32)[None])) y = op(x, axis=axis) - pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y) + pl.store(y_ref, + tuple(jnp.arange(d, dtype=jnp.int32) for d in y.shape), y) for i, key in enumerate(random.split(random.key(0), 20)): x = make_x(key) From aa551e66c59730997f0abc4e88e8aba0a14634de Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Sat, 21 Sep 2024 07:39:17 -0700 Subject: [PATCH 599/702] Test that jax.numpy docstrings include examples --- jax/_src/numpy/polynomial.py | 12 ++--- jax/_src/numpy/ufuncs.py | 87 +++++++++++++++++++++++++++++++++++- jax/_src/numpy/vectorize.py | 85 ++++++++++++++++++----------------- tests/lax_numpy_test.py | 2 + 4 files changed, 137 insertions(+), 49 deletions(-) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index ca4e3ebaf6a2..cce8bb8e6f7f 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -317,7 +317,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array: - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given coefficients. - Example: + Examples: Scalar inputs: @@ -407,7 +407,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given coefficients. - Example: + Examples: >>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32) @@ -455,7 +455,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: - :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial division. - Example: + Examples: >>> x1 = jnp.array([2, 3]) >>> x2 = jnp.array([5, 4, 1]) >>> jnp.polyadd(x1, x2) @@ -637,7 +637,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - - :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial division. - Example: + Examples: >>> x1 = np.array([2, 1, 0]) >>> x2 = np.array([0, 5, 0, 3]) >>> np.polymul(x1, x2) @@ -702,7 +702,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> - :func:`jax.numpy.polysub`: Computes the difference of two polynomials. - :func:`jax.numpy.polymul`: Computes the product of two polynomials. - Example: + Examples: >>> x1 = jnp.array([5, 7, 9]) >>> x2 = jnp.array([4, 1]) >>> np.polydiv(x1, x2) @@ -755,7 +755,7 @@ def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: - :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial division. - Example: + Examples: >>> x1 = jnp.array([2, 3]) >>> x2 = jnp.array([5, 4, 1]) >>> jnp.polysub(x1, x2) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 857ed8668d59..2455817bf054 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -652,12 +652,26 @@ def _add(x: ArrayLike, y: ArrayLike, /) -> Array: JAX implementation of :obj:`numpy.add`. This is a universal function, and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``+`` operator for + JAX arrays. Args: x, y: arrays to add. Must be broadcastable to a common shape. Returns: Array containing the result of the element-wise addition. + + Examples: + Calling ``add`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.add(x, 10) + Array([10, 11, 12, 13], dtype=int32) + + Calling ``add`` via the ``+`` operator: + + >>> x + 10 + Array([10, 11, 12, 13], dtype=int32) """ x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) @@ -668,12 +682,26 @@ def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: JAX implementation of :obj:`numpy.multiply`. This is a universal function, and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``*`` operator for + JAX arrays. Args: x, y: arrays to multiply. Must be broadcastable to a common shape. Returns: Array containing the result of the element-wise multiplication. + + Examples: + Calling ``multiply`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.multiply(x, 10) + Array([ 0, 10, 20, 30], dtype=int32) + + Calling ``multiply`` via the ``*`` operator: + + >>> x * 10 + Array([ 0, 10, 20, 30], dtype=int32) """ x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) @@ -684,12 +712,26 @@ def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function, and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``&`` operator for + JAX arrays. Args: x, y: integer or boolean arrays. Must be broadcastable to a common shape. Returns: Array containing the result of the element-wise bitwise AND. + + Examples: + Calling ``bitwise_and`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.bitwise_and(x, 1) + Array([0, 1, 0, 1], dtype=int32) + + Calling ``bitwise_and`` via the ``&`` operator: + + >>> x & 1 + Array([0, 1, 0, 1], dtype=int32) """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) @@ -699,12 +741,26 @@ def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function, and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``|`` operator for + JAX arrays. Args: x, y: integer or boolean arrays. Must be broadcastable to a common shape. Returns: Array containing the result of the element-wise bitwise OR. + + Examples: + Calling ``bitwise_or`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.bitwise_or(x, 1) + Array([1, 1, 3, 3], dtype=int32) + + Calling ``bitwise_or`` via the ``|`` operator: + + >>> x | 1 + Array([1, 1, 3, 3], dtype=int32) """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) @@ -714,12 +770,26 @@ def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function, and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``^`` operator for + JAX arrays. Args: x, y: integer or boolean arrays. Must be broadcastable to a common shape. Returns: Array containing the result of the element-wise bitwise XOR. + + Examples: + Calling ``bitwise_xor`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.bitwise_xor(x, 1) + Array([1, 0, 3, 2], dtype=int32) + + Calling ``bitwise_xor`` via the ``^`` operator: + + >>> x ^ 1 + Array([1, 0, 3, 2], dtype=int32) """ return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) @@ -958,6 +1028,11 @@ def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: Returns: Array containing the result of the element-wise logical AND. + + Examples: + >>> x = jnp.arange(4) + >>> jnp.logical_and(x, 1) + Array([False, True, True, True], dtype=bool) """ return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) @@ -973,6 +1048,11 @@ def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: Returns: Array containing the result of the element-wise logical OR. + + Examples: + >>> x = jnp.arange(4) + >>> jnp.logical_or(x, 1) + Array([ True, True, True, True], dtype=bool) """ return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) @@ -988,6 +1068,11 @@ def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: Returns: Array containing the result of the element-wise logical XOR. + + Examples: + >>> x = jnp.arange(4) + >>> jnp.logical_xor(x, 1) + Array([ True, False, False, False], dtype=bool) """ return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) @@ -1373,7 +1458,7 @@ def rint(x: ArrayLike, /) -> Array: If an element of x is exactly half way, e.g. ``0.5`` or ``1.5``, rint will round to the nearest even integer. - Example: + Examples: >>> x1 = jnp.array([5, 4, 7]) >>> jnp.rint(x1) Array([5., 4., 7.], dtype=float32) diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index dc368367e14e..e7a0e2142327 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -215,48 +215,49 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None): Returns: Vectorized version of the given function. - Here are a few examples of how one could write vectorized linear algebra - routines using :func:`vectorize`: - - >>> from functools import partial - - >>> @partial(jnp.vectorize, signature='(k),(k)->(k)') - ... def cross_product(a, b): - ... assert a.shape == b.shape and a.ndim == b.ndim == 1 - ... return jnp.array([a[1] * b[2] - a[2] * b[1], - ... a[2] * b[0] - a[0] * b[2], - ... a[0] * b[1] - a[1] * b[0]]) - - >>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)') - ... def matrix_vector_product(matrix, vector): - ... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape - ... return matrix @ vector - - These functions are only written to handle 1D or 2D arrays (the ``assert`` - statements will never be violated), but with vectorize they support - arbitrary dimensional inputs with NumPy style broadcasting, e.g., - - >>> cross_product(jnp.ones(3), jnp.ones(3)).shape - (3,) - >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape - (2, 3) - >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape - (2, 2, 3) - >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ValueError: input with shape (3,) does not have enough dimensions for all - core dimensions ('n', 'k') on vectorized function with excluded=frozenset() - and signature='(n,k),(k)->(k)' - >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape - (2,) - >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape - (4, 2) - - Note that this has different semantics than `jnp.matmul`: - - >>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4]. + Examples: + Here are a few examples of how one could write vectorized linear algebra + routines using :func:`vectorize`: + + >>> from functools import partial + + >>> @partial(jnp.vectorize, signature='(k),(k)->(k)') + ... def cross_product(a, b): + ... assert a.shape == b.shape and a.ndim == b.ndim == 1 + ... return jnp.array([a[1] * b[2] - a[2] * b[1], + ... a[2] * b[0] - a[0] * b[2], + ... a[0] * b[1] - a[1] * b[0]]) + + >>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)') + ... def matrix_vector_product(matrix, vector): + ... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape + ... return matrix @ vector + + These functions are only written to handle 1D or 2D arrays (the ``assert`` + statements will never be violated), but with vectorize they support + arbitrary dimensional inputs with NumPy style broadcasting, e.g., + + >>> cross_product(jnp.ones(3), jnp.ones(3)).shape + (3,) + >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape + (2, 3) + >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape + (2, 2, 3) + >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ValueError: input with shape (3,) does not have enough dimensions for all + core dimensions ('n', 'k') on vectorized function with excluded=frozenset() + and signature='(n,k),(k)->(k)' + >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape + (2,) + >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape + (4, 2) + + Note that this has different semantics than `jnp.matmul`: + + >>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4]. """ if any(not isinstance(exclude, (str, int)) for exclude in excluded): raise TypeError("jax.numpy.vectorize can only exclude integer or string arguments, " diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a1d1a0292338..9dc2e079bb3f 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6341,6 +6341,8 @@ def test_lax_numpy_docstrings(self): self.assertNotEmpty(doc) self.assertIn("Args:", doc, msg=f"'Args:' not found in docstring of jnp.{name}") self.assertIn("Returns:", doc, msg=f"'Returns:' not found in docstring of jnp.{name}") + if name not in ["frompyfunc", "isdtype", "promote_types"]: + self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}") @parameterized.named_parameters( {"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False]) From a2b39192d285130401d7d2a38c9fdcc715cd6264 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sat, 21 Sep 2024 10:22:36 -0700 Subject: [PATCH 600/702] Make `make_array_from_process_local_data` go via `device_put` if there is only 1 process. PiperOrigin-RevId: 677232996 --- jax/_src/array.py | 9 ++++++--- tests/array_test.py | 18 ++++++------------ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 9e5595aacca3..83be3d418c50 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -892,17 +892,20 @@ def make_array_from_process_local_data( setting it to (4, 4) in this case. Args: - sharding: sharding of the global tensor. - local_data: data on the host to be placed on local devices. Each + sharding: Sharding of the global array. + local_data: Data on the host to be placed on local devices. Each dimension should either match global_shape, or match num_addressable_indices(dim). - global_shape: the target shape of the global tensor. If None, + global_shape: The target shape of the global array. If None, will infer from local_data and sharding. Returns: Tensor that will have sharding=sharding and of shape global_shape. """ # pyformat: enable + if xla_bridge.process_count() == 1: + return api.device_put(local_data, sharding) + # TODO(sandler): consider supporting partially specified global_shape or # making local_to_global_shape available in the api. local_shape = local_data.shape diff --git a/tests/array_test.py b/tests/array_test.py index 9260e800258e..080356d3490a 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -822,18 +822,12 @@ def test_make_array_from_callback_global_array(self): self.assertEqual(out2.sharding, sharding2) def test_make_array_from_process_data_single_host_data_sharding(self): - data = np.ones((1, 512)) - mesh = jtu.create_mesh((1, 1), ('x', 'unused')) - sharding_spec = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec('x') - ) - global_shape = data.shape - result = jax.make_array_from_process_local_data( - sharding_spec, data, global_shape - ) - self.assertIsInstance(result, jax.Array) - self.assertEqual(result.shape, data.shape) - self.assertEqual(result.sharding, sharding_spec) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + data = np.ones((256, 512)) + s = jax.NamedSharding(mesh, P('x')) + result = jax.make_array_from_process_local_data(s, data) + self.assertArraysEqual(result, data) + self.assertEqual(result.sharding, s) class ShardingTest(jtu.JaxTestCase): From 43cc70b7a15fb7fd584fda6714f4c38f0ea7dd76 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 20 Sep 2024 22:58:01 +0000 Subject: [PATCH 601/702] add jax.experimental.primal_tangent_dtype helper useful for constructing new dtypes which have a distinct tangent type (e.g. for quantization) --- jax/_src/core.py | 21 +++++++++++---------- jax/_src/dtypes.py | 22 ++++++++++++++++++++++ jax/_src/interpreters/mlir.py | 12 ++++-------- jax/_src/interpreters/pxla.py | 2 +- jax/_src/lax/lax.py | 3 +-- jax/_src/lax/slicing.py | 4 ++-- jax/_src/sharding_impls.py | 10 +++++----- jax/experimental/__init__.py | 3 +++ tests/dtypes_test.py | 28 ++++++++++++++++++++++++++++ 9 files changed, 77 insertions(+), 28 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index bff59625b702..467f2b63d390 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1596,16 +1596,17 @@ def physical_aval(aval: DShapedArray) -> DShapedArray: ... def physical_aval(aval: AbstractValue) -> AbstractValue: ... def physical_aval(aval): - aval_dtype = getattr(aval, 'dtype', None) - if aval_dtype and isinstance(aval_dtype, dtypes.ExtendedDType): - ctor = type(aval) - aval_shape = getattr(aval, 'shape', None) - assert aval_shape is not None, (ctor, aval) - elt_aval = aval_dtype._rules.physical_element_aval(aval_dtype) - assert type(elt_aval) is ShapedArray - return ctor((*aval_shape, *elt_aval.shape), elt_aval.dtype) # pytype: disable=wrong-arg-count - else: - return aval + if (isinstance(aval, (ShapedArray, DShapedArray)) and + isinstance(aval.dtype, dtypes.ExtendedDType)): + elt_aval = physical_element_aval(aval.dtype) + if isinstance(aval, ShapedArray): + return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) + return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) + return aval + +def physical_element_aval(edtype: dtypes.ExtendedDType) -> ShapedArray: + duck = edtype._rules.physical_element_aval(edtype) # type: ignore + return ShapedArray(duck.shape, dtypes.dtype(duck.dtype)) def _short_dtype_name(dtype) -> str: if isinstance(dtype, dtypes.ExtendedDType): diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 352a3e550112..9865632d8975 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -23,7 +23,9 @@ import abc import builtins +import dataclasses import functools +import types from typing import cast, overload, Any, Literal, Union import warnings @@ -834,3 +836,23 @@ def safe_to_cast(input_dtype_or_value: Any, # We deliberately use output_dtype rather than output_dtype_or_value here: # this effectively treats the output dtype as always strongly-typed. return result_type(input_dtype_or_value, output_dtype) == output_dtype + +def primal_tangent_dtype(primal_dtype, tangent_dtype, + name: str | None = None) -> ExtendedDType: + name_ = name or f'PTDtype{{{primal_dtype}:{tangent_dtype}}}' + rules = types.SimpleNamespace( + physical_element_aval= + lambda dtype: types.SimpleNamespace(shape=(), dtype=primal_dtype), + tangent_dtype=lambda dtype: tangent_dtype, + convert_from=lambda _, other: other == primal_dtype, + convert_to=lambda other, _: other == primal_dtype) + + class primal_tangent_dtype_scalar(extended): ... + + @dataclasses.dataclass(frozen=True) + class PrimalTangentDType(ExtendedDType): + name = name_ + _rules = rules + type = primal_tangent_dtype_scalar + + return PrimalTangentDType() diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index c4c77c72b88b..b65899830282 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2158,8 +2158,7 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, # op is broadcast. # Lower a possibly-dynamic broadcast_in_dim if dtypes.issubdtype(aval_out.dtype, dtypes.extended): # type: ignore - elt_shape = aval_out.dtype._rules.physical_element_aval( # type: ignore - aval_out.dtype).shape # type: ignore + elt_shape = core.physical_element_aval(aval_out.dtype).shape # type: ignore trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] # type: ignore broadcast_dimensions = [*broadcast_dimensions, *trailing_dims] physical_aval_out = core.physical_aval(aval_out) @@ -2213,8 +2212,7 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va def slice_op(ctx: LoweringRuleContext, x, aval_out, *, start_indices, limit_indices, strides) -> ir.Value: if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - elt_shape = aval_out.dtype._rules.physical_element_aval( - aval_out.dtype).shape + elt_shape = core.physical_element_aval(aval_out.dtype).shape trailing_zeros = [0] * len(elt_shape) trailing_ones = [1] * len(elt_shape) start_indices = (*start_indices, *trailing_zeros) @@ -2241,8 +2239,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, start_indices) -> ir.Value: x_aval = ctx.avals_in[0] if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - elt_shape = aval_out.dtype._rules.physical_element_aval( - aval_out.dtype).shape + elt_shape = core.physical_element_aval(aval_out.dtype).shape index_avals = ctx.avals_in[1:] dtype = dtypes.canonicalize_dtype( index_avals[0].dtype if index_avals else 'int64') # type: ignore @@ -2275,8 +2272,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *, start_indices) -> ir.Value: if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - elt_shape = aval_out.dtype._rules.physical_element_aval( - aval_out.dtype).shape + elt_shape = core.physical_element_aval(aval_out.dtype).shape index_avals = ctx.avals_in[2:] dtype = dtypes.canonicalize_dtype( index_avals[0].dtype if index_avals else 'int64') # type: ignore diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index de668090eaa1..4c134f266da5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1429,7 +1429,7 @@ def _hlo_shard(aval, axis_env, x, in_axis): return x elif isinstance(aval, core.ShapedArray): if dtypes.issubdtype(aval.dtype, dtypes.extended): - aval = aval.dtype._rules.physical_element_aval(aval.dtype) + aval = core.physical_element_aval(aval.dtype) dims = list(aval.shape) zero = mlir.ir_constant(np.zeros((), dtype=np.uint32)) idxs = [zero] * len(dims) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 394a54c357b8..7ffab8d8c8c9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3841,8 +3841,7 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation): def _transpose_lower(ctx, x, *, permutation): aval_out, = ctx.avals_out if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - elt_shape = aval_out.dtype._rules.physical_element_aval( - aval_out.dtype).shape + elt_shape = core.physical_element_aval(aval_out.dtype).shape trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] permutation = [*permutation, *trailing_dims] return [hlo.transpose(x, mlir.dense_int_array(permutation))] diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 5ed1945ecb96..60dfa0e1b3d2 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1783,7 +1783,7 @@ def _gather_lower_opaque(ctx, operand, indices, *, indices_are_sorted, mode, fill_value) -> ir.Value: aval_x, aval_indices = ctx.avals_in aval_y, = ctx.avals_out - elt_shape = aval_x.dtype._rules.physical_element_aval(aval_x.dtype).shape + elt_shape = core.physical_element_aval(aval_x.dtype).shape trailing_offset_dims = [aval_y.ndim + i for i in range(len(elt_shape))] dimension_numbers = dimension_numbers._replace( offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims)) @@ -2436,7 +2436,7 @@ def _scatter_lower_opaque(ctx, operand, indices, updates, *, unique_indices, indices_are_sorted, mode): aval_x, aval_indices, aval_updates = ctx.avals_in aval_y, = ctx.avals_out - elt_shape = aval_x.dtype._rules.physical_element_aval(aval_x.dtype).shape + elt_shape = core.physical_element_aval(aval_x.dtype).shape trailing_window_dims = [aval_updates.ndim + i for i in range(len(elt_shape))] dimension_numbers = dimension_numbers._replace( update_window_dims=(*dimension_numbers.update_window_dims, diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 310ff38b7247..b69e78fe9ddf 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1509,7 +1509,7 @@ def num_addressable_indices( def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: - elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + elt_aval = core.physical_element_aval(aval.dtype) new_op_sharding = hlo_sharding.to_proto().clone() partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding) suffix = [] if num_replicas == 1 else [num_replicas] @@ -1526,7 +1526,7 @@ def make_key_array_phys_sharding(aval, sharding): if is_single_device_sharding(sharding): return sharding elif isinstance(sharding, PmapSharding): - elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + elt_aval = core.physical_element_aval(aval.dtype) trailing_sharding = [sharding_specs.NoSharding()] * elt_aval.ndim phys_sharding_spec = sharding_specs.ShardingSpec( sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), @@ -1534,7 +1534,7 @@ def make_key_array_phys_sharding(aval, sharding): return PmapSharding(devices=sharding.devices, sharding_spec=phys_sharding_spec) elif isinstance(sharding, NamedSharding): - elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + elt_aval = core.physical_element_aval(aval.dtype) trailing_spec = [None] * elt_aval.ndim return NamedSharding( sharding.mesh, @@ -1551,7 +1551,7 @@ def physical_sharding( def get_logical_gspmd_sharding(aval, phys_sharding): - elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + elt_aval = core.physical_element_aval(aval.dtype) phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( aval.ndim + elt_aval.ndim) partitions, num_replicas = get_num_ways_dim_sharded(phys_hlo_sharding) @@ -1583,7 +1583,7 @@ def logical_sharding(aval, phys_sharding) -> sharding.Sharding: if is_single_device_sharding(phys_sharding): return phys_sharding elif isinstance(phys_sharding, PmapSharding): - elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + elt_aval = core.physical_element_aval(aval.dtype) logical_sharding_spec = sharding_specs.ShardingSpec( sharding=phys_sharding.sharding_spec.sharding[:-elt_aval.ndim], mesh_mapping=phys_sharding.sharding_spec.mesh_mapping) diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 1b22c2c2ada9..375d058d0edc 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -22,6 +22,9 @@ from jax._src.callback import ( io_callback as io_callback ) +from jax._src.dtypes import ( + primal_tangent_dtype as primal_tangent_dtype, +) from jax._src.earray import ( EArray as EArray ) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index e736e06da2d0..e9d05a9eaaa5 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -578,6 +578,34 @@ def test_check_dtype_array(self): with self.assertWarnsRegex(DeprecationWarning, msg): jax.jit(dtypes.check_user_dtype_supported)(x) + @parameterized.parameters([True]) # TODO(mattjj): make jit=False work + def test_primal_tangent_dtype(self, jit): + dt = dtypes.primal_tangent_dtype(jnp.int8, jnp.bfloat16) + + x = jax.random.uniform(jax.random.key(0), (3,), minval=0, maxval=10 + ).astype(jnp.int8) + g = jax.random.uniform(jax.random.key(0), (3,), minval=0, maxval=10 + ).astype(jnp.bfloat16) + + @jax.custom_gradient + def f(x): + def bwd(g): + return 2 * g, + return jnp.int8(x).astype(g.dtype) * 2 + 1, bwd + + def h(): + result, bwd = jax.vjp(f, x.astype(dt)) + bwd_result, = bwd(g) + return result, bwd_result + + if jit: + h = jax.jit(h) + + result, bwd_result = h() + self.assertEqual(result.dtype, jnp.bfloat16) + self.assertEqual(bwd_result.dtype, jnp.bfloat16) + self.assertAllClose(bwd_result, 2 * g) + class EArrayTest(jtu.JaxTestCase): From ba74490e6fc8e9ff0af20ce9a3bdf233c16cdcb9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 21 Sep 2024 13:51:03 -0700 Subject: [PATCH 602/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/e0eff722043ace468aa64efccdbb98b473a4e6ed. PiperOrigin-RevId: 677269125 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 376f2167e35c..ce7f0747d0c4 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "44d14566fc5d298d5d410efa24ed8630ce137137" -XLA_SHA256 = "2aa4d49121faa95c063c413e28fa87e0a5af64177588acd35459401f0e76f2ea" +XLA_COMMIT = "e0eff722043ace468aa64efccdbb98b473a4e6ed" +XLA_SHA256 = "4ba04f6a9f1273ee5fe149c556a620c175035b8f31b5f4620b3f814771659a1c" def repo(): tf_http_archive( From 02994d6bbbc51c2785462c77ea5a4f2441e36b05 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 22 Sep 2024 13:26:21 -0700 Subject: [PATCH 603/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/2101ae888f054bfbd13d5ac42af8aea3b1600749. PiperOrigin-RevId: 677526024 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ce7f0747d0c4..56fac1fa611b 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "e0eff722043ace468aa64efccdbb98b473a4e6ed" -XLA_SHA256 = "4ba04f6a9f1273ee5fe149c556a620c175035b8f31b5f4620b3f814771659a1c" +XLA_COMMIT = "2101ae888f054bfbd13d5ac42af8aea3b1600749" +XLA_SHA256 = "3392793681e186f8ee5e901bd047bb688c9337970244d0c07ea173889db3d837" def repo(): tf_http_archive( From 48c29f62e10b7b29d096848214701e4ae58cf31d Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Sun, 22 Sep 2024 14:29:56 -0700 Subject: [PATCH 604/702] [pallas:mosaic_gpu] Fragmented array debug printing. PiperOrigin-RevId: 677537364 --- jax/_src/pallas/mosaic_gpu/lowering.py | 17 +++++++++-- .../mosaic/gpu/fragmented_array.py | 30 ++++++++++++++++--- tests/pallas/mosaic_gpu_test.py | 17 +++++++++++ 3 files changed, 57 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6eae64b7affa..9eaca527d4ef 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -695,10 +695,21 @@ def _debug_print_lowering_rule( has_placeholders: bool, ): del has_placeholders # Unused. - if any(aval.shape for aval in ctx.avals_in): - raise NotImplementedError("Only scalar values are supported") primitives.check_debug_print_format(fmt, *args) - mgpu.debug_print(fmt, *args) + if not any(aval.shape for aval in ctx.avals_in): + mgpu.debug_print(fmt, *args) + elif len(ctx.avals_in) == 1: + @args[0].foreach + def _(val, idx): + idx_fmt = ", ".join(["{}"] * len(idx)) + fmt_str = fmt.format(f"[{idx_fmt}]/{list(args[0].shape)}: {{}}") + mgpu.debug_print(fmt_str, *idx, val, uniform=False) + else: + raise NotImplementedError( + "debug_print only supports printing of scalar values, or a single array" + " value when using the Mosaic GPU backend." + ) + return () diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index d5a6e9eb69d1..b0352bf4489f 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -110,6 +110,17 @@ def from_memref_type(cls, memref_ty: ir.Type): ) def thread_vec_idxs(self): + index = ir.IndexType.get() + for v in self.linear_thread_vec_idxs(): + res = [] + for dim in reversed(self.shape): + dim = c(dim, index) + res.append(arith.remui(v, dim)) + v = arith.divui(v, dim) + res.reverse() + yield res + + def linear_thread_vec_idxs(self): """The indexes to be used for vector load/store WGStridedFragLayout. Yields: @@ -122,7 +133,7 @@ def thread_vec_idxs(self): tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE, index)) off = arith.muli(tidx, c(self.vec_size, tidx.type)) for i in range(reg_num): - yield [arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))] + yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout @@ -182,7 +193,7 @@ def load_strided(cls, ref: ir.Value): ref_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) layout = WGStridedFragLayout.from_memref_type(ref_ty) vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) - vecs = [vector.load(vec_ty, ref_1d, vec_idx) for vec_idx in layout.thread_vec_idxs()] + vecs = [vector.load(vec_ty, ref_1d, [vec_idx]) for vec_idx in layout.linear_thread_vec_idxs()] return cls(_registers=np.array(vecs), _layout=layout) @classmethod @@ -623,6 +634,17 @@ def broadcast_minor(self, n): ) return FragmentedArray(_registers=new_regs, _layout=WGMMA_LAYOUT) + def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]): + """Call a function for each value and index.""" + if not isinstance(self.layout, WGStridedFragLayout): + raise NotImplementedError(self.layout) + index = ir.IndexType.get() + for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat): + assert len(idx) == len(self.shape), (idx, self.shape) + for i in range(self.layout.vec_size): + i = c(i, index) + fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i))) + def store_untiled(self, ref: ir.Value): if not ir.MemRefType.isinstance(ref.type): raise ValueError(ref) @@ -658,8 +680,8 @@ def _store_untiled_wg_strided(self, ref: ir.Value): if ref_shape != self.shape: raise ValueError((ref_shape, self.shape)) smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) - for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat): - vector.store(reg, smem_1d, idx) + for idx, reg in zip(self.layout.linear_thread_vec_idxs(), self.registers.flat): + vector.store(reg, smem_1d, [idx]) def _store_untiled_wgmma(self, ref: ir.Value): """Stores accumulator to a 2D memref. Not optimized at the moment.""" diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 17ef26c7f9b3..6aa74263cd94 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import math from absl.testing import absltest from absl.testing import parameterized @@ -240,6 +241,22 @@ def kernel(x_ref, o_ref): # TODO(slebedev): Remove assertRaises() once we support indexing. kernel(x) + def test_print_array(self): + in_shape = [2, 1, 64, 64] + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(in_shape, jnp.float32), + ) + def kernel(x_ref, o_ref): + del o_ref + pl.debug_print("x: {}", x_ref[...]) + + x = jnp.arange(math.prod(in_shape)).reshape(in_shape).astype(jnp.float32) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertIn(f"x: [1, 0, 43, 23]/{list(in_shape)}: 6871.000000\n", output()) + def test_scoped_allocation(self): def kernel(x_ref, o_ref): def body(tmp_ref): From b6fe7939090b0c4692f295bcdfa6255a52a6758e Mon Sep 17 00:00:00 2001 From: Ayaka Date: Sun, 22 Sep 2024 18:54:48 -0700 Subject: [PATCH 605/702] [Pallas] Skip `atomic_cas` and `atomic_counter` tests on GPU in 64-bit mode These tests are failing on GPU in 64-bit mode. This fixes test failures introduced by https://github.com/jax-ml/jax/pull/23798 PiperOrigin-RevId: 677583606 --- tests/pallas/ops_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 65cdde30f7d9..346dde0dd79e 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1461,6 +1461,9 @@ def reduce(x_ref, _, y_ref): (2, 1, 1), ) def test_atomic_cas(self, init_value, cmp, new_value): + if jax.config.x64_enabled and jtu.test_device_matches(["gpu"]): + self.skipTest("Not supported on GPU in 64-bit mode") + @functools.partial( self.pallas_call, out_shape=( jax.ShapeDtypeStruct((), intx), @@ -1479,10 +1482,13 @@ def test_atomic_counter(self, num_threads): if self.INTERPRET: self.skipTest("While loop not supported in interpret mode.") + if jax.config.x64_enabled and jtu.test_device_matches(["gpu"]): + self.skipTest("Not supported on GPU in 64-bit mode") + @functools.partial( self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), jnp.int32), - jax.ShapeDtypeStruct((), jnp.int32)), + jax.ShapeDtypeStruct((), intx), + jax.ShapeDtypeStruct((), intx)), input_output_aliases={0: 0, 1: 1}, grid=(num_threads,)) def increment(_, __, lock_ref, counter_ref): From 2199685437da8f2521cf2af688853e14c081d6a1 Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Sun, 22 Sep 2024 22:25:50 -0700 Subject: [PATCH 606/702] Ignore scipy.stats._axis_nan_policy.SmallSampleWarning for LaxBackedScipyStatsTests.testMode It is to fix our CI, the warning itself started occurring on scipy 1.14 due to this change https://github.com/scipy/scipy/pull/20694, which introduced SmallSampleWarning and started emitting it if the input is an empty array (the `a` variable in the randomized parametrized test LaxBackedScipyStatsTests.testMode sometimes happens to be an empty array). Note, the actual ignored warning is RungimeWarning (the superclass of SmallSampleWarning) to make it backward compatible (scipy.stats._axis_nan_policy.SmallSampleWarning does not exist in scipy prior 1.14, not to mention it being under private declared in a private (_axis_nan_policy) namespace. PiperOrigin-RevId: 677629866 --- tests/scipy_stats_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 983cb6bdc37a..91563f698ad4 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -1568,6 +1568,10 @@ def evaluate_kde(kde, x): contains_nans=[True, False], keepdims=[True, False] ) + @jtu.ignore_warning( + category=RuntimeWarning, + message="One or more sample arguments is too small; all returned values will be NaN" + ) def testMode(self, shape, dtype, axis, contains_nans, keepdims): if scipy_version < (1, 9, 0) and keepdims != True: self.skipTest("scipy < 1.9.0 only support keepdims == True") From 653f07a7e1f5f1f5e14b5fcc01077e770ba9c15c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 23 Sep 2024 05:58:00 -0700 Subject: [PATCH 607/702] Updated Pallas Mosaic GPU lowering post Mosaic GPU restructuring PiperOrigin-RevId: 677758519 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 9eaca527d4ef..e45acc3f62ac 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -493,7 +493,7 @@ def _(step, _): jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8) ) - module, out_structs_smem, _ = mgpu._lower_as_gpu_kernel( + module, out_structs_smem, _ = mgpu.core._lower_as_gpu_kernel( body, grid=grid, cluster=(), From f311e81c0299c1d9b787202e57c60ac93bf5df21 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 23 Sep 2024 06:58:55 -0700 Subject: [PATCH 608/702] Added `is_signed` to `mgpu.FragmentedArray` The registers within a fragmented array always use signless types, and instead the signedness is tracked on the fragmented arrays itself (i.e. in Python). PiperOrigin-RevId: 677776009 --- jax/_src/pallas/mosaic_gpu/lowering.py | 39 ++- jax/experimental/mosaic/gpu/core.py | 4 +- .../mosaic/gpu/fragmented_array.py | 183 +++++++++----- jax/experimental/mosaic/gpu/utils.py | 17 +- jax/experimental/mosaic/gpu/wgmma.py | 14 +- tests/mosaic/gpu_test.py | 231 +++++++++--------- 6 files changed, 300 insertions(+), 188 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e45acc3f62ac..80222dbaea22 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -28,7 +28,6 @@ from jax._src import core as jax_core from jax._src import pjit from jax._src import util -from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect @@ -40,6 +39,8 @@ from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.state import primitives as sp import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import core as mgpu_core +from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp import numpy as np @@ -137,7 +138,7 @@ def scratch_view( for s in structs: scratch_ty = ir.MemRefType.get( s.shape, - mlir.dtype_to_ir_type(s.dtype), + mgpu_utils.dtype_to_ir_type(s.dtype), memory_space=smem, ) views.append( @@ -493,7 +494,7 @@ def _(step, _): jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8) ) - module, out_structs_smem, _ = mgpu.core._lower_as_gpu_kernel( + module, out_structs_smem, _ = mgpu_core._lower_as_gpu_kernel( body, grid=grid, cluster=(), @@ -598,20 +599,26 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): @register_lowering_rule(sp.get_p) def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree): - del ctx, tree # Unused. + del tree # Unused. if indexers: raise NotImplementedError("No support for indexers yet") - return mgpu.FragmentedArray.load_strided(x_smem) + [x_aval] = ctx.avals_in + return mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) @register_lowering_rule(sp.swap_p) def _swap_lowering_rule( ctx: LoweringRuleContext, x_smem, value, *indexers, tree ): - del ctx, tree # Unused. + del tree # Unused. if indexers: raise NotImplementedError("No support for indexers yet") - old_value = mgpu.FragmentedArray.load_strided(x_smem) + x_aval, _ = ctx.avals_in + old_value = mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) value.store_untiled(x_smem) return old_value @@ -645,7 +652,9 @@ def _convert_element_type_lowering_rule( ): del weak_type, sharding [x_aval] = ctx.avals_in - return _ensure_fa(x, x_aval.dtype).astype(mlir.dtype_to_ir_type(new_dtype)) + return _ensure_fa(x, x_aval.dtype).astype( + mgpu_utils.dtype_to_ir_type(new_dtype), is_signed=mgpu_utils.is_signed(new_dtype) + ) def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): @@ -673,7 +682,7 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): @register_lowering_rule(lax.rsqrt_p) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - return _ensure_fa(x, x_aval.dtype).rsqrt(ctx.module_ctx.approx_math) + return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) @register_lowering_rule(lax.reduce_sum_p) @@ -684,7 +693,9 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): _, [scratch] = ctx.module_ctx.scratch_view( [jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype)] ) - return mgpu.FragmentedArray.splat(x.reduce_sum(scratch), ()) + return mgpu.FragmentedArray.splat( + x.reduce_sum(scratch), (), is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) @register_lowering_rule(primitives.debug_print_p) @@ -832,11 +843,13 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: return x elif isinstance(x, (np.number, np.ndarray, int, float)): return mgpu.FragmentedArray.splat( - _ir_constant(x, mlir.dtype_to_ir_type(dtype)), () + _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)), + (), + is_signed=mgpu_utils.is_signed(dtype), ) elif isinstance(x, ir.Value): if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)): - return mgpu.FragmentedArray.splat(x, ()) + return mgpu.FragmentedArray.splat(x, (), is_signed=mgpu_utils.is_signed(dtype)) raise NotImplementedError(f"Unsupported type: {type(x)}") @@ -844,7 +857,7 @@ def _ensure_ir_value(x: object, aval: jax_core.ShapedArray) -> ir.Value: if isinstance(x, ir.Value): return x elif isinstance(x, (np.number, np.ndarray, int, float)): - return _ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)) + return _ir_constant(x, mgpu_utils.dtype_to_ir_type(aval.dtype)) raise NotImplementedError(f"Unsupported type: {type(x)}") diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 0e263844b18e..9a4570ec1673 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -594,7 +594,7 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value: cluster_shape, ) case _: - mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype) + mlir_dtype = utils.dtype_to_ir_type(ref_ty.dtype) tile_smem = memref.view( ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), dynamic_smem, c(dynamic_smem_offset, index), [], @@ -734,7 +734,7 @@ def _lower_as_gpu_kernel( i64 = ir.IntegerType.get_signless(64) def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: - return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) + return ir.MemRefType.get(shape.shape, utils.dtype_to_ir_type(shape.dtype)) in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index b0352bf4489f..37deee6130d5 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -144,48 +144,71 @@ def linear_thread_vec_idxs(self): @jax.tree_util.register_pytree_node_class +@dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True) class FragmentedArray: - registers: np.ndarray # of ir.Value, see checks in init for shapes. + # An array of ir.Value, see checks in init for shapes. + registers: np.ndarray = dataclasses.field(repr=False) layout: FragmentedLayout - - def __init__(self, *, _registers: np.ndarray, _layout: FragmentedLayout): - self.registers = _registers - self.layout = _layout + is_signed: bool | None + + def __init__( + self, + *, + _registers: np.ndarray, + _layout: FragmentedLayout, + _is_signed: bool | None, + ): + """Initializes a fragmented array. + + This is a low-level API. Prefer using classmethods to construct fragmented + arrays instead. + """ + # We need to use ``object.__setattr__`` here because of ``frozen=True``. + object.__setattr__(self, "registers", _registers) + object.__setattr__(self, "layout", _layout) + object.__setattr__(self, "is_signed", _is_signed) + + if (_is_signed is not None) != ir.IntegerType.isinstance(self.mlir_dtype): + raise TypeError( + "is_signed must only be non-None if the MLIR type is an integer" + f" type, got {_is_signed=} for {self.mlir_dtype}" + ) match self.layout: # Registers are [m_tiles, n_tiles, 2 rows, 1 cols] in WGMMA layout # Each element is a vector<2xdtype> case WGMMAFragLayout(): - if self.registers.ndim != 4 or self.registers.shape[2:] != (2, 1): - raise ValueError("Invalid register array shape") + if _registers.ndim != 4 or _registers.shape[2:] != (2, 1): + raise ValueError(f"Invalid register array shape: {_registers.shape}") # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout # Each element is a dtype scalar case WGMMARowFragLayout(): - if self.registers.ndim != 2 or self.registers.shape[-1] != 2: - raise ValueError("Invalid register array shape") + if _registers.ndim != 2 or _registers.shape[-1] != 2: + raise ValueError(f"Invalid register array shape: {_registers.shape}") # Registers are flat case WGStridedFragLayout(shape): - (reg_size,) = ir.VectorType(_registers.flat[0].type).shape - if np.prod(shape) != np.prod(_registers.shape) * WARPGROUP_SIZE * reg_size: - raise ValueError((reg_size, shape, _registers.shape, WARPGROUP_SIZE), _registers.flat[0].type) + [reg_size] = ir.VectorType(_registers.flat[0].type).shape + if ( + math.prod(shape) + != math.prod(_registers.shape) * WARPGROUP_SIZE * reg_size + ): + raise ValueError( + "Invalid register array shape: math.prod({_registers.shape}) *" + " {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" + ) # Just a single register case WGSplatFragLayout(): if _registers.size != 1: - raise ValueError(f"WGStridedFragLayout requires a single value {_registers.shape} ({_registers.size})") + raise ValueError(f"Invalid register array shape: {_registers.shape}") case _: raise NotImplementedError - def __repr__(self): - return ( - f"FragmentedArray(layout={self.layout}, shape={self.shape})" - ) - @classmethod - def load_strided(cls, ref: ir.Value): + def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None): if not ir.MemRefType.isinstance(ref.type): raise TypeError(ref.type) @@ -194,10 +217,10 @@ def load_strided(cls, ref: ir.Value): layout = WGStridedFragLayout.from_memref_type(ref_ty) vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) vecs = [vector.load(vec_ty, ref_1d, [vec_idx]) for vec_idx in layout.linear_thread_vec_idxs()] - return cls(_registers=np.array(vecs), _layout=layout) + return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) @classmethod - def splat(cls, value, shape, layout=None): + def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) match layout: case WGMMARowFragLayout(): @@ -227,6 +250,7 @@ def splat(cls, value, shape, layout=None): return cls( _registers=np.full(reg_shape, value, dtype=object), _layout=layout, + _is_signed=is_signed, ) @property @@ -261,12 +285,19 @@ def _pointwise(self, op, *other): elif not isinstance(o, ir.Value): raise NotImplementedError(o) - o = FragmentedArray.splat(o, shape=self.shape, layout=self.layout) + o = FragmentedArray.splat( + o, shape=self.shape, layout=self.layout, is_signed=self.is_signed + ) if isinstance(o.layout, WGSplatFragLayout): if not o.layout.can_broadcast_to(self.shape): raise ValueError("Can't broadcast shape.") - o = FragmentedArray.splat(o.registers.flat[0], shape=self.shape, layout=self.layout) + o = FragmentedArray.splat( + o.registers.flat[0], + shape=self.shape, + layout=self.layout, + is_signed=self.is_signed, + ) else: if self.layout != o.layout: raise ValueError("Incompatible FragmentedArray layouts") @@ -278,15 +309,20 @@ def _pointwise(self, op, *other): for idx, reg in np.ndenumerate(self.registers): new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs)) - return FragmentedArray(_registers=new_regs, _layout=self.layout) + return FragmentedArray( + _registers=new_regs, _layout=self.layout, _is_signed=self.is_signed + ) + + def __pos__(self): + return self def __neg__(self): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise(arith.negf) elif ir.IntegerType.isinstance(self.mlir_dtype): - return self._pointwise(arith.negsi) + return 0 - self else: - raise NotImplementedError(self.mlir_dtype) + return NotImplemented def __add__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): @@ -294,7 +330,7 @@ def __add__(self, other): elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.addi, other) else: - raise NotImplementedError(self.mlir_dtype) + return NotImplemented def __radd__(self, other): return self + other @@ -305,20 +341,26 @@ def __mul__(self, other): elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.muli, other) else: - raise NotImplementedError(self.mlir_dtype) + return NotImplemented def __rmul__(self, other): return self * other def __sub__(self, other): - if not ir.FloatType.isinstance(self.mlir_dtype): - raise NotImplementedError - return self._pointwise(arith.subf, other) + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(arith.subf, other) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise(arith.subi, other) + else: + return NotImplemented def __rsub__(self, other): - if not ir.FloatType.isinstance(self.mlir_dtype): - raise NotImplementedError - return self._pointwise(lambda s, o: arith.subf(o, s), other) + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(lambda s, o: arith.subf(o, s), other) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise(lambda s, o: arith.subi(o, s), other) + else: + return NotImplemented def __truediv__(self, other): if not ir.FloatType.isinstance(self.mlir_dtype): @@ -331,11 +373,16 @@ def __rtruediv__(self, other): return self._pointwise(lambda s, o: arith.divf(o, s), other) def max(self, other): - if not ir.FloatType.isinstance(self.mlir_dtype): - raise NotImplementedError - return self._pointwise(arith.maximumf, other) + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(arith.maximumf, other) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise( + arith.maxsi if self.is_signed else arith.maxui, other + ) + else: + return NotImplemented - def exp(self, approx: bool = False): + def exp(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx: @@ -349,7 +396,7 @@ def fast_exp(x): return self._pointwise(self._lift_fast_unary(fast_exp)) return self._pointwise(mlir_math.exp) - def sin(self, approx: bool = False): + def sin(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): @@ -358,7 +405,7 @@ def sin(self, approx: bool = False): self._lift_fast_unary("sin.approx.f32") if approx else mlir_math.sin ) - def cos(self, approx: bool = False): + def cos(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): @@ -367,7 +414,7 @@ def cos(self, approx: bool = False): self._lift_fast_unary("cos.approx.f32") if approx else mlir_math.cos ) - def rsqrt(self, approx: bool = False): + def rsqrt(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): @@ -439,10 +486,12 @@ def __getitem__(self, idx): base_idx[0] : base_idx[0] + slice_shape[0], base_idx[1] : base_idx[1] + slice_shape[1], ] - return FragmentedArray(_registers=new_regs, _layout=self.layout) + return FragmentedArray( + _registers=new_regs, _layout=self.layout, _is_signed=self.is_signed + ) # TODO(apaszke): Support JAX dtypes here as well? - def astype(self, new_dtype: ir.Type): + def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): i8 = ir.IntegerType.get_signless(8) i16 = ir.IntegerType.get_signless(16) i32 = ir.IntegerType.get_signless(32) @@ -450,7 +499,11 @@ def astype(self, new_dtype: ir.Type): cur_dtype = self.mlir_dtype if cur_dtype == new_dtype: - return self + if self.is_signed == is_signed: + return self + return FragmentedArray( + _registers=self.registers, _layout=self.layout, _is_signed=is_signed + ) reg_type = self.registers.flat[0].type is_vector_reg = ir.VectorType.isinstance(reg_type) reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else () @@ -485,7 +538,9 @@ def astype(self, new_dtype: ir.Type): new_registers[idx] = vector.bitcast( ir.VectorType.get((2,), new_dtype), new_vec ) - return FragmentedArray(_registers=new_registers, _layout=self.layout) + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=is_signed + ) # Generic path. from_float = ir.FloatType.isinstance(cur_dtype) to_float = ir.FloatType.isinstance(new_dtype) @@ -519,7 +574,9 @@ def astype(self, new_dtype: ir.Type): raise NotImplementedError(f"Unsupported layout {self.layout}") for idx, reg in np.ndenumerate(self.registers): new_registers[idx] = convert(new_reg_ty, reg) - return FragmentedArray(_registers=new_registers, _layout=self.layout) + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=is_signed + ) def reduce_sum(self, scratch) -> ir.Value: index = ir.IndexType.get() @@ -593,7 +650,9 @@ def reduce(self, op, axis): ) result = op(result, other_result) new_regs[row_tile, row_subtile] = result - return FragmentedArray(_registers=new_regs, _layout=WGMMA_ROW_LAYOUT) + return FragmentedArray( + _registers=new_regs, _layout=WGMMA_ROW_LAYOUT, _is_signed=self.is_signed + ) def broadcast(self, shape): if not isinstance(self.layout, WGSplatFragLayout): @@ -605,7 +664,11 @@ def broadcast(self, shape): if not self.layout.can_broadcast_to(shape): raise ValueError(f"Can't broadcast {self.shape} to {shape}") - return FragmentedArray(_registers=self.registers, _layout=WGSplatFragLayout(shape)) + return FragmentedArray( + _registers=self.registers, + _layout=WGSplatFragLayout(shape), + _is_signed=self.is_signed, + ) def reshape(self, shape): if self.shape == shape: @@ -617,7 +680,11 @@ def reshape(self, shape): if np.prod(shape) != np.prod(self.shape): raise ValueError(f"Can't reshape {self.shape} to {shape}") - return FragmentedArray(_registers=self.registers, _layout=WGSplatFragLayout(shape)) + return FragmentedArray( + _registers=self.registers, + _layout=WGSplatFragLayout(shape), + _is_signed=self.is_signed, + ) def broadcast_minor(self, n): if self.layout != WGMMA_ROW_LAYOUT: @@ -632,7 +699,9 @@ def broadcast_minor(self, n): new_regs[row_tile, :, row_subtile, :] = vector.splat( ir.VectorType.get((2,), dtype), reg ) - return FragmentedArray(_registers=new_regs, _layout=WGMMA_LAYOUT) + return FragmentedArray( + _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed + ) def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]): """Call a function for each value and index.""" @@ -671,6 +740,7 @@ def _store_untiled_splat(self, ref: ir.Value): self.registers.flat[0], self.shape, layout=WGStridedFragLayout(shape=self.shape, vec_size=vec_size), + is_signed=self.is_signed, ) fa.store_untiled(ref) @@ -728,7 +798,9 @@ def store_tiled(self, ref, swizzle: int | None): vector.store(get(self.registers), ref, idxs) @classmethod - def load_tiled(cls, ref, swizzle: int | None): + def load_tiled( + cls, ref, swizzle: int | None, *, is_signed: bool | None = None + ): ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type bw = mgpu.bytewidth(dtype) @@ -744,7 +816,7 @@ def load_tiled(cls, ref, swizzle: int | None): ) for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle): update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs)) - return cls(_registers=registers, _layout=WGMMA_LAYOUT) + return cls(_registers=registers, _layout=WGMMA_LAYOUT, _is_signed=is_signed) @staticmethod def transfer_tiled(shape, dtype, swizzle: int | None): @@ -831,10 +903,11 @@ def update_registers(regs, new, left_idx=left_idx, right_idx=right_idx): yield get_register, update_registers, idx def tree_flatten(self): - return list(self.registers.flat), (self.layout, self.registers.shape) + aux = self.layout, self.registers.shape, self.is_signed + return list(self.registers.flat), aux @classmethod def tree_unflatten(cls, aux, flat_registers): - layout, reg_shape = aux + layout, reg_shape, is_signed = aux registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape) - return cls(_registers=registers, _layout=layout) + return cls(_registers=registers, _layout=layout, _is_signed=is_signed) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 30b8ca5cfb14..8d1d48eb94d1 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -23,13 +23,14 @@ from typing import Any, Literal import jax +from jax import numpy as jnp +from jax.interpreters import mlir from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import memref -from jaxlib.mlir.dialects import nvgpu from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import scf from jaxlib.mlir.dialects import vector @@ -939,3 +940,17 @@ def cluster_collective_mask( for idx in np.ndindex(collective_shape): mask_unshifted |= 1 << sum(i * s for i, s in zip(idx, collective_strides)) return arith.shli(c(mask_unshifted, i32), mask_shift) + + +def dtype_to_ir_type(dtype: jax.typing.DTypeLike) -> ir.Type: + dtype = jnp.dtype(dtype) + if jnp.issubdtype(dtype, jnp.integer): + # All integer types in Mosaic GPU are signless. + return ir.IntegerType.get_signless(dtype.itemsize * 8) + return mlir.dtype_to_ir_type(dtype) + + +def is_signed(dtype: jax.typing.DTypeLike) -> bool | None: + if jnp.issubdtype(dtype, jnp.integer): + return jnp.issubdtype(dtype, jnp.signedinteger) + return None diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index 5b0282080c55..ba0f130364ff 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -53,15 +53,19 @@ def __init__(self, *, _value: mgpu.FragmentedArray, _sync: bool = True): self.value = wgmma_fence(_value) @classmethod - def zero(cls, m, n, dtype=None): + def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None): if m % 64 or n % 8: raise ValueError + if is_signed is False: + raise TypeError("PTX does not support unsigned WGMMA accumulators") f32 = ir.F32Type.get() if dtype is None: dtype = f32 zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) return cls( - _value=mgpu.FragmentedArray.splat(zero, (m, n), mgpu.WGMMA_LAYOUT) + _value=mgpu.FragmentedArray.splat( + zero, (m, n), mgpu.WGMMA_LAYOUT, is_signed=is_signed + ) ) @classmethod @@ -430,7 +434,9 @@ def wgmma( ) return WGMMAAccumulator( _value=mgpu.FragmentedArray( - _registers=new_acc_regs, _layout=mgpu.WGMMA_LAYOUT + _registers=new_acc_regs, + _layout=mgpu.WGMMA_LAYOUT, + _is_signed=acc.value.is_signed, ), _sync=False, ) @@ -490,7 +496,7 @@ def wgmma_fence(array: mgpu.FragmentedArray): registers = np.asarray(regs, dtype=object).reshape(array.registers.shape) else: raise NotImplementedError(dtype) - return mgpu.FragmentedArray(_registers=registers, _layout=array.layout) + return mgpu.FragmentedArray(_registers=registers, _layout=array.layout, _is_signed=array.is_signed) def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 9f2f1222bde4..f6e0bc07e4ae 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -15,7 +15,6 @@ """Tests for Mosaic GPU DSL functions and utilities.""" import enum -from functools import partial import itertools import math import operator @@ -121,7 +120,7 @@ def body(*idx): nvvm.fence_proxy(nvvm.ProxyKind.async_) -def iota_tensor(m, n, mlir_dtype): +def iota_tensor(m, n, dtype: jax.typing.DTypeLike): assert m % 64 == 0 assert n % 8 == 0 def c(i): @@ -145,8 +144,12 @@ def c(i): value = arith.index_cast(i32, value) vec = vector.insertelement(value, vec, position=c(col_offset)) registers[row_tile, col_tile, row_subtile, 0] = vec - t = mgpu.FragmentedArray(_registers=registers, _layout=mgpu.WGMMA_LAYOUT) - return t.astype(mlir_dtype) + t = mgpu.FragmentedArray( + _registers=registers, _layout=mgpu.WGMMA_LAYOUT, _is_signed=True + ) + return t.astype( + utils.dtype_to_ir_type(dtype), is_signed=utils.is_signed(dtype) + ) class TestCase(parameterized.TestCase): @@ -199,7 +202,7 @@ def test_iota_tensor(self): def kernel(ctx, dst, _): f32 = ir.F32Type.get() index = ir.IndexType.get() - registers = iota_tensor(m, n, f32).registers + registers = iota_tensor(m, n, jnp.float32).registers assert registers.size == 16, registers.size for i, vec_reg in enumerate(registers.flat): for j in range(2): @@ -361,30 +364,30 @@ def get_packed_shape(strides, shape): class WGMMATest(TestCase): - @parameterized.named_parameters( - ("f32", ir.F32Type, jnp.float32), ("f16", ir.F16Type, jnp.float16) - ) - def test_store_untiled(self, mlir_dtype_cls, jax_dtype): - mlir_dtype = mlir_dtype_cls.get() + @parameterized.named_parameters(("f32", jnp.float32), ("f16", jnp.float16)) + def test_store_untiled(self, dtype): def kernel(ctx, out, _): del ctx - iota_tensor(64, 64, mlir_dtype).store_untiled(out) - expected = np.arange(64 * 64, dtype=jax_dtype).reshape(64, 64) + iota_tensor(64, 64, dtype).store_untiled(out) + expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64) iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() np.testing.assert_array_equal(iota, expected) @parameterized.named_parameters( - ("f32", ir.F32Type, jnp.float32, 256), - ("f16", ir.F16Type, jnp.float16, 256), - ("f16_small", ir.F16Type, jnp.float16, 128), + ("f32", jnp.float32, 256), + ("f16", jnp.float16, 256), + ("f16_small", jnp.float16, 128), ) - def test_store_untiled_splat(self, mlir_dtype_cls, jax_dtype, size): - mlir_dtype = mlir_dtype_cls.get() + def test_store_untiled_splat(self, jax_dtype, size): + mlir_dtype = utils.dtype_to_ir_type(jax_dtype) def kernel(ctx, out, _): del ctx - mgpu.FragmentedArray.splat(c(1., mlir_dtype), (size,)).store_untiled(out) + arr = mgpu.FragmentedArray.splat( + c(1.0, mlir_dtype), (size,), is_signed=utils.is_signed(jax_dtype) + ) + arr.store_untiled(out) expected = np.ones((size,), jax_dtype) mosaic_ones = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () @@ -392,17 +395,12 @@ def kernel(ctx, out, _): np.testing.assert_array_equal(mosaic_ones, expected) @parameterized.product( - dtypes=( - (ir.F32Type.get, jnp.float32), - (ir.F16Type.get, jnp.float16), - (partial(ir.IntegerType.get_signless, 8), jnp.int8), - ), + dtype=[jnp.float32, jnp.float16, jnp.int8], swizzle=(32, 64, 128), num_col_tiles=(1, 2, 3), ) - def test_store_tiled(self, dtypes, swizzle, num_col_tiles): - mlir_dtype_cls, jax_dtype = dtypes - mlir_dtype = mlir_dtype_cls() + def test_store_tiled(self, dtype, swizzle, num_col_tiles): + mlir_dtype = utils.dtype_to_ir_type(dtype) if bytewidth(mlir_dtype) > 2 and swizzle == 32: self.skipTest("Not implemented") col_tiling = swizzle // bytewidth(mlir_dtype) @@ -411,10 +409,10 @@ def test_store_tiled(self, dtypes, swizzle, num_col_tiles): tiling = (64, col_tiling) def kernel(ctx, out, smem): del ctx - iota_tensor(m, n, mlir_dtype).store_tiled(smem, swizzle=swizzle) + iota_tensor(m, n, dtype).store_tiled(smem, swizzle=swizzle) copy(smem, out, swizzle=swizzle) expected = ( - np.arange(m * n, dtype=jax_dtype) + np.arange(m * n, dtype=dtype) .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) .transpose(0, 2, 1, 3) ) @@ -424,21 +422,17 @@ def kernel(ctx, out, smem): np.testing.assert_array_equal(iota, expected) @parameterized.product( - dtypes=( - (ir.F16Type.get, jnp.float16), - (partial(ir.IntegerType.get_signless, 8), jnp.int8), - ), + dtype=[jnp.float16, jnp.int8], swizzle=(32, 64, 128), ) - def test_store_tiled_short_n(self, dtypes, swizzle): - mlir_dtype_cls, jax_dtype = dtypes - mlir_dtype = mlir_dtype_cls() + def test_store_tiled_short_n(self, dtype, swizzle): + mlir_dtype = utils.dtype_to_ir_type(dtype) col_tiling = swizzle // bytewidth(mlir_dtype) m = 128 n = 16 // bytewidth(mlir_dtype) tiling = (64, col_tiling) def kernel(ctx, out, smem): - iota_tensor(m, n, mlir_dtype).store_tiled(smem, swizzle=swizzle) + iota_tensor(m, n, dtype).store_tiled(smem, swizzle=swizzle) ctx.async_copy( src_ref=smem, dst_ref=out, @@ -447,37 +441,31 @@ def kernel(ctx, out, smem): gmem_transform=mgpu.TileTransform(tiling), ) ctx.await_async_copy(0) - smem_shape = jax.ShapeDtypeStruct((m // tiling[0], 1, *tiling), jax_dtype) - expected = np.arange(m * n, dtype=jax_dtype).reshape(m, n) + smem_shape = jax.ShapeDtypeStruct((m // tiling[0], 1, *tiling), dtype) + expected = np.arange(m * n, dtype=dtype).reshape(m, n) iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, smem_shape )() np.testing.assert_array_equal(iota, expected) @parameterized.named_parameters( - ("bf16_i8", - ir.BF16Type.get, jnp.bfloat16, - lambda: ir.IntegerType.get_signless(8), jnp.int8), - ("i8_bf16", - lambda: ir.IntegerType.get_signless(8), jnp.int8, - ir.BF16Type.get, jnp.bfloat16), - ("i8_i8", - lambda: ir.IntegerType.get_signless(8), jnp.int8, - lambda: ir.IntegerType.get_signless(8), jnp.int8), + ("bf16_i8", jnp.bfloat16, jnp.int8), + ("i8_bf16", jnp.int8, jnp.bfloat16), + ("i8_i8", jnp.int8, jnp.int8), ) - def test_convert_tiled(self, - mlir_dtype_cls_from, jax_dtype_from, - mlir_dtype_cls_to, jax_dtype_to): - mlir_dtype_from = mlir_dtype_cls_from() - mlir_dtype_to = mlir_dtype_cls_to() + def test_convert_tiled(self, jax_dtype_from, jax_dtype_to): + mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from) + mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) m = 128 n = 256 // bytewidth(mlir_dtype_from) def kernel(ctx, inp, out, smem): del ctx smem_from, smem_to = smem copy(inp, smem_from, swizzle=128) - t = mgpu.FragmentedArray.load_tiled(smem_from, swizzle=128) - t = t.astype(mlir_dtype_to) + t = mgpu.FragmentedArray.load_tiled( + smem_from, swizzle=128, is_signed=utils.is_signed(jax_dtype_from) + ) + t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to)) t.store_tiled(smem_to, swizzle=128) copy(smem_to, out, swizzle=128) @@ -503,12 +491,12 @@ def kernel(ctx, inp, out, smem): np.testing.assert_array_equal(res, expected_to) @parameterized.named_parameters( - ("f32", ir.F32Type.get, jnp.float32), - ("f16", ir.F16Type.get, jnp.float16), - ("i8", partial(ir.IntegerType.get_signless, 8), jnp.int8), + ("f32", jnp.float32), + ("f16", jnp.float16), + ("i8", jnp.int8), ) - def test_load_tiled(self, mlir_dtype_cls, jax_dtype): - mlir_dtype = mlir_dtype_cls() + def test_load_tiled(self, jax_dtype): + mlir_dtype = utils.dtype_to_ir_type(jax_dtype) m = 128 n = 256 // bytewidth(mlir_dtype) tiling = (64, 128 // bytewidth(mlir_dtype)) @@ -516,7 +504,9 @@ def kernel(ctx, in_, out, smem): del ctx smem1, smem2 = smem copy(in_, smem1, swizzle=128) - t = mgpu.FragmentedArray.load_tiled(smem1, swizzle=128) + t = mgpu.FragmentedArray.load_tiled( + smem1, swizzle=128, is_signed=utils.is_signed(jax_dtype) + ) t.store_tiled(smem2, swizzle=128) copy(smem2, out, swizzle=128) expected = ( @@ -560,7 +550,7 @@ def test_wgmma_basic( raise self.skipTest("Copy with non-128B swizzles not implemented") in_mlir_dtype = in_mlir_dtype_cls.get() - out_mlir_dtype = mlir.dtype_to_ir_type(jnp.dtype(jax_out_dtype)) + out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype) if ir.F32Type.isinstance(in_mlir_dtype): # We actually use tf32 instead in_jax_dtype = jnp.float32 if lhs_transpose or not rhs_transpose: @@ -680,11 +670,9 @@ def quantize(x): k_steps=(1, 2), rhs_transpose=(False, True), swizzle=(32, 64, 128), - mlir_dtype_cls=(ir.F16Type, ir.BF16Type), + dtype=[jnp.float16, jnp.bfloat16], ) - def test_wgmma_reg_lhs( - self, m, n, k_steps, rhs_transpose, swizzle, mlir_dtype_cls - ): + def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, swizzle, dtype): index = ir.IndexType.get() row_major = mgpu.WGMMALayout.ROW_MAJOR @@ -710,23 +698,22 @@ def kernel(ctx, rhs, out, rhs_smem): swizzle=swizzle, ) init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n) - lhs_regs = iota_tensor(m, k, mlir_dtype_cls.get()) + lhs_regs = iota_tensor(m, k, dtype) acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) acc.value.store_untiled(out) - jax_dtype = jnp.float16 if mlir_dtype_cls == ir.F16Type else jnp.bfloat16 y_shape = (n, k) if rhs_transpose else (k, n) - y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype) + y = self.prng.uniform(-1, 1, y_shape).astype(dtype) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) scratch_shape = jax.ShapeDtypeStruct( - (k_steps, n // nk_tile, nk_tile, nk_tile), jax_dtype + (k_steps, n // nk_tile, nk_tile, nk_tile), dtype ) z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape )(y) - x = np.arange(m * k, dtype=jax_dtype).reshape(m, k) + x = np.arange(m * k, dtype=dtype).reshape(m, k) ref = jax.lax.dot( x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32 ) @@ -766,7 +753,7 @@ def kernel(ctx, rhs, out, smem): ) barrier.wait() init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n) - lhs_regs = iota_tensor(m, k, ir.F16Type.get()) + lhs_regs = iota_tensor(m, k, jnp.float16) rhs_smem = memref_slice(rhs_smem, smem_slice) acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() @@ -804,17 +791,22 @@ def kernel(ctx, dst, scratch): arith.addi(wg_idx, c(1, i32)), (128,), mgpu.WGStridedFragLayout((128,), 1), + is_signed=False, ) with ir.InsertionPoint(scf.IfOp(is_first_wg).then_block): arr.store_untiled(tmp) barriers[0].arrive() # Signal that tmp is ready. barriers[1].wait() # Wait for the other warp to produce tmp. - final_arr = arr + mgpu.FragmentedArray.load_strided(tmp) + final_arr = arr + mgpu.FragmentedArray.load_strided( + tmp, is_signed=False + ) final_arr.store_untiled(memref_slice(dst, 0)) scf.yield_([]) with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block): barriers[0].wait() - final_arr = arr + mgpu.FragmentedArray.load_strided(tmp) + final_arr = arr + mgpu.FragmentedArray.load_strided( + tmp, is_signed=False + ) barriers[2].arrive() barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp. arr.store_untiled(tmp) @@ -1154,18 +1146,16 @@ def test_tma_small_tile_store(self, small_dim): else: raise ValueError("small_dim must be 0 or 1") tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64) - padded_shape = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2])) + m, n = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2])) def kernel(ctx, dst, tmp): - vals = iota_tensor( - m=padded_shape[0], n=padded_shape[1], mlir_dtype=ir.F16Type.get() - ) + vals = iota_tensor(m, n, jnp.float16) vals.store_tiled(tmp, swizzle=128) ctx.async_copy( src_ref=tmp, dst_ref=dst, swizzle=128, gmem_transform=mgpu.TileTransform((64, 64)), - gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])), + gmem_slice=(ds(0, m), ds(0, n)), ) ctx.await_async_copy(0) tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16) @@ -1173,9 +1163,7 @@ def kernel(ctx, dst, tmp): y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out, tiled, )() - iota = np.arange(np.prod(padded_shape), dtype=jnp.float16).reshape( - padded_shape - ) + iota = np.arange(m * n, dtype=jnp.float16).reshape([m, n]) np.testing.assert_array_equal(y, iota[:shape[0], :shape[1]]) def test_tma_invalid(self): @@ -1207,61 +1195,78 @@ class FragmentedArrayTest(TestCase): operator.truediv, (lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum), ), + dtype=[jnp.float32, jnp.int32, jnp.uint32], m=(64, 128), n=(8, 16, 32, 64, 80, 128, 256), ) - def test_binary(self, op, m=64, n=32): + def test_binary(self, op, dtype, m=64, n=32): if isinstance(op, tuple): op, np_op = op else: np_op = op + if not jnp.issubdtype(dtype, jnp.floating) and op is operator.truediv: + self.skipTest("Unsupported for integer types") + for scalar_rhs in [None, 2]: def kernel(ctx, dst, _): - f32 = ir.F32Type.get() - iota = iota_tensor(m=m, n=n, mlir_dtype=f32) - rhs = iota if scalar_rhs is None else c(scalar_rhs, iota.mlir_dtype) + mlir_dtype = utils.dtype_to_ir_type(dtype) + iota = iota_tensor(m, n, dtype) + rhs = iota if scalar_rhs is None else c(scalar_rhs, mlir_dtype) op(iota, rhs).store_untiled(dst) - out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() - ref_x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + ref_x = np.arange(m * n, dtype=dtype).reshape(m, n) ref_rhs = scalar_rhs or ref_x - if op == operator.truediv: + if op is operator.truediv: np.testing.assert_allclose(result, np_op(ref_x, ref_rhs), atol=2e-7) else: np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs)) @parameterized.product( ops=( - (lambda x: mgpu.FragmentedArray.exp(x), np.exp, False), - (lambda x: mgpu.FragmentedArray.exp(x, approx=True), np.exp, True), - (lambda x: mgpu.FragmentedArray.sin(x), np.sin, False), - (lambda x: mgpu.FragmentedArray.sin(x, approx=True), np.sin, True), - (lambda x: mgpu.FragmentedArray.cos(x), np.cos, False), - (lambda x: mgpu.FragmentedArray.cos(x, approx=True), np.cos, True), - (lambda x: mgpu.FragmentedArray.rsqrt(x), jax.lax.rsqrt, False), - (lambda x: mgpu.FragmentedArray.rsqrt(x, approx=True), jax.lax.rsqrt, True), - (lambda x: -x, jax.lax.neg, False), - (lambda x: x + 42.0, lambda x: x + 42.0, False), + (lambda x: -x, jax.lax.neg), + (lambda x: x + 42, lambda x: x + 42), ), - m=(64, 128), - n=(8, 16, 32, 64, 80, 128, 256), + dtype=[jnp.float32, jnp.int32, jnp.uint32], ) - def test_unary(self, ops, m=64, n=32): - op, np_op, is_approx = ops + def test_unary(self, ops, dtype, m=64, n=32): + op, np_op = ops + def kernel(ctx, dst, _): - f32 = ir.F32Type.get() - iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + iota = iota_tensor(m, n, dtype) + op(iota).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + x = np.arange(m * n, dtype=dtype).reshape(m, n) + np.testing.assert_allclose(result, np_op(x), atol=2e-7, rtol=2e-7) + + @parameterized.product( + ops=[ + (lambda x: mgpu.FragmentedArray.exp(x), np.exp), + (lambda x: mgpu.FragmentedArray.sin(x), np.sin), + (lambda x: mgpu.FragmentedArray.cos(x), np.cos), + (lambda x: mgpu.FragmentedArray.rsqrt(x), jax.lax.rsqrt), + ], + approx=[False, True], + ) + def test_math(self, ops, approx, m=64, n=32): + op, np_op = ops + def kernel(ctx, dst, _): + iota = iota_tensor(m, n, jnp.float32) op(iota).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) - atol = 5e-3 if is_approx else 2e-7 - rtol = 4e-6 if is_approx else 2e-7 + atol = 5e-3 if approx else 2e-7 + rtol = 4e-6 if approx else 2e-7 np.testing.assert_allclose(result, np_op(x), atol=atol, rtol=rtol) @parameterized.product( @@ -1271,8 +1276,7 @@ def kernel(ctx, dst, _): ) def test_reduce(self, op, m=64, n=32): def kernel(ctx, dst, _): - f32 = ir.F32Type.get() - iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + iota = iota_tensor(m, n, jnp.float32) iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( @@ -1290,8 +1294,7 @@ def kernel(ctx, dst, _): def test_splat_layout(self): m, n = 64, 8 def kernel(ctx, dst, _): - f32 = ir.F32Type.get() - iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + iota = iota_tensor(m, n, jnp.float32) cte = c(1, iota.mlir_dtype) cte_arr = mgpu.FragmentedArray.splat(cte, ()) cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n)) @@ -1307,7 +1310,9 @@ def test_splat(self): def kernel(ctx, dst, _): f32 = ir.F32Type.get() v = arith.constant(f32, ir.FloatAttr.get(f32, 3.14)) - t = mgpu.FragmentedArray.splat(v, (128,), mgpu.WGMMA_ROW_LAYOUT) + t = mgpu.FragmentedArray.splat( + v, (128,), mgpu.WGMMA_ROW_LAYOUT + ) t.broadcast_minor(32).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) result = mgpu.as_gpu_kernel( @@ -1354,10 +1359,10 @@ def kernel(ctx, out, *_): def test_fast_i8_convert(self, jax_dtype_to): jax_dtype_to = jnp.dtype(jax_dtype_to) jax_dtype_from = jnp.dtype(jnp.int8) - mlir_dtype_to = mlir.dtype_to_ir_type(jax_dtype_to) + mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) def kernel(ctx, inp, out, smem): del ctx, smem - arr = mgpu.FragmentedArray.load_strided(inp) + arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) arr.astype(mlir_dtype_to).store_untiled(out) x = jnp.arange(-128, 128, dtype=jax_dtype_from) From 6a72c5229258e082d07bd39a2957ff6108bffb59 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 23 Sep 2024 19:40:09 +0530 Subject: [PATCH 609/702] Improve docs for jax.numpy: conjugate, conj, imag and real --- jax/_src/numpy/ufuncs.py | 82 ++++++++++++++++++++++++++++++++++++++-- tests/lax_numpy_test.py | 4 +- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 2455817bf054..3d3dc3588fc8 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2248,24 +2248,98 @@ def radians(x: ArrayLike, /) -> Array: return deg2rad(x) -@implements(np.conjugate, module='numpy') @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: + """Return element-wise complex-conjugate of the input. + + JAX implementation of :obj:`numpy.conjugate`. + + Args: + x: inpuat array or scalar. + + Returns: + An array containing the complex-conjugate of ``x``. + + See also: + - :func:`jax.numpy.real`: Returns the element-wise real part of the complex + argument. + - :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the + complex argument. + + Examples: + >>> jnp.conjugate(3) + Array(3, dtype=int32, weak_type=True) + >>> x = jnp.array([2-1j, 3+5j, 7]) + >>> jnp.conjugate(x) + Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64) + """ check_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) -conj = conjugate -@implements(np.imag) +def conj(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.conjugate`""" + return conjugate(x) + + @partial(jit, inline=True) def imag(val: ArrayLike, /) -> Array: + """Return element-wise imaginary of part of the complex argument. + + JAX implementation of :obj:`numpy.imag`. + + Args: + val: input array or scalar. + + Returns: + An array containing the imaginary part of the elements of ``val``. + + See also: + - :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise + complex-conjugate of the input. + - :func:`jax.numpy.real`: Returns the element-wise real part of the complex + argument. + + Examples: + >>> jnp.imag(4) + Array(0, dtype=int32, weak_type=True) + >>> jnp.imag(5j) + Array(5., dtype=float32, weak_type=True) + >>> x = jnp.array([2+3j, 5-1j, -3]) + >>> jnp.imag(x) + Array([ 3., -1., 0.], dtype=float32) + """ check_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) -@implements(np.real) @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: + """Return element-wise real part of the complex argument. + + JAX implementation of :obj:`numpy.real`. + + Args: + val: input array or scalar. + + Returns: + An array containing the real part of the elements of ``val``. + + See also: + - :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise + complex-conjugate of the input. + - :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the + complex argument. + + Examples: + >>> jnp.real(5) + Array(5, dtype=int32, weak_type=True) + >>> jnp.real(2j) + Array(0., dtype=float32, weak_type=True) + >>> x = jnp.array([3-2j, 4+7j, -2j]) + >>> jnp.real(x) + Array([ 3., 4., -0.], dtype=float32) + """ check_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 9dc2e079bb3f..371a13f0cde6 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6308,8 +6308,8 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', - 'amax', 'amin', 'around', 'bitwise_right_shift', 'degrees', 'divide', - 'mod', 'pow', 'radians', 'round_'] + 'amax', 'amin', 'around', 'bitwise_right_shift', 'conj', 'degrees', + 'divide', 'mod', 'pow', 'radians', 'round_'] skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split'] for name in dir(jnp): From 1256e18fd45a081edca4f1c2a865546d317625a2 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 23 Sep 2024 07:37:15 -0700 Subject: [PATCH 610/702] Added comparison operators to `mgpu.FragmentedArray` PiperOrigin-RevId: 677788023 --- .../mosaic/gpu/fragmented_array.py | 73 ++++++++++++++++++- tests/mosaic/gpu_test.py | 23 ++++++ 2 files changed, 92 insertions(+), 4 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 37deee6130d5..6cdda1a4ff62 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -15,6 +15,7 @@ """Utilities for code generator.""" import dataclasses +import functools import math from typing import Callable @@ -276,7 +277,11 @@ def mlir_dtype(self): case WGMMARowFragLayout() | WGSplatFragLayout(): return reg_ty - def _pointwise(self, op, *other): + def _pointwise(self, op, *other, output_is_signed: bool | None = None): + is_signed = ( + output_is_signed if output_is_signed is not None else self.is_signed + ) + other_arrs = [] for o in other: if not isinstance(o, FragmentedArray): @@ -286,7 +291,7 @@ def _pointwise(self, op, *other): raise NotImplementedError(o) o = FragmentedArray.splat( - o, shape=self.shape, layout=self.layout, is_signed=self.is_signed + o, shape=self.shape, layout=self.layout, is_signed=is_signed ) if isinstance(o.layout, WGSplatFragLayout): @@ -296,7 +301,7 @@ def _pointwise(self, op, *other): o.registers.flat[0], shape=self.shape, layout=self.layout, - is_signed=self.is_signed, + is_signed=is_signed, ) else: if self.layout != o.layout: @@ -310,7 +315,7 @@ def _pointwise(self, op, *other): for idx, reg in np.ndenumerate(self.registers): new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs)) return FragmentedArray( - _registers=new_regs, _layout=self.layout, _is_signed=self.is_signed + _registers=new_regs, _layout=self.layout, _is_signed=is_signed ) def __pos__(self): @@ -372,6 +377,66 @@ def __rtruediv__(self, other): raise NotImplementedError return self._pointwise(lambda s, o: arith.divf(o, s), other) + def __eq__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.OEQ, + si_pred=arith.CmpIPredicate.eq, + ui_pred=arith.CmpIPredicate.eq, + ) + + def __ne__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.UNE, + si_pred=arith.CmpIPredicate.ne, + ui_pred=arith.CmpIPredicate.ne, + ) + + def __lt__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.OLT, + si_pred=arith.CmpIPredicate.slt, + ui_pred=arith.CmpIPredicate.ult, + ) + + def __le__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.OLE, + si_pred=arith.CmpIPredicate.sle, + ui_pred=arith.CmpIPredicate.ule, + ) + + def __gt__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.OGT, + si_pred=arith.CmpIPredicate.sgt, + ui_pred=arith.CmpIPredicate.ugt, + ) + + def __ge__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.OGE, + si_pred=arith.CmpIPredicate.sge, + ui_pred=arith.CmpIPredicate.uge, + ) + + def _compare(self, other, *, f_pred, si_pred, ui_pred): + if ir.FloatType.isinstance(self.mlir_dtype): + pred = functools.partial(arith.cmpf, f_pred) + elif ir.IntegerType.isinstance(self.mlir_dtype): + if ir.IntegerType(self.mlir_dtype).is_signed: + pred = functools.partial(arith.cmpi, si_pred) + else: + pred = functools.partial(arith.cmpi, ui_pred) + else: + raise NotImplementedError + return self._pointwise(pred, other, output_is_signed=False) + def max(self, other): if ir.FloatType.isinstance(self.mlir_dtype): return self._pointwise(arith.maximumf, other) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f6e0bc07e4ae..41725e6017f4 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1225,6 +1225,29 @@ def kernel(ctx, dst, _): else: np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs)) + @parameterized.product( + op=[ + operator.lt, + operator.le, + operator.gt, + operator.ge, + operator.eq, + operator.ne, + ], + dtype=[jnp.float32, jnp.int32, jnp.uint32], + ) + def test_comparison(self, op, dtype, m=64, n=32): + def kernel(ctx, dst, _): + iota = iota_tensor(m, n, dtype) + op(iota, iota + 1).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((m, n), jnp.bool) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + iota = np.arange(m * n, dtype=dtype).reshape(m, n) + np.testing.assert_array_equal(result, op(iota, iota + 1)) + @parameterized.product( ops=( (lambda x: -x, jax.lax.neg), From 41eccd925d74ccc8a93ed54797b38e68326096c0 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 23 Sep 2024 20:09:12 +0530 Subject: [PATCH 611/702] Improve docs for jnp.logspace and jnp.geomspace --- jax/_src/numpy/lax_numpy.py | 108 +++++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b71412e586bd..e9c5785a7715 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5896,10 +5896,69 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return (result, delta) if retstep else result -@util.implements(np.logspace) def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, base: ArrayLike = 10.0, dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Generate logarithmically-spaced values. + + JAX implementation of :func:`numpy.logspace`. + + Args: + start: scalar or array. Used to specify the start value. The start value is + ``base ** start``. + stop: scalar or array. Used to specify the stop value. The end value is + ``base ** stop``. + num: int, optional, default=50. Number of values to generate. + endpoint: bool, optional, default=True. If True, then include the ``stop`` value + in the result. If False, then exclude the ``stop`` value. + base: scalar or array, optional, default=10. Specifies the base of the logarithm. + dtype: optional. Specifies the dtype of the output. + axis: int, optional, default=0. Axis along which to generate the logspace. + + Returns: + An array of logarithm. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step value. + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. + + Examples: + List 5 logarithmically spaced values between 1 (``10 ** 0``) and 100 + (``10 ** 2``): + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 5) + Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32) + + List 5 logarithmically-spaced values between 1(``10 ** 0``) and 100 + (``10 ** 2``), excluding endpoint: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 5, endpoint=False) + Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32) + + List 7 logarithmically-spaced values between 1 (``2 ** 0``) and 4 (``2 ** 2``) + with base 2: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 7, base=2) + Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32) + + Multi-dimensional logspace: + + >>> start = jnp.array([0, 5]) + >>> stop = jnp.array([5, 0]) + >>> base = jnp.array([2, 3]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(start, stop, 5, base=base) + Array([[ 1. , 243. ], + [ 2.378, 61.547], + [ 5.657, 15.588], + [ 13.454, 3.948], + [ 32. , 1. ]], dtype=float32) + """ num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") return _logspace(start, stop, num, endpoint, base, dtype, axis) @@ -5922,9 +5981,54 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return lax.convert_element_type(ufuncs.power(base, lin), dtype) -@util.implements(np.geomspace) def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Generate geometrically-spaced values. + + JAX implementation of :func:`numpy.geomspace`. + + Args: + start: scalar or array. Specifies the starting values. + stop: scalar or array. Specifies the stop values. + num: int, optional, default=50. Number of values to generate. + endpoint: bool, optional, default=True. If True, then include the ``stop`` value + in the result. If False, then exclude the ``stop`` value. + dtype: optional. Specifies the dtype of the output. + axis: int, optional, default=0. Axis along which to generate the geomspace. + + Returns: + An array containing the geometrically-spaced values. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step value. + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. + + Examples: + List 5 geometrically-spaced values between 1 and 16: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(1, 16, 5) + Array([ 1., 2., 4., 8., 16.], dtype=float32) + + List 4 geomtrically-spaced values between 1 and 16, with ``endpoint=False``: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(1, 16, 4, endpoint=False) + Array([1., 2., 4., 8.], dtype=float32) + + Multi-dimensional geomspace: + + >>> start = jnp.array([1, 1000]) + >>> stop = jnp.array([27, 1]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(start, stop, 4) + Array([[ 1., 1000.], + [ 3., 100.], + [ 9., 10.], + [ 27., 1.]], dtype=float32) + """ num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") return _geomspace(start, stop, num, endpoint, dtype, axis) From 6c52ddc97f1402c6eeeb5da1d99ac06381921d38 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 23 Sep 2024 08:10:29 -0700 Subject: [PATCH 612/702] [Checkify] Add checks for shard_map. PiperOrigin-RevId: 677798938 --- jax/_src/checkify.py | 59 ++++++++++++++++++++++++++++++++++++++++++ tests/checkify_test.py | 41 +++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index e67f624fc32e..32cc4feb9054 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -25,6 +25,7 @@ from jax import dtypes from jax import lax +from jax.experimental import shard_map from jax._src import api from jax._src import linear_util as lu from jax._src import config @@ -931,6 +932,64 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, return tree_unflatten(out_tree, err_and_out) error_checks[pjit.pjit_p] = pjit_error_check + +def shard_map_error_check( + error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs +): + if (mesh := kwargs.get('mesh')) is None: + raise ValueError('Mesh must be provided for shard_map with checkify.') + + err_vals, err_tree = jtu.tree_flatten(error) + num_error_vals = len(err_vals) + # Replicated sharding for in errors. + new_in_names = (*([{}] * num_error_vals), *in_names) + new_vals_in = [*err_vals, *vals_in] + in_avals = list(map(get_shaped_aval, new_vals_in)) + for i, v in enumerate(in_avals): + if not (sharder := core.shard_aval_handlers.get(type(v))): + raise ValueError(f'Unsupported aval type: {type(v)}') + in_avals[i] = sharder(mesh, new_in_names[i], v) + + if not isinstance(jaxpr, core.ClosedJaxpr): + jaxpr = core.ClosedJaxpr(jaxpr, ()) + with core.extend_axis_env_nd(mesh.shape.items()): + # jaxpr to checked_jaxpr + checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( + jaxpr, enabled_errors, err_tree, *in_avals + ) + num_out_error_vals = out_tree.num_leaves - len(out_names) + + @lu.wrap_init + def expand_errors_leading_dim(*xs): + outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs) + errs, outs = split_list(outs, [num_out_error_vals]) + errs = [lax.expand_dims(e, [0]) for e in errs] + return *errs, *outs + + with core.extend_axis_env_nd(mesh.shape.items()): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + expand_errors_leading_dim, checked_jaxpr.in_avals + ) + checked_jaxpr = core.ClosedJaxpr(jaxpr, consts) + + # Update shard_map params to account for extra error values. + # Use fully sharded partitioning for out errors. + new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names) + subfun = lu.hashable_partial( + lu.wrap_init(core.eval_jaxpr), checked_jaxpr.jaxpr, checked_jaxpr.consts + ) + new_params = dict( + jaxpr=checked_jaxpr.jaxpr, + in_names=new_in_names, + out_names=new_out_names, + **kwargs, + ) + _, new_params = shard_map.shard_map_p.get_bind_params(new_params) + + err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params) + return tree_unflatten(out_tree, err_and_out) +error_checks[shard_map.shard_map_p] = shard_map_error_check + def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts, jvp_jaxpr_thunk, call_jaxpr, **params): # The types to have in mind are: diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 726e89d1b3e9..24387a767659 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -23,6 +23,7 @@ from jax import lax from jax.experimental import checkify from jax.experimental import pjit +from jax.experimental import shard_map from jax.sharding import NamedSharding from jax._src import array from jax._src import config @@ -539,6 +540,46 @@ def g(x, y): self.assertIsNotNone(b_err.get()) self.assertStartsWith(b_err.get(), "division by zero") + @parameterized.parameters(True, False) + def test_shard_map(self, check_rep): + def f(x): + # unary func + return jax.lax.axis_index("dev") * x / x + + def g(x, y): + # binary func + return jax.lax.axis_index("dev") * x / y + + devices = jax.local_devices()[:8] # Taking up to 8 devices + mesh = jax.sharding.Mesh(np.array(devices), ["dev"]) + pspec = jax.sharding.PartitionSpec("dev") + ps = NamedSharding(mesh, pspec) + inp = np.tile(np.arange(4, dtype=np.int32), 2) + x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx]) + + f = shard_map.shard_map( + f, mesh, in_specs=pspec, out_specs=pspec, check_rep=check_rep + ) + f = jax.jit(f, in_shardings=ps, out_shardings=ps) + f = checkify.checkify(f, errors=checkify.float_checks) + g = shard_map.shard_map( + g, mesh, in_specs=(pspec, pspec), out_specs=pspec, check_rep=check_rep + ) + g = jax.jit(g, in_shardings=(ps, ps), out_shardings=ps) + g = checkify.checkify(g, errors=checkify.float_checks) + u_err, _ = f(x) + b_err, _ = g(x, x) + + divbyzero = "division by zero" + expected_err = f"at mapped index 0: {divbyzero}" + if (next_device_with_zero := len(devices) // 2) != 0: + expected_err += f"\nat mapped index {next_device_with_zero}: {divbyzero}" + + self.assertIsNotNone(u_err.get()) + self.assertEqual(u_err.get(), expected_err) + self.assertIsNotNone(b_err.get()) + self.assertEqual(b_err.get(), expected_err) + def test_empty_enabled_errors(self): def multi_errors(x): x = x/0 # DIV From e976dee4de39fc941c3fb68d043aaff9f47e1b60 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 23 Sep 2024 21:10:26 +0530 Subject: [PATCH 613/702] Improve docs for jax.numpy: square, sqrt and modf --- jax/_src/numpy/ufuncs.py | 97 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 94 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 3d3dc3588fc8..a0611f0bd6e1 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -636,9 +636,36 @@ def tanh(x: ArrayLike, /) -> Array: def arctanh(x: ArrayLike, /) -> Array: return lax.atanh(*promote_args_inexact('arctanh', x)) -@implements(np.sqrt, module='numpy') + @partial(jit, inline=True) def sqrt(x: ArrayLike, /) -> Array: + """Calculates element-wise non-negative square root of the input array. + + JAX implementation of :obj:`numpy.sqrt`. + + Args: + x: input array or scalar. + + Returns: + An array containing the non-negative square root of the elements of ``x``. + + Note: + - For real-valued negative inputs, ``jnp.sqrt`` produces a ``nan`` output. + - For complex-valued negative inputs, ``jnp.sqrt`` produces a ``complex`` output. + + See also: + - :func:`jax.numpy.square`: Calculates the element-wise square of the input. + - :func:`jax.numpy.power`: Calculates the element-wise base ``x1`` exponential + of ``x2``. + + Examples: + >>> x = jnp.array([-8-6j, 1j, 4]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.sqrt(x) + Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64) + >>> jnp.sqrt(-1) + Array(nan, dtype=float32, weak_type=True) + """ return lax.sqrt(*promote_args_inexact('sqrt', x)) @implements(np.cbrt, module='numpy') @@ -2162,9 +2189,50 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.rem(*promote_args_numeric("fmod", x1, x2)) -@implements(np.square, module='numpy') @partial(jit, inline=True) def square(x: ArrayLike, /) -> Array: + """Calculate element-wise square of the input array. + + JAX implementation of :obj:`numpy.square`. + + Args: + x: input array or scalar. + + Returns: + An array containing the square of the elements of ``x``. + + Note: + ``jnp.square`` is equivalent to computing ``jnp.power(x, 2)``. + + See also: + - :func:`jax.numpy.sqrt`: Calculates the element-wise non-negative square root + of the input array. + - :func:`jax.numpy.power`: Calculates the element-wise base ``x1`` exponential + of ``x2``. + - :func:`jax.lax.integer_pow`: Computes element-wise power :math:`x^y`, where + :math:`y` is a fixed integer. + - :func:`jax.numpy.float_power`: Computes the first array raised to the power + of second array, element-wise, by promoting to the inexact dtype. + + Examples: + >>> x = jnp.array([3, -2, 5.3, 1]) + >>> jnp.square(x) + Array([ 9. , 4. , 28.090002, 1. ], dtype=float32) + >>> jnp.power(x, 2) + Array([ 9. , 4. , 28.090002, 1. ], dtype=float32) + + For integer inputs: + + >>> x1 = jnp.array([2, 4, 5, 6]) + >>> jnp.square(x1) + Array([ 4, 16, 25, 36], dtype=int32) + + For complex-valued inputs: + + >>> x2 = jnp.array([1-3j, -1j, 2]) + >>> jnp.square(x2) + Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64) + """ check_arraylike("square", x) x, = promote_dtypes_numeric(x) return lax.integer_pow(x, 2) @@ -2343,9 +2411,32 @@ def real(val: ArrayLike, /) -> Array: check_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) -@implements(np.modf, module='numpy', skip_params=['out']) + @jit def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: + """Return element-wise fractional and integral parts of the input array. + + JAX implementation of :obj:`numpy.modf`. + + Args: + x: input array or scalar. + out: Not used by JAX. + + Returns: + An array containing the fractional and integral parts of the elements of ``x``, + promoting dtypes inexact. + + See also: + - :func:`jax.numpy.divmod`: Calculates the integer quotient and remainder of + ``x1`` by ``x2`` element-wise. + + Examples: + >>> jnp.modf(4.8) + (Array(0.8000002, dtype=float32, weak_type=True), Array(4., dtype=float32, weak_type=True)) + >>> x = jnp.array([-3.4, -5.7, 0.6, 1.5, 2.3]) + >>> jnp.modf(x) + (Array([-0.4000001 , -0.6999998 , 0.6 , 0.5 , 0.29999995], dtype=float32), Array([-3., -5., 0., 1., 2.], dtype=float32)) + """ check_arraylike("modf", x) x, = promote_dtypes_inexact(x) if out is not None: From 91f16419bb3b1670cf31c55a4e8dde194bb73f53 Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Mon, 23 Sep 2024 09:06:35 -0700 Subject: [PATCH 614/702] Fix errata in block-sparse kernel tutorial. Correct M//blk_M to N//blk_N. It was ok because both values happen to be same. In addition, grid order is (num_blocks, j) as 'num_blocks' replaces 'i'. PiperOrigin-RevId: 677817478 --- docs/pallas/tpu/sparse.ipynb | 8 ++++---- docs/pallas/tpu/sparse.md | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index 909103273e1e..6666b9e3ec5a 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -312,13 +312,13 @@ " o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n", "\n", "\n", - "def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", + "def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", " del j, blk_idxs_i, blk_idxs_k\n", " return (blk_idx, 0, 0)\n", - "def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", + "def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", " del blk_idxs_i\n", " return (blk_idxs_k[blk_idx], j)\n", - "def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):\n", + "def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", " del blk_idxs_k\n", " return (blk_idxs_i[blk_idx], j)\n", "\n", @@ -333,7 +333,7 @@ " num_scalar_prefetch=2,\n", " # Note that while num_blocks is static here, Pallas does support\n", " # dynamic grid sizes.\n", - " grid=(M // blk_M, num_blocks),\n", + " grid=(num_blocks, N // blk_N),\n", " in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n", " pl.BlockSpec((blk_K, blk_N), y_map),\n", " # Placeholder for a zeros-array used by input_output_aliases.\n", diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 23e14bb9bc0b..3bc662895654 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -252,13 +252,13 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs. o_ref[...] = accum_scratch[...].astype(o_ref.dtype) -def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k): +def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k): del j, blk_idxs_i, blk_idxs_k return (blk_idx, 0, 0) -def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k): +def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k): del blk_idxs_i return (blk_idxs_k[blk_idx], j) -def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k): +def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k): del blk_idxs_k return (blk_idxs_i[blk_idx], j) @@ -273,7 +273,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=2, # Note that while num_blocks is static here, Pallas does support # dynamic grid sizes. - grid=(M // blk_M, num_blocks), + grid=(num_blocks, N // blk_N), in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map), pl.BlockSpec((blk_K, blk_N), y_map), # Placeholder for a zeros-array used by input_output_aliases. From 3134ece9b7e305234cd6c79a24ed4a08ba7370c8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 18 Sep 2024 14:24:38 -0700 Subject: [PATCH 615/702] ufuncs: improve jnp.add.at & jnp.multiply.at --- jax/_src/numpy/ufuncs.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 3d3dc3588fc8..24587b57a0d9 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -21,6 +21,7 @@ from collections.abc import Callable from functools import partial import operator +from typing import Any import numpy as np @@ -2493,6 +2494,20 @@ def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = No result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where) return result if dtype is None else result.astype(dtype) +def _add_at(a: Array, indices: Any, b: ArrayLike): + if a.dtype == bool: + a = a.astype('int32') + b = lax.convert_element_type(b, bool).astype('int32') + return a.at[indices].add(b).astype(bool) + return a.at[indices].add(b) + +def _multiply_at(a: Array, indices: Any, b: ArrayLike): + if a.dtype == bool: + a = a.astype('int32') + b = lax.convert_element_type(b, bool).astype('int32') + return a.at[indices].mul(b).astype(bool) + else: + return a.at[indices].mul(b) # Generate ufunc interfaces for several common binary functions. # We start with binary ufuncs that have well-defined identities.' @@ -2501,8 +2516,8 @@ def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = No # - define add.at/multiply.at in terms of scatter_add/scatter_mul # - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod # - define all monoidal reductions in terms of lax.reduce -add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum) -multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod) +add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) +multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and) bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or) bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor) From d29a757e3051b3a114bf006b51151bbf04bbadd9 Mon Sep 17 00:00:00 2001 From: kaixih Date: Fri, 20 Sep 2024 18:16:20 +0000 Subject: [PATCH 616/702] fix bwd batcher for unsupported dbias --- jax/_src/cudnn/fused_attention_stablehlo.py | 10 +++++- tests/fused_attention_stablehlo_test.py | 39 +++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 171954bc86c5..e20271f66301 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -676,6 +676,11 @@ def _dot_product_attention_bwd_batcher( *_, S, _, _ = key.shape B = math.prod(Bs) has_bias, has_dbias = variadic_args + # Reset the has_dbias if the combined batch size is not 1, because cuDNN only + # supports dbias with a single batch. In this case, an all-zero dbias will be + # appended instead. + if B > 1: + variadic_args = (has_bias, False) original_query_shape = query.shape original_key_shape = key.shape original_value_shape = value.shape @@ -708,7 +713,10 @@ def _dot_product_attention_bwd_batcher( grads[2] = jnp.reshape(grads[2], original_value_shape) if has_dbias: assert has_bias - grads[3] = jnp.reshape(grads[3], original_bias_shape) + if variadic_args[1]: + grads[3] = jnp.reshape(grads[3], original_bias_shape) + else: + grads.append(jnp.zeros(original_bias_shape, bias.dtype)) out_bdims += (batch_dims[3],) return grads, out_bdims diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index c70aa2cd5a00..2cfcfa7c5ec6 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -419,6 +419,45 @@ def test_sdpa_broadcast_bias_and_dbias(self): self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) self.assertArraysAllClose(bias_grad_ref, bias_grad, rtol=1e-5, atol=1e-5) + @jtu.sample_product( + batch_size=[1, 16], + ) + @jtu.run_on_devices("cuda") + def test_sdpa_dbias(self, batch_size: int): + # cuDNN only supports dbias when batch size is 1. If the batch size is + # greater, dbias is silently set to all zeros. This test verifies this + # behavior for both vmap and regular use cases. + # TODO: Remove this test once cuDNN adds broader dbias support. + dtype = jnp.bfloat16 + x_shape = (batch_size, 512, 16, 48) + bias_shape = (batch_size, 16, 512, 512) + mask_shape = (1, 1, 512) + + keys = jax.random.split(jax.random.key(0), 2) + x = jax.random.normal(keys[0], x_shape, dtype=dtype) + bias = jax.random.normal(keys[1], bias_shape, dtype=dtype) + mask = jnp.ones(mask_shape, dtype=jnp.bool_) + + def attn(x, bias, mask): + return dot_product_attention(x, x, x, bias, mask) + + def attn_vjp(x, bias, mask, target_fn): + _, f_vjp = jax.vjp(target_fn, x, bias, mask) + return f_vjp(x) + + attn_vmap = jax.vmap(attn, in_axes=(0, 0, None)) + attn_ref = jax.jit(partial(attn_vjp, target_fn=attn)) + attn_ans = jax.jit(partial(attn_vjp, target_fn=attn_vmap)) + + _, dbias_ref, _ = attn_ref(x, bias, mask) + x = jnp.expand_dims(x, axis=1) + bias = jnp.expand_dims(bias, axis=1) + _, dbias_ans, _ = attn_ans(x, bias, mask) + dbias_ans = jnp.squeeze(dbias_ans, axis=1) + self.assertArraysAllClose(dbias_ans, dbias_ref) + if batch_size != 1: + self.assertTrue(not jnp.any(dbias_ans)) + @jtu.run_on_devices("cuda") def test_sdpa_sliding_window_length(self): k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) From 3e19a28b09a43b24d78f190c53bfbf50a6b906fd Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 23 Sep 2024 11:05:41 -0700 Subject: [PATCH 617/702] [pallas:mosaic_gpu] Basic implementation of wgmma. PiperOrigin-RevId: 677864187 --- jax/_src/pallas/mosaic_gpu/BUILD | 3 + jax/_src/pallas/mosaic_gpu/__init__.py | 5 +- jax/_src/pallas/mosaic_gpu/core.py | 4 +- jax/_src/pallas/mosaic_gpu/primitives.py | 175 ++++++++++++++++++++++- tests/pallas/mosaic_gpu_test.py | 38 +++++ 5 files changed, 221 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 171ff0439085..fd291b201fa1 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -88,7 +88,10 @@ pytype_strict_library( ":lowering", "//jax", "//jax:core", + "//jax:effects", + "//jax:mlir", "//jax:mosaic_gpu", + "//jax/_src/lib", "//jax/_src/pallas", ], ) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index 1bd512834ce5..ddf27361493a 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -18,10 +18,13 @@ from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace -from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem +from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wait_barrier from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import zero_accumulator +from jax._src.pallas.mosaic_gpu.primitives import wgmma +from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait GMEM = GPUMemorySpace.GMEM SMEM = GPUMemorySpace.SMEM diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 6ef4cd1621f4..b6d2ada5ee28 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -68,6 +68,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: ... +@dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): """Represents a tiling transformation for memory refs. @@ -76,8 +77,7 @@ class TilingTransform(MemoryRefTransform): tiling of (64, 32) will be tiled as (4, 8, 64, 32). """ - def __init__(self, tiling: tuple[int, ...]): - self.tiling = tiling + tiling: tuple[int, ...] def __call__( self, block_aval: pallas_core.AbstractMemoryRef diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index e96574612bfa..ef30dd0956ec 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -16,12 +16,18 @@ from __future__ import annotations +import dataclasses + from jax._src import core as jax_core +from jax._src import effects from jax._src import state +from jax._src.interpreters import mlir +from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering - +import jax.experimental.mosaic.gpu as mgpu +import jax.numpy as jnp async_copy_p = jax_core.Primitive("async_copy") async_copy_p.multiple_results = True @@ -103,3 +109,170 @@ def wait_smem_to_gmem(allow_groups: int) -> None: def wait_barrier(barrier: pallas_core.AbstractMemoryRef) -> None: """Waits on the given barrier.""" wait_p.bind(barrier) + + +class _WGMMAPipelineEffect(effects.Effect): + pass + + +_wgmma_pipeline_effect = _WGMMAPipelineEffect() +effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect) + + +# Not a shaped array to avoid unexpected operations. +class WGMMAAbstractAccumulator(jax_core.AbstractValue): + __slots__ = ['shape', 'dtype'] + + def __init__(self, shape, dtype): + self.shape = shape + self.dtype = dtype + + def __eq__(self, other): + return (type(self) is type(other) + and self.dtype == other.dtype and self.shape == other.shape) + + def __hash__(self): + return hash((self.shape, self.dtype)) + + def update(self, shape=None, dtype=None): + if shape is None: + shape = self.shape + if dtype is None: + dtype = self.dtype + return WGMMAAbstractAccumulator(shape, dtype) + + def str_short(self, short_dtypes=False) -> str: + del short_dtypes + shapestr = ",".join(map(str, self.shape)) + return f"Accumulator{{{self.dtype.name}}}[{shapestr}]" + +@dataclasses.dataclass(frozen=True) +class WGMMAAccumulator: + inner_aval: WGMMAAbstractAccumulator + + shape = property(lambda self: self.inner_aval.shape) + dtype = property(lambda self: self.inner_aval.dtype) + + def as_array(self) -> jax_core.ShapedArray: + return acc_to_shaped_array_p.bind(self.inner_aval) + + +jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulator] = lambda aval, _: aval + +acc_to_shaped_array_p = jax_core.Primitive("acc_to_shaped_array") + +@acc_to_shaped_array_p.def_abstract_eval +def _acc_to_shaped_array_abstract_eval(acc) -> jax_core.ShapedArray: + return jax_core.ShapedArray(shape=acc.shape, dtype=acc.dtype) + + +@lowering.register_lowering_rule(acc_to_shaped_array_p) +def _acc_to_shaped_array_lowering_rule( + ctx: lowering.LoweringRuleContext, acc +): + del ctx + return acc.value + +wgmma_p = jax_core.Primitive("wgmma") + +def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): + """Asynchronous warp group matmul. + + The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires + that accumulator is an accumualtion register reference. + + Args: + acc: The accumulator register. + a: The left hand side operand. + b: The right hand side operand. + transpose: Whether to transpose b. + n_tile: The number of tiles to use. + swizzle: The swizzle pattern. + """ + if not isinstance(acc, WGMMAAccumulator): + raise TypeError(acc) + + rhs_transpose = ( + (jnp.dtype(b.dtype).itemsize == 2) + if rhs_transpose is None + else rhs_transpose + ) + + ma, ka, tma, tka = a.shape + kb, nb, tkb, tnb = b.shape + mc, nc = acc.shape + + if rhs_transpose: + kb, nb, tkb, tnb = nb, kb, tnb, tkb + + if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb: + raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}") + + outval = wgmma_p.bind(acc.inner_aval, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose) + return WGMMAAccumulator(outval) + +@wgmma_p.def_effectful_abstract_eval +def _wgmma_effectful_abstract_eval(acc, *args, **kwargs): + del args, kwargs + return acc, { + _wgmma_pipeline_effect, + state.ReadEffect(1), + state.ReadEffect(2), + } + +@lowering.register_lowering_rule(wgmma_p) +def _wgmma_lowering_rule( + ctx: lowering.LoweringRuleContext, + acc, + a, + b, + swizzle, + rhs_transpose, +): + del ctx + new_acc = mgpu.wgmma( + acc, + a, + b, + swizzle=swizzle, + b_order=mgpu.WGMMALayout.COL_MAJOR + if rhs_transpose + else mgpu.WGMMALayout.ROW_MAJOR, + ) + nvvm_dialect.wgmma_commit_group_sync_aligned() + return new_acc + +wgmma_wait_p = jax_core.Primitive("wgmma_wait") +wgmma_wait_p.multiple_results = True + +def wgmma_wait(i: int): + """Wait until all but the last `i` WGMMA operations are done.""" + return wgmma_wait_p.bind(i) + + +@wgmma_wait_p.def_effectful_abstract_eval +def wgmma_wait_effectful_abstract_eval(_): + return [], {_wgmma_pipeline_effect} + +@lowering.register_lowering_rule(wgmma_wait_p) +def _wgmma_wait_lowering_rule(ctx: lowering.LoweringRuleContext, allow_groups): + del ctx + nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) + return () + +zero_accumulator_p = jax_core.Primitive("zero_accumulator") +def zero_accumulator(shape, dtype): + return WGMMAAccumulator(zero_accumulator_p.bind(shape=shape, dtype=dtype)) + +@zero_accumulator_p.def_abstract_eval +def _zero_accumulator_abstract_eval(shape, dtype): + return WGMMAAbstractAccumulator(shape=shape, dtype=dtype) + + +@lowering.register_lowering_rule(zero_accumulator_p) +def _zero_accumulator_lowering_rule( + ctx: lowering.LoweringRuleContext, shape, dtype +): + del ctx + m, n = shape + return mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=mlir.dtype_to_ir_type(jnp.dtype(dtype))) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6aa74263cd94..f34beeeb8166 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -340,6 +340,44 @@ def kernel(x_ref, o_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 2.0 + 3.0) + def test_wgmma(self): + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + def kernel(a_ref, b_ref, o_ref): + acc = plgpu.zero_accumulator((64, 128), jnp.float32) + acc = plgpu.wgmma(acc, a_ref, b_ref, rhs_transpose=False) + plgpu.wgmma_wait(0) + # TODO(cperivol): turn acc into a reference so we can reason about effects. + o_ref[...] = acc.as_array() + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype) + b = jax.random.uniform(key2, shape=(128, 128), dtype=dtype) + + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec( + (64, 128), + lambda i, j: (i, j), + tiling=(64, elems_128b), + swizzle=128, + ), + plgpu.GPUBlockSpec( + (128, 128), + lambda *i: i, + tiling=(elems_128b, elems_128b), + swizzle=128, + ), + ], + out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), + out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), + grid=(1, 1), + )(a, b) + np.testing.assert_allclose( + res, a @ b, rtol=1e-3 + ) if __name__ == "__main__": absltest.main() From 93203c757413aee4de99b0e557af3b00137662ea Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 23 Sep 2024 11:09:16 -0700 Subject: [PATCH 618/702] [Pallas] Simplify sign and erf_inv tests Removed the method to locally enabling x64 using: ```python with contextlib.ExitStack() as stack: if jnp.dtype(dtype).itemsize == 8: stack.enter_context(config.enable_x64(True)) ``` This is because we can determine whether a test is running in x64 environment by checking the value of `jax.config.x64_enabled`. There is no need to locally enabling x64. PiperOrigin-RevId: 677865574 --- tests/pallas/ops_test.py | 70 +++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 346dde0dd79e..8d242617efbb 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -669,21 +669,31 @@ def run(interpret=False): actual = run(False) self.assertAllClose(actual, expected) - SIGN_PARAMS = [ - (jnp.int32, (-3, 0, 5)), - (jnp.uint32, (0, 5)), - (jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), - (jnp.float64, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), - ] - @parameterized.named_parameters( (f"{dtype.__name__}_{value}", dtype, value) - for dtype, values in SIGN_PARAMS + for dtypes, values in ( + ((jnp.uint16, jnp.uint32, jnp.uint64), (0, 5)), + ((jnp.int16, jnp.int32, jnp.int64), (-3, 0, 5)), + ( + (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64), + (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf), + ), + ) + for dtype in dtypes for value in values ) def test_sign(self, dtype, value): - if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64: - self.skipTest("float64 is not supported on TPU") + if ( + not jax.config.x64_enabled + and dtype in (jnp.uint64, jnp.int64, jnp.float64) + ): + self.skipTest("64-bit types require x64_enabled") + + if ( + jtu.test_device_matches(["tpu"]) + and dtype in (jnp.uint16, jnp.int16, jnp.bfloat16, jnp.float16) + ): + self.skipTest("16-bit types are not supported on TPU") @functools.partial( self.pallas_call, @@ -692,38 +702,32 @@ def test_sign(self, dtype, value): def kernel(x_ref, o_ref): o_ref[...] = jnp.sign(x_ref[...]) - with contextlib.ExitStack() as stack: - if jnp.dtype(dtype).itemsize == 8: - stack.enter_context(config.enable_x64(True)) + x = jnp.full((8, 128,), value, dtype=dtype) + out = kernel(x) + expected = jnp.sign(x) - x = jnp.full((8, 128,), value, dtype=dtype) - out = kernel(x) - expected = jnp.sign(x) - np.testing.assert_array_equal(out, expected) + # `.astype(jnp.float32)` is a workaround for dtype=bfloat16 and value=nan, + # see https://github.com/jax-ml/ml_dtypes/issues/206 + np.testing.assert_array_equal( + out.astype(jnp.float32), + expected.astype(jnp.float32), + ) - @parameterized.product( - dtype=[jnp.float32, jnp.float64], - value=[-3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4], + @parameterized.parameters( + -3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4, ) - def test_erf_inv(self, dtype, value): - if jtu.test_device_matches(["tpu"]) and dtype == jnp.float64: - self.skipTest("float64 is not supported on TPU") - + def test_erf_inv(self, value): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8, 128), dtype), + out_shape=jax.ShapeDtypeStruct((8, 128), floatx), ) def kernel(x_ref, o_ref): o_ref[...] = lax.erf_inv(x_ref[...]) - with contextlib.ExitStack() as stack: - if jnp.dtype(dtype).itemsize == 8: - stack.enter_context(config.enable_x64(True)) - - x = jnp.full((8, 128), value, dtype=dtype) - out = kernel(x) - expected = lax.erf_inv(x) - np.testing.assert_array_equal(out, expected) + x = jnp.full((8, 128), value, dtype=floatx) + out = kernel(x) + expected = lax.erf_inv(x) + np.testing.assert_array_equal(out, expected) class OpsInterpretTest(OpsTest): From 712e638ca4913498382572d41f72aaa027b1cfc7 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 23 Sep 2024 11:21:01 -0700 Subject: [PATCH 619/702] [pallas] Add support for `unblocked` mode (without padding) in Triton lowering. PiperOrigin-RevId: 677870258 --- jax/_src/pallas/triton/lowering.py | 11 +-- tests/pallas/pallas_test.py | 140 +++++++++++++++++++++++++++++ tests/pallas/tpu_pallas_test.py | 119 ------------------------ 3 files changed, 146 insertions(+), 124 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 6a16156271d7..15f0d265b836 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -121,6 +121,12 @@ def _eval_index_map( _ensure_ir_value(i, jax_core.ShapedArray((), jnp.int32)) for i in block_indices ) + if isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): + if block_mapping.indexing_mode.padding is not None: + raise NotImplementedError( + "Unblocked indexing with padding is not supported in Triton lowering." + ) + return tuple(block_indices) return tuple( i if b is pallas_core.mapped else _mul(i, _ir_constant(b, i.type)) for i, b in zip(block_indices, block_mapping.block_shape) @@ -324,11 +330,6 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "Scalar prefetch not supported in Triton lowering." ) - if not all(isinstance(bm.indexing_mode, Blocked) - for bm in grid_mapping.block_mappings): - raise NotImplementedError( - "Only Blocked indexing mode is supported in Triton lowering." - ) start_indices = map( functools.partial(_eval_index_map, ctx, program_ids), grid_mapping.block_mappings, diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index aec7fd54c925..6df31b55f8e7 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -683,6 +683,146 @@ class PallasCallInterpretTest(PallasCallTest): INTERPRET = True +class PallasCallUnblockedIndexingTest(PallasBaseTest): + + def test_block_spec_unblocked(self): + def show_program_ids( + *, shape, block_shape, grid, indexing_mode: pl.IndexingMode + ): + def kernel(o1_ref): + assert o1_ref.shape == block_shape + o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0)) + + return self.pallas_call( + kernel, + jax.ShapeDtypeStruct(shape, dtype=np.int32), + grid=grid, + out_specs=pl.BlockSpec( + block_shape, lambda i: (8 * i, 0), indexing_mode=indexing_mode + ), + )() + + # No padding + pids = show_program_ids( + shape=(16, 128), + block_shape=(8, 128), + grid=(2,), + indexing_mode=pl.Unblocked(), + ) + expected_pids = np.array([[0] * 128] * 8 + [[1] * 128] * 8, dtype=np.int32) + self.assertAllClose(pids, expected_pids) + + if jtu.test_device_matches(["gpu"]) and not self.INTERPRET: + self.skipTest("TODO: padding not implemented on GPU yet") + + # Only high padding + pids = show_program_ids( + shape=(14, 128), + block_shape=(8, 128), + grid=(2,), + indexing_mode=pl.Unblocked(((0, 2), (0, 0))), + ) + expected_pids = np.array([[0] * 128] * 8 + [[1] * 128] * 6, dtype=np.int32) + self.assertAllClose(pids, expected_pids) + + # Both low and high padding + self.skipTest("TODO: low padding not supported yet") + pids = show_program_ids( + shape=(11, 128), + block_shape=(8, 128), + grid=(2,), + indexing_mode=pl.Unblocked(((3, 2), (0, 0))), + ) + expected_pids = np.array([[0] * 128] * 5 + [[1] * 128] * 6, dtype=np.int32) + self.assertAllClose(pids, expected_pids) + + @parameterized.parameters("int32", "float32") + def test_block_spec_unblocked_padding_is_nan(self, dtype_name): + if not self.INTERPRET: + self.skipTest("Only applicable for the interpret mode") + + dtype = np.dtype(dtype_name) + + def copy_kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + res = self.pallas_call( + copy_kernel, + jax.ShapeDtypeStruct((6,), dtype=dtype), + grid=(1,), + in_specs=[ + pl.BlockSpec( + (6,), lambda i: 0, indexing_mode=pl.Unblocked(((1, 2),)) + ) + ], + )(np.full((3,), 42, dtype=dtype)) + expected_pad = {"int32": jnp.iinfo(np.int32).min, "float32": np.nan}[ + dtype_name + ] + self.assertAllClose( + res, + np.array( + [expected_pad, 42, 42, 42, expected_pad, expected_pad], dtype=dtype + ), + ) + + def test_unblocked_indexing(self): + shape = (16 * 8, 128) + result_ty = jax.ShapeDtypeStruct((15 * 8, 128), jnp.float32) + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[pl.ds(0, 8)] + x_ref[pl.ds(8, 8)] + + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + y = self.pallas_call( + kernel, + grid=(15,), + in_specs=( + pl.BlockSpec( + (2 * 8, 128), lambda i: (i * 8, 0), indexing_mode=pl.unblocked + ), + ), + out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), + out_shape=result_ty, + )(x) + ref = [] + for i in range(15): + block = x[i * 8 : i * 8 + 2 * 8] + ref.append(block[0:8] + block[8:16]) + ref = np.concatenate(ref, axis=0) + np.testing.assert_array_equal(y, ref) + + def test_unblocked_indexing_with_padding(self): + if jtu.test_device_matches(["gpu"]) and not self.INTERPRET: + self.skipTest("TODO: padding not implemented on GPU yet") + + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct((8, 128), jnp.float32) + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[pl.ds(0, 8)] + + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + y = self.pallas_call( + kernel, + grid=(1,), + in_specs=( + pl.BlockSpec( + (2 * 8, 128), + lambda i: (0, 0), + indexing_mode=pl.Unblocked(((0, 8), (0, 0))), + ), + ), + out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), + out_shape=result_ty, + )(x) + np.testing.assert_array_equal(y, x) + + +class PallasCallUnblockedIndexingInterpretTest(PallasCallUnblockedIndexingTest): + INTERPRET = True + + class ApiErrorTest(PallasBaseTest): def test_pallas_call_kernel_args_mismatch(self): a = np.arange(256, dtype=np.int32) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 87ccaa644e8c..9a81f3196ba2 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1652,125 +1652,6 @@ def kernel(x_ref, y_ref): )(x) -class PallasCallUnblockedIndexingTest(PallasBaseTest): - - def test_block_spec_unblocked(self): - def show_program_ids(*, shape, block_shape, grid, - indexing_mode: pl.IndexingMode): - def kernel(o1_ref): - assert o1_ref.shape == block_shape - o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0)) - - return self.pallas_call(kernel, - jax.ShapeDtypeStruct(shape, dtype=np.int32), - grid=grid, - out_specs=pl.BlockSpec(block_shape, - lambda i: (8 * i, 0), - indexing_mode=indexing_mode))() - # No padding - pids = show_program_ids(shape=(16, 128), block_shape=(8, 128), - grid=(2,), - indexing_mode=pl.Unblocked()) - expected_pids = np.array( - [[0] * 128] * 8 + [[1] * 128] * 8, - dtype=np.int32) - self.assertAllClose(pids, expected_pids) - - # Only high padding - pids = show_program_ids(shape=(14, 128), block_shape=(8, 128), - grid=(2,), - indexing_mode=pl.Unblocked(((0, 2), (0, 0)))) - expected_pids = np.array( - [[0] * 128] * 8 + [[1] * 128] * 6, - dtype=np.int32) - self.assertAllClose(pids, expected_pids) - - # Both low and high padding - self.skipTest("TODO: TPU low padding not supported yet") - pids = show_program_ids(shape=(11, 128), block_shape=(8, 128), - grid=(2,), - indexing_mode=pl.Unblocked(((3, 2), (0, 0)))) - expected_pids = np.array( - [[0] * 128] * 5 + [[1] * 128] * 6, - dtype=np.int32) - self.assertAllClose(pids, expected_pids) - - @parameterized.parameters("int32", "float32") - def test_block_spec_unblocked_padding_is_nan(self, dtype_name): - if not self.INTERPRET: - self.skipTest("Only applicable for the interpret mode") - - dtype = np.dtype(dtype_name) - def copy_kernel(x_ref, o_ref): - o_ref[...] = x_ref[...] - res = self.pallas_call(copy_kernel, - jax.ShapeDtypeStruct((6,), dtype=dtype), - grid=(1,), - in_specs=[pl.BlockSpec((6,), lambda i: 0, - indexing_mode=pl.Unblocked(((1, 2),)))])( - np.full((3,), 42, dtype=dtype) - ) - expected_pad = {"int32": jnp.iinfo(np.int32).min, - "float32": np.nan}[dtype_name] - self.assertAllClose(res, np.array([expected_pad, 42, 42, 42, - expected_pad, expected_pad], dtype=dtype)) - - def test_unblocked_indexing(self): - shape = (16 * 8, 128) - result_ty = jax.ShapeDtypeStruct((15 * 8, 128), jnp.float32) - - def kernel(x_ref, o_ref): - o_ref[...] = x_ref[pl.ds(0, 8)] + x_ref[pl.ds(8, 8)] - - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - y = self.pallas_call( - kernel, - grid=(15,), - in_specs=( - pl.BlockSpec( - (2 * 8, 128), lambda i: (i * 8, 0), indexing_mode=pl.unblocked - ), - ), - out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), - out_shape=result_ty, - )(x) - ref = [] - for i in range(15): - block = x[i * 8:i * 8 + 2 * 8] - ref.append(block[0:8] + block[8:16]) - ref = np.concatenate(ref, axis=0) - np.testing.assert_array_equal(y, ref) - - def test_unblocked_indexing_with_padding(self): - shape = (8, 128) - result_ty = jax.ShapeDtypeStruct((8, 128), jnp.float32) - - def kernel(x_ref, y_ref): - y_ref[...] = x_ref[pl.ds(0, 8)] - - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - y = self.pallas_call( - kernel, - grid=(1,), - in_specs=( - pl.BlockSpec( - (2 * 8, 128), - lambda i: (0, 0), - indexing_mode=pl.Unblocked(((0, 8), (0, 0))), - ), - ), - out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), - out_shape=result_ty, - )(x) - np.testing.assert_array_equal(y, x) - - -class PallasCallUnblockedIndexingInterpretTest( - PallasCallUnblockedIndexingTest -): - INTERPRET = True - - class PallasUXTest(PallasBaseTest): def test_mlir_location(self): From dc1ace59922312eb6247369be167d9ef9f1c0ef1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 23 Sep 2024 12:25:56 -0700 Subject: [PATCH 620/702] Re-enable tsan tests after fix. PiperOrigin-RevId: 677895934 --- tests/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index 4635a48cede1..e64889cc3b1f 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1189,7 +1189,6 @@ jax_test( }, tags = [ "noasan", # Times out. - "notsan", # TODO(b/309111150): Re-enable after rolling forward cl/666056414. ], deps = [ "//jax:experimental", From 4fccd64c8bd6657fe5cd90d5d784a3667a5d7cc8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 23 Sep 2024 12:30:59 -0700 Subject: [PATCH 621/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/1162b7e30d12d00aa4d004a71217ef958d8aa290. PiperOrigin-RevId: 677897482 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 56fac1fa611b..f8bddcd740bf 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "2101ae888f054bfbd13d5ac42af8aea3b1600749" -XLA_SHA256 = "3392793681e186f8ee5e901bd047bb688c9337970244d0c07ea173889db3d837" +XLA_COMMIT = "1162b7e30d12d00aa4d004a71217ef958d8aa290" +XLA_SHA256 = "706d360fa2f82174fb7210cf7b87470faa2440f7614efc57136f47879d0032ed" def repo(): tf_http_archive( From 29a1cb766e1744995245300f790ea7b82c61631c Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Mon, 23 Sep 2024 14:42:01 -0500 Subject: [PATCH 622/702] [ROCM] add missing typename keyword to work with gcc --- jaxlib/gpu/solver_kernels_ffi.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 9191a0ff8dec..32cd97565f5e 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -408,7 +408,7 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto w_data = static_cast::value*>(w->untyped_data()); + auto w_data = static_cast::value*>(w->untyped_data()); auto info_data = info->typed_data(); if (a_data != out_data) { JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( @@ -692,7 +692,7 @@ ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols, AllocateWorkspace(scratch, lwork, "gesvd")); auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto s_data = static_cast::value*>(s->untyped_data()); + auto s_data = static_cast::value*>(s->untyped_data()); auto u_data = compute_uv ? static_cast(u->untyped_data()) : nullptr; auto vt_data = compute_uv ? static_cast(vt->untyped_data()) : nullptr; auto info_data = info->typed_data(); From cc885ff8757afbe4a1cea093483e76b98dcf6187 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 20 Sep 2024 08:25:54 -0700 Subject: [PATCH 623/702] Better docs for jnp.meshgrid --- jax/_src/numpy/lax_numpy.py | 65 +++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index e9c5785a7715..559d17cd9514 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -819,11 +819,6 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, return hist, bin_edges_by_dim -_ARRAY_VIEW_DOC = """ -The JAX version of this function may in some cases return a copy rather than a -view of the input. -""" - def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: """Return a transposed version of an N-dimensional array. @@ -6055,9 +6050,67 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool return lax.convert_element_type(res, dtype) -@util.implements(np.meshgrid, lax_description=_ARRAY_VIEW_DOC) def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> list[Array]: + """Construct N-dimensional grid arrays from N 1-dimensional vectors. + + JAX implementation of :func:`numpy.meshgrid`. + + Args: + xi: N arrays to convert to a grid. + copy: whether to copy the input arrays. JAX supports only ``copy=True``, + though under JIT compilation the compiler may opt to avoid copies. + sparse: if False (default), then each returned arrays will be of shape + ``[len(x1), len(x2), ..., len(xN)]``. If False, then returned arrays + will be of shape ``[1, 1, ..., len(xi), ..., 1, 1]``. + indexing: options are ``'xy'`` for cartesian indexing (default) or ``'ij'`` + for matrix indexing. + + Returns: + A length-N list of grid arrays. + + See also: + - :obj:`jax.numpy.mgrid`: create a meshgrid using indexing syntax. + - :obj:`jax.numpy.ogrid`: create an open meshgrid using indexing syntax. + + Examples: + For the following examples, we'll use these 1D arrays as inputs: + + >>> x = jnp.array([1, 2]) + >>> y = jnp.array([10, 20, 30]) + + 2D cartesian mesh grid: + + >>> x_grid, y_grid = jnp.meshgrid(x, y) + >>> print(x_grid) + [[1 2] + [1 2] + [1 2]] + >>> print(y_grid) + [[10 10] + [20 20] + [30 30]] + + 2D sparse cartesian mesh grid: + + >>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True) + >>> print(x_grid) + [[1 2]] + >>> print(y_grid) + [[10] + [20] + [30]] + + 2D matrix-index mesh grid: + + >>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij') + >>> print(x_grid) + [[1 1 1] + [2 2 2]] + >>> print(y_grid) + [[10 20 30] + [10 20 30]] + """ util.check_arraylike("meshgrid", *xi) args = [asarray(x) for x in xi] if not copy: From e4091a67524b4555e45f17f4d02ab16bee5d520d Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Mon, 23 Sep 2024 15:03:32 -0700 Subject: [PATCH 624/702] Fix another errata in block-sparse kernel tutorial. PiperOrigin-RevId: 677952796 --- docs/pallas/tpu/sparse.ipynb | 2 +- docs/pallas/tpu/sparse.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index 6666b9e3ec5a..a80ba4ebedbb 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -297,7 +297,7 @@ " ):\n", " \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n", " del idxs_k_ref\n", - " blk_idx = pl.program_id(1)\n", + " blk_idx = pl.program_id(0)\n", " is_start = blk_idx == 0\n", " changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n", " @pl.when(is_start | changed_blocks)\n", diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 3bc662895654..2ac25edb5064 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -237,7 +237,7 @@ def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs. ): """A DSD (Dense = Sparse @ Dense) matmul kernel.""" del idxs_k_ref - blk_idx = pl.program_id(1) + blk_idx = pl.program_id(0) is_start = blk_idx == 0 changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)]) @pl.when(is_start | changed_blocks) From a99ea733367bf98e002f698661470f18ac1bc71a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 23 Sep 2024 16:02:51 -0700 Subject: [PATCH 625/702] Use `jax.make_array_from_process_local_data` API in distributed data loading doc PiperOrigin-RevId: 677973689 --- docs/distributed_data_loading.md | 74 +++----------------------------- 1 file changed, 6 insertions(+), 68 deletions(-) diff --git a/docs/distributed_data_loading.md b/docs/distributed_data_loading.md index d7b88be44178..4f4dd7839c37 100644 --- a/docs/distributed_data_loading.md +++ b/docs/distributed_data_loading.md @@ -243,35 +243,10 @@ ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index()) # Grab just the first batch from the Dataset for this example per_process_batch = ds.as_numpy_iterator().next() -per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim - # isn't 0 - -per_replica_batch_size = per_process_batch_size // jax.local_device_count() -assert per_process_batch_size % per_replica_batch_size == 0, \ - "This example doesn't implement padding." -per_replica_batches = np.split(per_process_batch, jax.local_device_count()) - -# Thanks to the very important trick about data parallelism, no need to care what -# order the devices appear in the sharding. -sharding = jax.sharding.PositionalSharding(jax.devices()) -# PositionalSharding must have same rank as data being sharded. -sharding = sharding.reshape((jax.device_count(),) + - (1,) * (per_process_batch.ndim - 1)) - -global_batch_size = per_replica_batch_size * jax.device_count() -global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:]) - -global_batch_array = jax.make_array_from_single_device_arrays( - global_batch_shape, sharding, - # Thanks again to the very important trick, no need to care which device gets - # which per-replica batch. - arrays=[jax.device_put(batch, device) - for batch, device - in zip(per_replica_batches, sharding.addressable_devices)]) - -assert global_batch_array.shape == global_batch_shape -assert (global_batch_array.addressable_shards[0].data.shape == - per_replica_batches[0].shape) +mesh = jax.make_mesh((jax.device_count(),), ('batch',)) +sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('batch')) +global_batch_array = jax.make_array_from_process_local_data( + sharding, per_process_batch) ``` ## Data + model parallelism @@ -366,16 +341,6 @@ per_process_batch = ds.as_numpy_iterator().next() num_model_replicas_per_process = 2 # set according to your parallelism strategy num_model_replicas_total = num_model_replicas_per_process * jax.process_count() -per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim - # isn't 0 - -per_replica_batch_size = (per_process_batch_size // - num_model_replicas_per_process) -assert per_process_batch_size % per_replica_batch_size == 0, \ - "This example doesn't implement padding." -per_replica_batches = np.split(per_process_batch, - num_model_replicas_per_process) - # Create an example `Mesh` for per-process data parallelism. Make sure all devices # are grouped by process, and then resize so each row is a model replica. mesh_devices = np.array([jax.local_devices(process_idx) @@ -393,35 +358,8 @@ mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"]) sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("model_replicas")) -global_batch_size = per_replica_batch_size * num_model_replicas_total -global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:]) - -# Create the final jax.Array using jax.make_array_from_callback. The callback -# will be called for each local device, and passed the N-D numpy-style index -# that describes what shard of the global data that device should receive. -# -# You don't need care exactly which index is passed in due to the very important data -# parallelism, but you do use the index argument to make sure you replicate each -# per-replica batch correctly -- the `index` argument will be the same for -# devices in the same model replica, and different for devices in different -# model replicas. - -index_to_batch = {} -def callback(index: tuple[slice, ...]) -> np.ndarray: - # Python `slice` objects aren't hashable, so manually create dict key. - index_key = tuple((slice_.start, slice_.stop) for slice_ in index) - if index_key not in index_to_batch: - # You don't care which per-replica batch goes to which replica, just take the - # next unused one. - index_to_batch[index_key] = per_replica_batches[len(index_to_batch)] - return index_to_batch[index_key] - -global_batch_array = jax.make_array_from_callback( - global_batch_shape, sharding, callback) - -assert global_batch_array.shape == global_batch_shape -assert (global_batch_array.addressable_shards[0].data.shape == - per_replica_batches[0].shape) +global_batch_array = jax.make_array_from_process_local_data( + sharding, per_process_batch) ``` ### Model parallelism across processes From 45af8742caf6e250619074c5aca7c3d0ee67d44e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 23 Sep 2024 18:18:15 -0700 Subject: [PATCH 626/702] trigger array API tests for all PRs. We should have done this when we deprecated jax.experimental.array_api --- .github/workflows/jax-array-api.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index cbe383f21ffe..bbbe53732a69 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -1,13 +1,12 @@ name: JAX Array API on: - workflow_dispatch: # allows triggering the workflow run manually - pull_request: # Automatically trigger on pull requests affecting particular files + push: + branches: + - main + pull_request: branches: - main - paths: - - '**workflows/jax-array-api.yml' - - '**experimental/array_api/**' concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} From a0e4448393991837494eb941aa7a73911204cd9f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 24 Sep 2024 01:05:15 +0000 Subject: [PATCH 627/702] Remove warning filters from pyproject.toml, add local warning suppressions. We want to support running Bazel tests with PYTHONWARNINGS=error. In preparation for that change, move warning suppressions from pyproject.toml into the individual test cases that generate them, which is a reasonable cleanup anyway. --- pyproject.toml | 11 ----------- tests/array_interoperability_test.py | 2 ++ tests/compilation_cache_test.py | 2 ++ tests/scipy_stats_test.py | 4 ++++ 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a69adbfae2fd..b629762feff9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,23 +54,12 @@ markers = [ ] filterwarnings = [ "error", - "default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'", - "default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'", - "default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", # TODO(jakevdp): remove when array_api_tests stabilize "default:.*not machine-readable.*:UserWarning", "default:Special cases found for .* but none were parsed.*:UserWarning", "default:.*is not JSON-serializable. Using the repr instead.*:UserWarning", "default:The .* method is good for exploring strategies.*", - - # These are transitive warnings coming from TensorFlow dependencies. - # TODO(slebedev): Remove once we bump the minimum TensorFlow version. - "default:The key path API is deprecated .*", - "default:jax.xla_computation is deprecated.*:DeprecationWarning", - - # TODO(slebedev): Remove once we drop the legacy DLPack import path. - "default:.*from_dlpack with a DLPack tensor is deprecated.*:DeprecationWarning", ] doctest_optionflags = [ "NUMBER", diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 5585f1bcc005..3560241530c5 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -221,6 +221,8 @@ def testJaxToNumpy(self, shape, dtype): x_np = np.from_dlpack(x_jax) self.assertAllClose(x_np, x_jax) + @jtu.ignore_warning(message="Calling from_dlpack.*", + category=DeprecationWarning) def testNondefaultLayout(self): # Generate numpy array with nonstandard layout a = np.arange(4).reshape(2, 2) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index fd02f59826cc..75c52822a223 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -243,6 +243,7 @@ def test_cache_write_warning(self): mock.patch.object(cc._get_cache(backend).__class__, "put") as mock_put, warnings.catch_warnings(record=True) as w, ): + warnings.simplefilter("always") mock_put.side_effect = RuntimeError("test error") self.assertEqual(f(2).item(), 4) if len(w) != 1: @@ -265,6 +266,7 @@ def test_cache_read_warning(self): mock.patch.object(cc._get_cache(backend).__class__, "get") as mock_get, warnings.catch_warnings(record=True) as w, ): + warnings.simplefilter("always") mock_get.side_effect = RuntimeError("test error") # Calling assertEqual with the jitted f will generate two PJIT # executables: Equal and the lambda function itself. diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 91563f698ad4..f02ed0fc04bb 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -1572,6 +1572,10 @@ def evaluate_kde(kde, x): category=RuntimeWarning, message="One or more sample arguments is too small; all returned values will be NaN" ) + @jtu.ignore_warning( + category=RuntimeWarning, + message="All axis-slices of one or more sample arguments are too small", + ) def testMode(self, shape, dtype, axis, contains_nans, keepdims): if scipy_version < (1, 9, 0) and keepdims != True: self.skipTest("scipy < 1.9.0 only support keepdims == True") From adaf54a4bbe10ce05edcfeb29039c6948444c641 Mon Sep 17 00:00:00 2001 From: Jane Liu Date: Mon, 23 Sep 2024 12:54:32 -0700 Subject: [PATCH 628/702] enable the activation offloading test --- tests/memories_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index 3e0f444a1e66..63b21e2d3e6d 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1567,8 +1567,6 @@ def g(ys, _): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_layout_change_offloadable(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Remat scan does not work on GPU backend.") mesh = jtu.create_mesh((2,), ("x",)) shape = (256, 128) np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -1602,6 +1600,10 @@ def g(ys, _): self.assertIn('S(5)', compiled_text) self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-slice-done.*S\(5\)") compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: From 8196c8bf365a29b83746c550b3f0083eb1345104 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 24 Sep 2024 05:18:40 -0700 Subject: [PATCH 629/702] Added support for % and `select` to `mgpu.FragmentedArray` PiperOrigin-RevId: 678200940 --- .../mosaic/gpu/fragmented_array.py | 28 +++++++++++++++++-- tests/mosaic/gpu_test.py | 18 +++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 6cdda1a4ff62..1c1ec18d3cf2 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -369,14 +369,30 @@ def __rsub__(self, other): def __truediv__(self, other): if not ir.FloatType.isinstance(self.mlir_dtype): - raise NotImplementedError + return NotImplemented return self._pointwise(arith.divf, other) def __rtruediv__(self, other): if not ir.FloatType.isinstance(self.mlir_dtype): - raise NotImplementedError + return NotImplemented return self._pointwise(lambda s, o: arith.divf(o, s), other) + def __mod__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + if self.is_signed: + return self._pointwise(arith.remsi, other) + else: + return self._pointwise(arith.remui, other) + + def __rmod__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + if self.is_signed: + return self._pointwise(lambda s, o: arith.remsi(o, s), other) + else: + return self._pointwise(lambda s, o: arith.remui(o, s), other) + def __eq__(self, other): return self._compare( other, @@ -768,6 +784,14 @@ def broadcast_minor(self, n): _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed ) + def select(self, x, y): + if ( + not ir.IntegerType.isinstance(self.mlir_dtype) + or ir.IntegerType(self.mlir_dtype).width != 1 + ): + raise NotImplementedError + return self._pointwise(arith.select, x, y) + def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]): """Call a function for each value and index.""" if not isinstance(self.layout, WGStridedFragLayout): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 41725e6017f4..2eacf7c9984c 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1193,6 +1193,7 @@ class FragmentedArrayTest(TestCase): operator.mul, operator.sub, operator.truediv, + operator.mod, (lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum), ), dtype=[jnp.float32, jnp.int32, jnp.uint32], @@ -1205,8 +1206,10 @@ def test_binary(self, op, dtype, m=64, n=32): else: np_op = op - if not jnp.issubdtype(dtype, jnp.floating) and op is operator.truediv: + if jnp.issubdtype(dtype, jnp.integer) and op is operator.truediv: self.skipTest("Unsupported for integer types") + if jnp.issubdtype(dtype, jnp.floating) and op is operator.mod: + self.skipTest("Unsupported for floating types") for scalar_rhs in [None, 2]: def kernel(ctx, dst, _): @@ -1269,6 +1272,19 @@ def kernel(ctx, dst, _): x = np.arange(m * n, dtype=dtype).reshape(m, n) np.testing.assert_allclose(result, np_op(x), atol=2e-7, rtol=2e-7) + def test_select(self, m=64, n=32): + + def kernel(ctx, dst, _): + iota = iota_tensor(m, n, jnp.int32) + (iota < 16).select(iota * 2, iota * 3).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((m, n), jnp.int32) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + x = np.arange(m * n, dtype=jnp.int32).reshape(m, n) + np.testing.assert_array_equal(result, np.where(x < 16, x * 2, x * 3)) + @parameterized.product( ops=[ (lambda x: mgpu.FragmentedArray.exp(x), np.exp), From a44e129ae791663cf391b35f4e26ff559a2b289a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 24 Sep 2024 05:22:18 -0700 Subject: [PATCH 630/702] Add more informative error when static argument is passed to non-static JIT parameter --- jax/_src/pjit.py | 12 ++++++++++++ tests/api_test.py | 24 +++++++++++++----------- tests/lax_test.py | 5 +++-- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ac1318ed7810..0a75128477ce 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -645,6 +645,18 @@ def _infer_params_impl( "An overflow was encountered while parsing an argument to a jitted " f"computation, whose {arg_path}." ) from e + except TypeError as e: + arg_description = (f"path {dbg.arg_names[i]}" if dbg + else f"flattened argument number {i}") + raise TypeError( + f"Error interpreting argument to {fun} as an abstract array." + f" The problematic value is of type {type(a)} and was passed to" + f" the function at {arg_description}.\n" + "This typically means that a jit-wrapped function was called with a non-array" + " argument, and this argument was not marked as static using the" + " static_argnums or static_argnames parameters of jax.jit." + ) from e + in_type = in_avals = tuple(avals) else: in_type = in_avals diff --git a/tests/api_test.py b/tests/api_test.py index adce61d650d6..d1c77c75eb79 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -733,13 +733,12 @@ def test_jit_bad_input(self): def f(x): return x - with self.assertRaisesRegex( - TypeError, r".* 'foo' of type <.*'str'> is not a valid JAX type"): + err_str = ("Error interpreting argument to .* as an abstract array. The problematic " + "value is of type .* and was passed to the function at path x.") + with self.assertRaisesRegex(TypeError, err_str): jit(f)("foo") # Jax type objects aren't valid data arguments. - err_str = "JAX scalar type .*int32.* cannot be interpreted as a JAX array." - with self.assertRaisesRegex(TypeError, err_str): jit(f)(jnp.int32) @@ -1576,13 +1575,14 @@ def test_bad_input(self): def f(x): return x - self.assertRaisesRegex( - TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type", - lambda: grad(f)("foo")) + with self.assertRaisesRegex(TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type"): + grad(f)("foo") - self.assertRaisesRegex( - TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type", - lambda: jit(f)("foo")) + + err_str = ("Error interpreting argument to .* as an abstract array. The problematic " + "value is of type .* and was passed to the function at path x.") + with self.assertRaisesRegex(TypeError, err_str): + jit(f)("foo") def test_grad_tuple_output(self): jtu.check_raises(lambda: grad(lambda x: (x,x))(1.0), TypeError, @@ -2959,8 +2959,10 @@ def check_warning(warn, nowarn): lambda: jnp.arange(1.0).astype(int)) def test_error_for_invalid_dtype(self): + err_str = ("Error interpreting argument to .* as an abstract array. The problematic " + r"value is of type .* and was passed to the function at path args\[1\].") with jax.enable_checks(False): - with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"): + with self.assertRaisesRegex(TypeError, err_str): lax.add(jnp.array(7), np.array("hello")) with jax.enable_checks(True): with self.assertRaises(AssertionError): diff --git a/tests/lax_test.py b/tests/lax_test.py index 0ae5f77afbdb..3f43773a8ec7 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2844,9 +2844,10 @@ def testDynamicUpdateSliceTypeErrors(self): (np.int32(1), np.int16(2)))) def test_primitive_jaxtype_error(self): + err_str = ("Error interpreting argument to .* as an abstract array. The problematic " + r"value is of type .* and was passed to the function at path args\[1\].") with jax.enable_checks(False): - with self.assertRaisesRegex( - TypeError, "Argument .* of type .* is not a valid JAX type"): + with self.assertRaisesRegex(TypeError, err_str): lax.add(1, 'hi') def test_reduction_with_repeated_axes_error(self): From 6229511f6a77a83ab7dca4414085866a438329b9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 18 Sep 2024 15:21:29 -0700 Subject: [PATCH 631/702] Make jnp.negative a ufunc & add unary ufunc tests --- jax/_src/numpy/array_methods.py | 8 +-- jax/_src/numpy/ufuncs.py | 3 +- tests/lax_numpy_ufuncs_test.py | 94 ++++++++++++++++++++++++--------- 3 files changed, 75 insertions(+), 30 deletions(-) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 547fe1247459..95d681cad8e5 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -909,15 +909,15 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, "setitem": _unimplemented_setitem, "copy": _copy, "deepcopy": _deepcopy, - "neg": ufuncs.negative, - "pos": ufuncs.positive, + "neg": lambda self: ufuncs.negative(self), + "pos": lambda self: ufuncs.positive(self), "eq": _defer_to_unrecognized_arg("==", ufuncs.equal), "ne": _defer_to_unrecognized_arg("!=", ufuncs.not_equal), "lt": _defer_to_unrecognized_arg("<", ufuncs.less), "le": _defer_to_unrecognized_arg("<=", ufuncs.less_equal), "gt": _defer_to_unrecognized_arg(">", ufuncs.greater), "ge": _defer_to_unrecognized_arg(">=", ufuncs.greater_equal), - "abs": ufuncs.abs, + "abs": lambda self: ufuncs.abs(self), "add": _defer_to_unrecognized_arg("+", ufuncs.add), "radd": _defer_to_unrecognized_arg("+", ufuncs.add, swap=True), "sub": _defer_to_unrecognized_arg("-", ufuncs.subtract), @@ -944,7 +944,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, "ror": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or, swap=True), "xor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor), "rxor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor, swap=True), - "invert": ufuncs.bitwise_not, + "invert": lambda self: ufuncs.bitwise_not(self), "lshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift), "rshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift), "rlshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift, swap=True), diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index a88ea3d760dc..2598b2183534 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -117,7 +117,7 @@ def invert(x: ArrayLike, /) -> Array: @partial(jit, inline=True) -def negative(x: ArrayLike, /) -> Array: +def _negative(x: ArrayLike, /) -> Array: """Return element-wise negative values of the input. JAX implementation of :obj:`numpy.negative`. @@ -2615,3 +2615,4 @@ def _multiply_at(a: Array, indices: Any, b: ArrayLike): logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce) logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce) logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor) +negative = ufunc(_negative, name="negative", nin=1, nout=1, call=_negative) diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 630d89f53c5a..537146a215c7 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -58,7 +58,7 @@ def _jnp_ufunc_props(name): jnp_func = getattr(jnp, name) assert isinstance(jnp_func, jnp.ufunc) np_func = getattr(np, name) - dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types] + dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types or f"{c}->{c}" in np_func.types] return [dict(name=name, dtype=dtype) for dtype in dtypes] @@ -66,10 +66,27 @@ def _jnp_ufunc_props(name): name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc) ] +BINARY_UFUNCS = [ + name for name in JAX_NUMPY_UFUNCS if getattr(jnp, name).nin == 2 +] + +UNARY_UFUNCS = [ + name for name in JAX_NUMPY_UFUNCS if getattr(jnp, name).nin == 1 +] + JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable( _jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS )) +BINARY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable( + _jnp_ufunc_props(name) for name in BINARY_UFUNCS +)) + +UNARY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable( + _jnp_ufunc_props(name) for name in UNARY_UFUNCS +)) + + broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] nonscalar_shapes = [(3,), (4,), (4, 3)] @@ -144,12 +161,25 @@ def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - JAX_NUMPY_UFUNCS_WITH_DTYPES, + UNARY_UFUNCS_WITH_DTYPES, + shape=broadcast_compatible_shapes, + ) + def test_unary_ufunc_call(self, name, dtype, shape): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, lhs_shape=broadcast_compatible_shapes, rhs_shape=broadcast_compatible_shapes, ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def test_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): + def test_bimary_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) rng = jtu.rand_default(self.rng()) @@ -177,15 +207,13 @@ def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - JAX_NUMPY_UFUNCS_WITH_DTYPES, + BINARY_UFUNCS_WITH_DTYPES, lhs_shape=broadcast_compatible_shapes, rhs_shape=broadcast_compatible_shapes, ) - def test_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype): + def test_binary_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) - if (jnp_fun.nin, jnp_fun.nout) != (2, 1): - self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] @@ -213,16 +241,15 @@ def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - JAX_NUMPY_UFUNCS_WITH_DTYPES, + BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes for axis in [None, *range(-len(shape), len(shape))]], ) - def test_ufunc_reduce(self, name, shape, axis, dtype): + def test_binary_ufunc_reduce(self, name, shape, axis, dtype): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) - if (jnp_fun.nin, jnp_fun.nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis) np_fun_reduce = partial(np_fun.reduce, axis=axis) @@ -266,16 +293,15 @@ def np_fun(arr, where): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - JAX_NUMPY_UFUNCS_WITH_DTYPES, + BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes for axis in [None, *range(-len(shape), len(shape))]], ) - def test_ufunc_reduce_where(self, name, shape, axis, dtype): + def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) - if (jnp_fun.nin, jnp_fun.nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + if jnp_fun.identity is None: self.skipTest("reduce with where requires identity") @@ -309,16 +335,14 @@ def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dty self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - JAX_NUMPY_UFUNCS_WITH_DTYPES, + BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes for axis in range(-len(shape), len(shape))] ) - def test_ufunc_accumulate(self, name, shape, axis, dtype): + def test_binary_ufunc_accumulate(self, name, shape, axis, dtype): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) - if (jnp_fun.nin, jnp_fun.nout) != (2, 1): - self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -355,15 +379,35 @@ def np_fun(x, idx, y): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - JAX_NUMPY_UFUNCS_WITH_DTYPES, + UNARY_UFUNCS_WITH_DTYPES, shape=nonscalar_shapes, idx_shape=[(), (2,)], ) - def test_ufunc_at(self, name, shape, idx_shape, dtype): + def test_unary_ufunc_at(self, name, shape, idx_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0]) + args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')] + + jnp_fun_at = partial(jnp_fun.at, inplace=False) + def np_fun_at(x, idx): + x_copy = x.copy() + np_fun.at(x_copy, idx) + return x_copy + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + self._CompileAndCheck(jnp_fun_at, args_maker) + + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + shape=nonscalar_shapes, + idx_shape=[(), (2,)], + ) + def test_binary_ufunc_at(self, name, shape, idx_shape, dtype): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) - if (jnp_fun.nin, jnp_fun.nout) != (2, 1): - self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") rng = jtu.rand_default(self.rng()) idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0]) @@ -413,13 +457,13 @@ def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_s self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - JAX_NUMPY_UFUNCS_WITH_DTYPES, + BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes for axis in [*range(-len(shape), len(shape))]], idx_shape=[(0,), (3,), (5,)], ) - def test_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype): + def test_binary_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) if (jnp_fun.nin, jnp_fun.nout) != (2, 1): From ae86ef16c7a03409cb444da3d477b2adb8134e6f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 24 Sep 2024 06:12:46 -0700 Subject: [PATCH 632/702] [Mosaic GPU] Add support for input_output_aliases PiperOrigin-RevId: 678217775 --- .../mosaic_gpu/pallas_call_registration.py | 5 +---- jax/experimental/mosaic/gpu/core.py | 11 ++++++++++- tests/pallas/mosaic_gpu_test.py | 17 +++++++++++++++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 510d4032f3dd..960fe7d71856 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -44,10 +44,6 @@ def pallas_call_lowering( raise NotImplementedError( "dynamic grid bounds not supported in the Mosaic GPU backend" ) - if input_output_aliases: - raise NotImplementedError( - "input_output_aliases not supported in the Mosaic GPU backend" - ) if debug: print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") @@ -72,4 +68,5 @@ def pallas_call_lowering( *args, module=module.operation.get_asm(binary=True, enable_debug_info=True), out_types=lowering_result.out_structs, + input_output_aliases=input_output_aliases, ) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 9a4570ec1673..9a03afeb4a8c 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -96,7 +96,14 @@ def _mosaic_gpu_abstract_eval(*_, module, out_types): # TODO(apaszke): Implement a proper system for managing kernel lifetimes KNOWN_KERNELS = {} -def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): + +def _mosaic_gpu_lowering_rule( + ctx, + *args, + module, + out_types, + input_output_aliases: tuple[tuple[int, int], ...] = (), +): del out_types # Unused. kernel_id = hashlib.sha256(module).digest() # Note that this is technically only a half measure. Someone might load a @@ -114,9 +121,11 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], backend_config=kernel_id + module, + operand_output_aliases=dict(input_output_aliases), ) return op.results + mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index f34beeeb8166..2f247ca60cff 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -379,5 +379,22 @@ def kernel(a_ref, b_ref, o_ref): res, a @ b, rtol=1e-3 ) + def test_input_output_aliases(self): + # Note that we're writing to the input pointer, which should alias b_ptr. + def kernel(a_ref, b_ref): + del b_ref + a_ref[...] = jnp.ones_like(a_ref) + + a = np.zeros((64, 64), dtype=jnp.float32) + b = pl.pallas_call( + kernel, + in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM), + input_output_aliases={0: 0}, + out_shape=a, + )(a) + np.testing.assert_array_equal(b, np.ones_like(a)) + + if __name__ == "__main__": absltest.main() From 562e9e8dff57129d8ba10298aa2dbb4ba0d6b4e0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 24 Sep 2024 13:13:29 +0000 Subject: [PATCH 633/702] Fix an incorrect output for jnp.cumsum. If dtype=bool but a non-bool input is passed, we should test for non-equality with zero rather than performing a cast to integer. --- CHANGELOG.md | 3 +++ jax/_src/numpy/reductions.py | 6 +++++- tests/lax_numpy_reducers_test.py | 4 ++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 43db6e197b5d..5bdcd1c20106 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument. The argument was only used by `xmap` which was removed in 0.4.31. +* Bug fixes + * Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs + if a non-boolean input was provided and `dtype=bool` was specified. ## jax 0.4.33 (September 16, 2024) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 043c976ef6f5..3436b00cfce1 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -1810,16 +1810,20 @@ def _cumulative_reduction( if fill_nan: a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) + a_type: DType = dtypes.dtype(a) result_type: DTypeLike = dtypes.dtype(dtype or a) if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_): result_type = _promote_integer_dtype(result_type) result_type = dtypes.canonicalize_dtype(result_type) + if a_type != np.bool_ and dtype == np.bool_: + a = lax_internal.asarray(a).astype(np.bool_) + a = lax.convert_element_type(a, result_type) result = reduction(a, axis) # We downcast to boolean because we accumulate in integer types - if dtypes.issubdtype(dtype, np.bool_): + if dtype is not None and dtypes.issubdtype(dtype, np.bool_): result = lax.convert_element_type(result, np.bool_) return result diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 33830c541fb9..623c11a51998 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -861,6 +861,10 @@ def testCumulativeSumErrors(self, shape, dtype, include_initial): with self.assertRaisesRegex(ValueError, msg): jnp.cumulative_sum(x, include_initial=include_initial) + def testCumulativeSumBool(self): + out = jnp.cumulative_sum(jnp.array([[0.1], [0.1], [0.0]]), axis=-1, + dtype=jnp.bool_) + np.testing.assert_array_equal(np.array([[True], [True], [False]]), out) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 8d86a04727a4c17d498f198f39688b1c1b310e1c Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 24 Sep 2024 08:16:48 -0700 Subject: [PATCH 634/702] [pallas] Allow `TransformedRef` to be passed to `pl.load` / `pl.store`, when `idx = ()`. PiperOrigin-RevId: 678257485 --- jax/_src/state/primitives.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 988f362290f0..773302c9f637 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -62,10 +62,13 @@ get_p = core.Primitive("get") get_p.def_impl(partial(dispatch.apply_primitive, get_p)) -Indexer = tuple[Union[int, slice, Array, types.EllipsisType], ...] +Indexer = Union[int, slice, Array, types.EllipsisType] + def get_ref_and_transforms( - ref_or_view: Any, idx: Indexer | None, function_name: str + ref_or_view: Any, + idx: Indexer | tuple[Indexer, ...] | None, + function_name: str, ) -> tuple[Any, tuple[Transform, ...]]: if isinstance(ref_or_view, TransformedRef): ref, transforms = ref_or_view.ref, ref_or_view.transforms @@ -76,18 +79,27 @@ def get_ref_and_transforms( raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.") if not isinstance(ref_aval.inner_aval, core.ShapedArray): return ref, () - if idx is None: + + if idx is None or idx is Ellipsis: + idx = () + elif not isinstance(idx, tuple): + idx = (idx,) + + if not idx and transforms and isinstance(transforms[-1], indexing.NDIndexer): return ref, transforms nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape) return ref, (*transforms, nd_indexer) -def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array: +def ref_get( + ref_or_view: Any, idx: Indexer | tuple[Indexer, ...] | None = None +) -> Array: """Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_get") flat_transforms, tree = tree_util.tree_flatten(transforms) return get_p.bind(ref, *flat_transforms, tree=tree) + # `swap` mutates a `Ref`, setting its value and returns its previous value. # b = swap_p.bind(x, a) # It generalizes the setting operation for a `Ref` as we can ignore the return @@ -110,7 +122,7 @@ def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array: def ref_swap( ref_or_view: AbstractRef | TransformedRef, - idx: Indexer | None, + idx: Indexer | tuple[Indexer, ...] | None, value: Array, _function_name: str = "ref_swap", ) -> Array: @@ -121,11 +133,14 @@ def ref_swap( def ref_set( - ref_or_view: AbstractRef | TransformedRef, idx: Indexer | None, value: Array + ref_or_view: AbstractRef | TransformedRef, + idx: Indexer | tuple[Indexer, ...] | None, + value: Array, ) -> None: """Sets a `Ref`'s value, a.k.a. ref[idx] <- value.""" ref_swap(ref_or_view, idx, value, _function_name="ref_set") + # `addupdate_p` mutates a `Ref`, adding a value to its existing value. # Semantically, # ``` @@ -141,12 +156,18 @@ def ref_set( addupdate_p.multiple_results = True addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p)) -def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None: + +def ref_addupdate( + ref_or_view: AbstractRef, + idx: Indexer | tuple[Indexer, ...] | None, + x: Array, +) -> None: """Mutates a ref with an additive update i.e. `ref[idx] += x`.""" ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_addupdate") flat_transforms, tree = tree_util.tree_flatten(transforms) return addupdate_p.bind(ref, x, *flat_transforms, tree=tree) + ## get/set/addupdate abstract evaluation rules From 9114b084fc0ac898d3a3f2de9e0266c6c1546842 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 24 Sep 2024 09:17:18 -0700 Subject: [PATCH 635/702] [Pallas] Update export compatibility tests The old test was generated before our IR was really stable, which has started to cause problems when trying to test with Trillium. PiperOrigin-RevId: 678277755 --- .../pallas/mosaic_matmul.py | 288 +++++++++--------- .../pallas/export_back_compat_pallas_test.py | 4 +- 2 files changed, 148 insertions(+), 144 deletions(-) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py index 065db82453f3..2c94cb777b46 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py @@ -16,72 +16,72 @@ from numpy import array, float32 -# Pasted from the test output (see back_compat_test_util.py module docstring) -data_2023_09_22 = dict( +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_09_24 = dict( testdata_version=1, platform='tpu', custom_call_targets=['tpu_custom_call'], - serialized_date=datetime.date(2023, 9, 22), + serialized_date=datetime.date(2024, 9, 24), inputs=(), - expected_outputs=(array([[ 90458.2 , 90470.875, 90480.85 , 90491.11 , - 90500.945, 90510.95 , 90521.18 , 90530.95 , - 90540.78 , 90551.16 , 90560.68 , 90570.734, - 90580.73 , 90590.58 , 90600.66 , 90610.61 ], - [ 643341.75 , 643434.25 , 643509.75 , 643587.06 , - 643660.1 , 643735.9 , 643813.5 , 643886. , - 643960.6 , 644039.56 , 644110.25 , 644186.75 , - 644262.5 , 644336.06 , 644412.9 , 644488.4 ], + expected_outputs=(array([[ 90458.2 , 90470.875, 90480.85 , 90491.1 , + 90500.945, 90510.945, 90521.19 , 90530.95 , + 90540.78 , 90551.16 , 90560.67 , 90570.734, + 90580.73 , 90590.586, 90600.66 , 90610.61 ], + [ 643341.75 , 643434.25 , 643509.75 , 643587.1 , + 643660.1 , 643735.94 , 643813.5 , 643886. , + 643960.6 , 644039.5 , 644110.25 , 644186.75 , + 644262.5 , 644336.06 , 644412.9 , 644488.3 ], [ 1196323.2 , 1196495.6 , 1196636.8 , 1196781. , 1196917.5 , 1197059. , 1197203.9 , 1197339.2 , 1197478.5 , 1197625.8 , 1197757.8 , 1197900.5 , - 1198042. , 1198179.4 , 1198323. , 1198464. ], - [ 1749075.5 , 1749327.9 , 1749534.4 , 1749745.9 , - 1749945.5 , 1750152.8 , 1750365.1 , 1750563.1 , - 1750767.1 , 1750983.1 , 1751176.2 , 1751385.4 , + 1198042. , 1198179.4 , 1198323.1 , 1198464. ], + [ 1749075.5 , 1749327.8 , 1749534.4 , 1749746. , + 1749945.5 , 1750152.8 , 1750365.1 , 1750563. , + 1750767.2 , 1750983.1 , 1751176.2 , 1751385.5 , 1751592.8 , 1751793.8 , 1752004.2 , 1752210.8 ], [ 2302500.5 , 2302832.5 , 2303104.8 , 2303383.5 , 2303646.2 , 2303919.5 , 2304199. , 2304459.8 , - 2304728.5 , 2305013. , 2305267.2 , 2305543. , + 2304728.5 , 2305013. , 2305267.2 , 2305542.8 , 2305816.2 , 2306081. , 2306358.5 , 2306630.5 ], - [ 2855440.2 , 2855852.5 , 2856190.2 , 2856535.5 , - 2856861.5 , 2857200.5 , 2857547.2 , 2857870.5 , + [ 2855440.2 , 2855852.2 , 2856190.2 , 2856535.5 , + 2856861.5 , 2857200.5 , 2857547.2 , 2857870.8 , 2858204.5 , 2858557. , 2858872.5 , 2859214.5 , 2859553.2 , 2859882. , 2860226. , 2860563.5 ], [ 3407472. , 3407964.2 , 3408367.5 , 3408780.2 , - 3409169.5 , 3409574.5 , 3409988.5 , 3410374.5 , - 3410773. , 3411194. , 3411570.5 , 3411979. , - 3412383.5 , 3412776. , 3413186.5 , 3413590. ], + 3409169.5 , 3409574.2 , 3409988.5 , 3410374.5 , + 3410772.8 , 3411194. , 3411570.5 , 3411978.8 , + 3412383.5 , 3412776. , 3413186.8 , 3413590. ], [ 3959847.5 , 3960419. , 3960888. , 3961367.8 , - 3961820.2 , 3962290.8 , 3962772.5 , 3963221.2 , - 3963684.8 , 3964174.2 , 3964612.2 , 3965086.8 , + 3961820.2 , 3962291. , 3962772.5 , 3963221.2 , + 3963684.8 , 3964174.5 , 3964612. , 3965086.8 , 3965557.2 , 3966013.2 , 3966491. , 3966959.5 ], - [ 4515869.5 , 4516521.5 , 4517056. , 4517602. , + [ 4515869.5 , 4516521.5 , 4517056. , 4517602.5 , 4518118. , 4518654.5 , 4519203. , 4519715. , 4520243. , 4520801. , 4521300. , 4521841. , - 4522378. , 4522897. , 4523441.5 , 4523975.5 ], + 4522377.5 , 4522897. , 4523441.5 , 4523975.5 ], [ 5061659. , 5062390. , 5062990. , 5063603.5 , 5064182. , 5064784.5 , 5065401. , 5065975. , 5066567.5 , 5067194. , 5067754. , 5068362. , - 5068964. , 5069547. , 5070159. , 5070759. ], + 5068964. , 5069547. , 5070159. , 5070758.5 ], [ 5621329. , 5622141. , 5622806.5 , 5623487.5 , 5624129.5 , 5624797. , 5625481. , 5626118. , 5626775. , 5627470.5 , 5628092. , 5628765. , - 5629433.5 , 5630080.5 , 5630758.5 , 5631424. ], - [ 6172821. , 6173712. , 6174443. , 6175191. , - 6175896. , 6176630. , 6177381. , 6178080.5 , - 6178803. , 6179566. , 6180248.5 , 6180988. , - 6181722. , 6182432.5 , 6183178. , 6183908. ], + 5629433. , 5630080.5 , 5630758.5 , 5631424. ], + [ 6172820.5 , 6173712. , 6174443. , 6175191. , + 6175896.5 , 6176630. , 6177381. , 6178080.5 , + 6178803. , 6179566. , 6180248.5 , 6180988.5 , + 6181722. , 6182432.5 , 6183177.5 , 6183908. ], [ 6723343.5 , 6724315. , 6725111.5 , 6725927. , 6726696. , 6727495.5 , 6728313.5 , 6729076.5 , - 6729863.5 , 6730696. , 6731440. , 6732246. , + 6729864. , 6730696. , 6731440. , 6732246. , 6733046. , 6733820.5 , 6734632. , 6735428.5 ], [ 7280537. , 7281587.5 , 7282449.5 , 7283331.5 , - 7284163.5 , 7285028.5 , 7285914. , 7286739.5 , - 7287591. , 7288492. , 7289296.5 , 7290169.5 , - 7291035. , 7291873.5 , 7292752.5 , 7293614. ], + 7284163.5 , 7285029. , 7285914. , 7286739. , + 7287591. , 7288492. , 7289297. , 7290169.5 , + 7291035. , 7291873.5 , 7292752. , 7293614. ], [ 7828292. , 7829423. , 7830350. , 7831299.5 , 7832194.5 , 7833125.5 , 7834078.5 , 7834966. , - 7835883. , 7836852. , 7837717.5 , 7838657. , + 7835883. , 7836852. , 7837718. , 7838657. , 7839588. , 7840490. , 7841436. , 7842363.5 ], [ 8384808.5 , 8386019.5 , 8387012.5 , 8388029.5 , 8388988. , 8389985. , 8391005. , 8391956. , @@ -98,54 +98,54 @@ [10055416. , 10056868. , 10058060. , 10059279. , 10060428. , 10061624. , 10062848. , 10063988. , 10065166. , 10066410. , 10067522. , 10068729. , - 10069924. , 10071083. , 10072298. , 10073489. ], - [10595886. , 10597417. , 10598673. , 10599958. , + 10069925. , 10071083. , 10072298. , 10073489. ], + [10595886. , 10597416. , 10598672. , 10599958. , 10601170. , 10602431. , 10603721. , 10604923. , - 10606164. , 10607477. , 10608649. , 10609921. , - 10611182. , 10612404. , 10613684. , 10614941. ], + 10606164. , 10607477. , 10608650. , 10609922. , + 10611182. , 10612404. , 10613684. , 10614940. ], [11135804. , 11137412. , 11138732. , 11140083. , - 11141357. , 11142682. , 11144038. , 11145301. , + 11141357. , 11142682. , 11144038. , 11145302. , 11146606. , 11147985. , 11149218. , 11150554. , - 11151880. , 11153163. , 11154509. , 11155829. ], + 11151880. , 11153164. , 11154509. , 11155829. ], [11686791. , 11688480. , 11689864. , 11691282. , 11692618. , 11694007. , 11695430. , 11696756. , - 11698123. , 11699571. , 11700864. , 11702265. , + 11698124. , 11699570. , 11700864. , 11702265. , 11703656. , 11705003. , 11706414. , 11707799. ], [12263420. , 12265190. , 12266642. , 12268128. , 12269529. , 12270986. , 12272478. , 12273868. , 12275303. , 12276820. , 12278176. , 12279646. , - 12281104. , 12282516. , 12283996. , 12285447. ], + 12281104. , 12282516. , 12283996. , 12285446. ], [12821178. , 12823029. , 12824548. , 12826102. , - 12827567. , 12829092. , 12830652. , 12832105. , - 12833606. , 12835193. , 12836610. , 12838149. , + 12827567. , 12829092. , 12830652. , 12832106. , + 12833606. , 12835192. , 12836610. , 12838148. , 12839673. , 12841150. , 12842699. , 12844217. ], [13362964. , 13364895. , 13366479. , 13368100. , 13369628. , 13371218. , 13372846. , 13374362. , 13375927. , 13377582. , 13379061. , 13380665. , - 13382255. , 13383796. , 13385411. , 13386995. ], - [13902882. , 13904891. , 13906539. , 13908225. , + 13382256. , 13383796. , 13385411. , 13386995. ], + [13902882. , 13904890. , 13906538. , 13908225. , 13909815. , 13911470. , 13913163. , 13914740. , - 13916369. , 13918091. , 13919629. , 13921298. , - 13922953. , 13924556. , 13926236. , 13927884. ], + 13916368. , 13918090. , 13919629. , 13921298. , + 13922952. , 13924556. , 13926236. , 13927884. ], [14443848. , 14445934. , 14447646. , 14449398. , 14451050. , 14452769. , 14454528. , 14456166. , - 14457858. , 14459647. , 14461245. , 14462979. , + 14457858. , 14459647. , 14461246. , 14462979. , 14464698. , 14466363. , 14468108. , 14469820. ], - [15024407. , 15026576. , 15028355. , 15030176. , + [15024406. , 15026576. , 15028355. , 15030176. , 15031893. , 15033679. , 15035507. , 15037210. , - 15038969. , 15040827. , 15042490. , 15044291. , + 15038968. , 15040828. , 15042490. , 15044291. , 15046077. , 15047808. , 15049621. , 15051400. ], [15586096. , 15588347. , 15590193. , 15592082. , 15593863. , 15595716. , 15597613. , 15599380. , - 15601204. , 15603133. , 15604857. , 15606726. , + 15601204. , 15603133. , 15604856. , 15606726. , 15608579. , 15610375. , 15612257. , 15614103. ], [16130043. , 16132373. , 16134285. , 16136242. , 16138087. , 16140006. , 16141970. , 16143800. , 16145690. , 16147688. , 16149473. , 16151409. , - 16153328. , 16155188. , 16157138. , 16159049. ], - [16669961. , 16672369. , 16674345. , 16676367. , - 16678274. , 16680257. , 16682287. , 16684178. , - 16686131. , 16688196. , 16690041. , 16692042. , + 16153328. , 16155188. , 16157138. , 16159050. ], + [16669960. , 16672369. , 16674345. , 16676367. , + 16678274. , 16680258. , 16682287. , 16684178. , + 16686132. , 16688196. , 16690041. , 16692042. , 16694026. , 16695948. , 16697962. , 16699938. ], [17209878. , 17212364. , 17214404. , 17216492. , 17218460. , 17220508. , 17222604. , 17224556. , @@ -174,35 +174,35 @@ [20516874. , 20519838. , 20522270. , 20524760. , 20527106. , 20529548. , 20532046. , 20534374. , 20536776. , 20539318. , 20541588. , 20544052. , - 20546492. , 20548860. , 20551338. , 20553770. ], + 20546492. , 20548858. , 20551338. , 20553770. ], [21056792. , 21059834. , 21062330. , 21064884. , - 21067292. , 21069798. , 21072364. , 21074752. , + 21067292. , 21069800. , 21072364. , 21074752. , 21077218. , 21079826. , 21082156. , 21084684. , - 21087190. , 21089618. , 21092162. , 21094658. ], + 21087190. , 21089618. , 21092164. , 21094660. ], [21596710. , 21599830. , 21602390. , 21605010. , 21607480. , 21610050. , 21612680. , 21615130. , 21617660. , 21620336. , 21622724. , 21625318. , 21627888. , 21630378. , 21632988. , 21635548. ], [22218698. , 22221906. , 22224536. , 22227228. , 22229768. , 22232408. , 22235108. , 22237628. , - 22240228. , 22242976. , 22245434. , 22248094. , - 22250734. , 22253292. , 22255972. , 22258602. ], + 22240228. , 22242976. , 22245432. , 22248094. , + 22250736. , 22253292. , 22255972. , 22258602. ], [22802946. , 22806238. , 22808938. , 22811700. , 22814306. , 22817016. , 22819790. , 22822374. , - 22825044. , 22827864. , 22830386. , 22833120. , - 22835830. , 22838456. , 22841208. , 22843906. ], + 22825044. , 22827864. , 22830384. , 22833120. , + 22835830. , 22838456. , 22841208. , 22843908. ], [23351442. , 23354816. , 23357584. , 23360416. , 23363088. , 23365866. , 23368710. , 23371360. , - 23374094. , 23376988. , 23379572. , 23382374. , + 23374096. , 23376988. , 23379572. , 23382374. , 23385154. , 23387846. , 23390668. , 23393436. ], [23891360. , 23894812. , 23897644. , 23900542. , - 23903274. , 23906118. , 23909028. , 23911738. , + 23903276. , 23906118. , 23909028. , 23911738. , 23914536. , 23917496. , 23920140. , 23923008. , - 23925850. , 23928606. , 23931492. , 23934324. ], + 23925850. , 23928604. , 23931492. , 23934324. ], [24431278. , 24434808. , 24437704. , 24440668. , 24443462. , 24446368. , 24449344. , 24452116. , 24454978. , 24458004. , 24460708. , 24463640. , - 24466548. , 24469364. , 24472318. , 24475214. ], + 24466548. , 24469364. , 24472316. , 24475212. ], [24971196. , 24974804. , 24977764. , 24980792. , 24983648. , 24986620. , 24989662. , 24992494. , 24995420. , 24998512. , 25001276. , 25004274. , @@ -212,56 +212,56 @@ 25535860. , 25539020. , 25541844. , 25544906. , 25547942. , 25550884. , 25553966. , 25556990. ], [26051032. , 26054796. , 26057884. , 26061044. , - 26064022. , 26067122. , 26070296. , 26073250. , - 26076302. , 26079530. , 26082412. , 26085540. , - 26088640. , 26091642. , 26094792. , 26097880. ], - [26590950. , 26594790. , 26597942. , 26601168. , + 26064024. , 26067124. , 26070296. , 26073250. , + 26076302. , 26079528. , 26082412. , 26085540. , + 26088640. , 26091644. , 26094792. , 26097880. ], + [26590950. , 26594792. , 26597944. , 26601168. , 26604210. , 26607374. , 26610612. , 26613628. , 26616744. , 26620038. , 26622980. , 26626172. , 26629336. , 26632402. , 26635616. , 26638768. ], - [27130866. , 27134786. , 27138002. , 27141294. , - 27144396. , 27147626. , 27150930. , 27154008. , - 27157186. , 27160546. , 27163548. , 27166806. , + [27130868. , 27134786. , 27138002. , 27141294. , + 27144396. , 27147624. , 27150930. , 27154008. , + 27157184. , 27160546. , 27163548. , 27166804. , 27170034. , 27173162. , 27176440. , 27179656. ], [27723244. , 27727248. , 27730532. , 27733892. , 27737062. , 27740358. , 27743732. , 27746876. , 27750120. , 27753552. , 27756618. , 27759944. , - 27763240. , 27766436. , 27769782. , 27773064. ], + 27763240. , 27766436. , 27769780. , 27773064. ], [28323220. , 28327310. , 28330664. , 28334094. , - 28337330. , 28340696. , 28344142. , 28347352. , + 28337332. , 28340696. , 28344142. , 28347352. , 28350664. , 28354168. , 28357300. , 28360696. , 28364062. , 28367324. , 28370744. , 28374096. ], - [28885444. , 28889618. , 28893040. , 28896544. , + [28885444. , 28889616. , 28893040. , 28896544. , 28899848. , 28903284. , 28906802. , 28910078. , - 28913462. , 28917038. , 28920234. , 28923702. , - 28927138. , 28930468. , 28933958. , 28937382. ], + 28913460. , 28917038. , 28920236. , 28923702. , + 28927138. , 28930468. , 28933960. , 28937382. ], [29425518. , 29429768. , 29433256. , 29436826. , - 29440192. , 29443694. , 29447276. , 29450614. , - 29454062. , 29457706. , 29460962. , 29464496. , + 29440192. , 29443692. , 29447276. , 29450614. , + 29454062. , 29457706. , 29460964. , 29464496. , 29467996. , 29471390. , 29474946. , 29478434. ], [29965436. , 29969764. , 29973316. , 29976952. , 29980378. , 29983944. , 29987594. , 29990992. , - 29994504. , 29998214. , 30001532. , 30005128. , + 29994504. , 29998216. , 30001532. , 30005128. , 30008694. , 30012148. , 30015770. , 30019322. ], [30505352. , 30509760. , 30513376. , 30517076. , 30520566. , 30524196. , 30527910. , 30531372. , 30534944. , 30538724. , 30542100. , 30545760. , - 30549392. , 30552908. , 30556594. , 30560210. ], + 30549392. , 30552908. , 30556596. , 30560212. ], [31045270. , 31049756. , 31053436. , 31057202. , - 31060752. , 31064446. , 31068228. , 31071750. , + 31060752. , 31064448. , 31068228. , 31071750. , 31075386. , 31079232. , 31082668. , 31086394. , 31090088. , 31093668. , 31097420. , 31101100. ], [31585188. , 31589752. , 31593496. , 31597328. , 31600940. , 31604698. , 31608544. , 31612128. , - 31615828. , 31619740. , 31623236. , 31627026. , + 31615828. , 31619740. , 31623236. , 31627028. , 31630786. , 31634428. , 31638244. , 31641988. ], [32125106. , 32129748. , 32133556. , 32137452. , 32141126. , 32144950. , 32148862. , 32152506. , - 32156270. , 32160248. , 32163804. , 32167660. , + 32156268. , 32160248. , 32163804. , 32167660. , 32171482. , 32175186. , 32179068. , 32182876. ], - [32665024. , 32669742. , 32673614. , 32677578. , + [32665024. , 32669744. , 32673616. , 32677578. , 32681314. , 32685200. , 32689178. , 32692884. , - 32696710. , 32700756. , 32704372. , 32708292. , + 32696712. , 32700756. , 32704372. , 32708292. , 32712180. , 32715946. , 32719894. , 32723766. ], [33221238. , 33226038. , 33229974. , 33234004. , 33237804. , 33241756. , 33245802. , 33249570. , @@ -274,66 +274,72 @@ [34414896. , 34419864. , 34423944. , 34428112. , 34432048. , 34436140. , 34440328. , 34444232. , 34448260. , 34452520. , 34456324. , 34460456. , - 34464548. , 34468512. , 34472672. , 34476744. ], + 34464548. , 34468512. , 34472672. , 34476748. ], [34824696. , 34829728. , 34833856. , 34838080. , 34842064. , 34846208. , 34850448. , 34854396. , 34858476. , 34862792. , 34866644. , 34870824. , 34874968. , 34878984. , 34883192. , 34887320. ]], dtype=float32),), mlir_module_text=r""" -#loc4 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":33:0) -#loc11 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=matmul keep_unused=False inline=False]"(#loc4)) -#loc16 = loc("jit(func)/jit(main)/jit(matmul)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=wrapped keep_unused=False inline=False]"(#loc4)) -#loc17 = loc("jit(func)/jit(main)/jit(matmul)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=apply_kernel keep_unused=False inline=False]"(#loc4)) +#loc6 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":74:12) +#loc14 = loc("jit(func)/jit(main)/pjit"(#loc6)) module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<64x16xf32> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512, 528, 544, 560, 576, 592, 608, 624, 640, 656, 672, 688, 704, 720, 736, 752, 768, 784, 800, 816, 832, 848, 864, 880, 896, 912, 928, 944, 960, 976, 992, 1008]> : tensor<64xi32> loc(#loc) - %1 = stablehlo.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : tensor<16xi32> loc(#loc) - %2 = stablehlo.iota dim = 0 : tensor<524288xf32> loc(#loc6) - %3 = stablehlo.reshape %2 : (tensor<524288xf32>) -> tensor<1024x512xf32> loc(#loc7) - %4 = stablehlo.constant dense<1.000000e-03> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1024x512xf32> loc(#loc8) - %6 = stablehlo.multiply %5, %3 : tensor<1024x512xf32> loc(#loc8) - %7 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1024x512xf32> loc(#loc9) - %9 = stablehlo.add %8, %6 : tensor<1024x512xf32> loc(#loc9) - %10 = stablehlo.slice %9 [0:512, 0:256] : (tensor<1024x512xf32>) -> tensor<512x256xf32> loc(#loc10) - %11 = call @matmul(%9, %10) : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc11) - %12 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<64xi32>) -> tensor<64x16x1xi32> loc(#loc12) - %13 = stablehlo.broadcast_in_dim %1, dims = [1] : (tensor<16xi32>) -> tensor<64x16x1xi32> loc(#loc13) - %14 = stablehlo.concatenate %12, %13, dim = 2 : (tensor<64x16x1xi32>, tensor<64x16x1xi32>) -> tensor<64x16x2xi32> loc(#loc14) - %15 = "stablehlo.gather"(%11, %14) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<1024x256xf32>, tensor<64x16x2xi32>) -> tensor<64x16xf32> loc(#loc15) - return %15 : tensor<64x16xf32> loc(#loc) + func.func public @main() -> (tensor<64x16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %c_0 = stablehlo.constant dense<16> : tensor loc(#loc) + %cst = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc) + %cst_1 = stablehlo.constant dense<1.000000e-03> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<524288xf32> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<524288xf32>) -> tensor<1024x512xf32> loc(#loc10) + %2 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<1024x512xf32> loc(#loc11) + %3 = stablehlo.multiply %2, %1 : tensor<1024x512xf32> loc(#loc11) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<1024x512xf32> loc(#loc12) + %5 = stablehlo.add %4, %3 : tensor<1024x512xf32> loc(#loc12) + %6 = stablehlo.slice %5 [0:512, 0:256] : (tensor<1024x512xf32>) -> tensor<512x256xf32> loc(#loc13) + %7 = call @matmul(%5, %6) : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc14) + %8 = stablehlo.iota dim = 0 : tensor<64xi32> loc(#loc15) + %9 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor<64xi32> loc(#loc16) + %10 = stablehlo.multiply %9, %8 : tensor<64xi32> loc(#loc16) + %11 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<64xi32> loc(#loc17) + %12 = stablehlo.add %11, %10 : tensor<64xi32> loc(#loc17) + %13 = stablehlo.iota dim = 0 : tensor<16xi32> loc(#loc15) + %14 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor<16xi32> loc(#loc16) + %15 = stablehlo.multiply %14, %13 : tensor<16xi32> loc(#loc16) + %16 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<16xi32> loc(#loc17) + %17 = stablehlo.add %16, %15 : tensor<16xi32> loc(#loc17) + %18 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<64xi32>) -> tensor<64x16x1xi32> loc(#loc18) + %19 = stablehlo.broadcast_in_dim %17, dims = [1] : (tensor<16xi32>) -> tensor<64x16x1xi32> loc(#loc18) + %20 = stablehlo.concatenate %18, %19, dim = 2 : (tensor<64x16x1xi32>, tensor<64x16x1xi32>) -> tensor<64x16x2xi32> loc(#loc19) + %21 = "stablehlo.gather"(%7, %20) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = array}> : (tensor<1024x256xf32>, tensor<64x16x2xi32>) -> tensor<64x16xf32> loc(#loc20) + return %21 : tensor<64x16xf32> loc(#loc) } loc(#loc) - func.func private @matmul(%arg0: tensor<1024x512xf32> loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=matmul keep_unused=False inline=False]"(#loc4)), %arg1: tensor<512x256xf32> loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=matmul keep_unused=False inline=False]"(#loc4))) -> tensor<1024x256xf32> { - %0 = call @wrapped(%arg0, %arg1) : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc16) - return %0 : tensor<1024x256xf32> loc(#loc11) - } loc(#loc11) - func.func private @wrapped(%arg0: tensor<1024x512xf32> loc("jit(func)/jit(main)/jit(matmul)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=wrapped keep_unused=False inline=False]"(#loc4)), %arg1: tensor<512x256xf32> loc("jit(func)/jit(main)/jit(matmul)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=wrapped keep_unused=False inline=False]"(#loc4))) -> tensor<1024x256xf32> { - %0 = call @apply_kernel(%arg0, %arg1) : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc17) - return %0 : tensor<1024x256xf32> loc(#loc16) - } loc(#loc16) - func.func private @apply_kernel(%arg0: tensor<1024x512xf32> loc("jit(func)/jit(main)/jit(matmul)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=apply_kernel keep_unused=False inline=False]"(#loc4)), %arg1: tensor<512x256xf32> loc("jit(func)/jit(main)/jit(matmul)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=apply_kernel keep_unused=False inline=False]"(#loc4))) -> tensor<1024x256xf32> { - %0 = stablehlo.custom_call @tpu_custom_call(%arg0, %arg1) {backend_config = "{\22custom_call_config\22: {\22body\22: \22TUzvUgFNTElSZ29vZ2xlMy10cnVuawABLwkBAwUHAQMJAwUDCwUNDQ8RExUXBwMZA44DIgMhAfkbDw8LKxMTBxcjEwsLCwsTCwsLhQsLCxsLMwsPEw87CxMLC1MLDwsLFxsLUxsLUxsLUxsbGw8TEwsLExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTCxMTExMXBQthkWlpeQGPExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXDxcPFw8XDxcPFw8XDxcTHxMfDxcfCwsLCwtTCxMBIRsHHw8HHw8nJycLIx8nJycCwhEDBTEzNTcd7R8dcR8FGwMHEgMWAzEzNTcdGgMfHQIDHx8DAwYDOQMFCgM7DgM7AwMXOQUdBR8FIQEBF3NDAQUjBSUNGWFmZmluZV9tYXA8KGQwLCBkMSkgLT4gKGQwLCBkMSk+AAUnBSkFKwMFFx0HawUtIxUREQEBAQEBAQEBBS8RBwUBAwICERUAAw0/QRlDRUdJSxtNT1EFMQEF+/sNFwUzIw0FIQQAAAAAAAAAAQAAAAAAAAAFNRENAQU3BTkBB1NZXwMFIVUjVwkpIw0FIQABAAAAAAAAAAIAAAAAAAADBSFbI10JKyMNBSEAAgAAAAAAAAABAAAAAAAAAwUhYSNjCS0jDQUhAAEAAAAAAAAAAQAAAAAAAAMFGSUbKQMFGSUbKwMFGSUbLREHAQMDB28RA8IPBTsFPQMDB3cRA4IPAwMHexEDQg8DAwd/EQMCDwMDB4MRA8IOAwMHhxEDgg4DAweLEQNCDgMDB48RAwIOAwMHkxEDwg0DAweXEQOCDQMDB5sRA0INAwMHnxEDAg0DAwejEQPCDAMDB6cRA4IMAwMHqxEDQgwDAwevEQPCCwMDB7MRA4ILAwMHtxEDQgsDAwe7EQMCCwMDB78RA8IKAwMHwxEDggoDAwfHEQNCCgMDB8sRAwIKAwMHzxEDwgkDAwfTEQOCCQMDB9cRA0IJAwMH2xEDAgkDAwffEQPCCAMDB+MRA4IIAwMH5xEDQggDAwfrEQMCDAU/AwMH8REDwgcDAwf1EQOCBwMDBwYCI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPGFyYml0cmFyeT4AI3RwdS50aWxlZDwoOCwxMjgpLFsyLDFdPgAjdHB1LnRpbGVkPCg4LDEyOCksWzQsMV0+ACN0cHUudnBhZDwiMzIsezAsMH0sKDgsMTI4KSI+ABEDQgcDAwcOAhEDAgcDAwcWAhEDwgYDAwceAhEDggYDAwcmAhEDQgYDAwcuAhEDAgYDAwc2AhEDwgUDAwc+AhEDggUDAwdGAhEDQgUDAwdOAhEDAgUDAwdWAhEDwgQDAwdeAhEDggQDAwdmAhEDQgQDAwduAhEDwgMDAwd2AhEDggMDAwd+AhEDQgMDAweGAhEDAgMDAweOAhEDwgIDAweWAhEDggIDAweeAhEDQgIDAwemAhEDAgIDAweuAhED4QMDB7YCEQPBAwMHvgIRA6EDAwfGAhEDgQMDB84CEQNhAwMH1gIRA0EDAwfeAhEDIQMDB+YCEQMCBAMFFx0H7gIRAwIIAwUXHQf2AhEDAQMDB/4CJQEJAAAAAAVBBUMFRQVHBUkjBwkhAQAAAAEAAAACAAAAAAAAAAVLAwMXHScFIQIECQMnBQIIAgQJAQICCycFAgQCBAkBAgQX+QUCCAIQCf8X+QUCEAIICf0X+QUCCAIICf0BCQULBwcPERMBBQUHBwUHBxf5BQIIAhAJJxf5BQIQAggJJxf5BQIIAggJJwR6UQUBEA8HAwERAxEPPQcDlglODQsHDwcPDw8RDxMPEwMFbQMDEwMFdQMDEwMFeQMDEwMFfQMDEwMFgQMDEwMFhQMDEwMFiQMDEwMFjQMDEwMFkQMDEwMFlQMDEwMFmQMDEwMFnQMDEwMFoQMDEwMFpQMDEwMFqQMDEwMFrQMDEwMFsQMDEwMFtQMDEwMFuQMDEwMFvQMDEwMFwQMDEwMFxQMDEwMFyQMDEwMFzQMDEwMF0QMDEwMF1QMDEwMF2QMDEwMF3QMDEwMF4QMDEwMF5QMDEwMD6QMDEwMD7wMDEwMD8wMDEwMD9wMDEwMDCgIDAxMDAxICAwMTAwMaAgMDEwMDIgIDAxMDAyoCAwMTAwMyAgMDEwMDOgIDAxMDA0ICAwMTAwNKAgMDEwMDUgIDAxMDA1oCAwMTAwNiAgMDEwMDagIDAxMDA3ICAwMTAwN6AgMDEwMDggIDAxMDA4oCAwMTAwOSAgMDEwMDmgIDAxMDA6ICAwMTAwOqAgMDEwMDsgIDAxMDA7oCAwMTAwPCAgMDEwMDygIDAxMDA9ICAwMTAwPaAgMDEwMD4gIDAxMDA+oCAwMTAwPyAgMDEwMN+gIDAREGDwMbAwURBg8DHQMHEQYPAx8DCQcHAwEDAQeNiYkHBwMBAwEHjYmFBwcDAQMBB42DiQcHAwEDAQeNg4UHBwMBAwEHjYGJBwcDAQMBB42BhQcHAwEDAQeNf4kHBwMBAwEHjX+FBwcDAQMBB419iQcHAwEDAQeNfYUHBwMBAwEHjXuJBwcDAQMBB417hQcHAwEDAQeNeYkHBwMBAwEHjXmFBwcDAQMBB413iQcHAwEDAQeNd4UHBwMBAwEHjXWJBwcDAQMBB411hQcHAwEDAQeNc4kHBwMBAwEHjXOFBwcDAQMBB41xiQcHAwEDAQeNcYUHBwMBAwEHjW+JBwcDAQMBB41vhQcHAwEDAQeNbYkHBwMBAwEHjW2FBwcDAQMBB41riQcHAwEDAQeNa4UHBwMBAwEHjWmJBwcDAQMBB41phQcHAwEDAQeNZ4kHBwMBAwEHjWeFBwcDAQMBB42FiQcHAwEDAQeNhYUHBwMBAwEHjWWJBwcDAQMBB41lhQcHAwEDAQeNY4kHBwMBAwEHjWOFBwcDAQMBB41hiQcHAwEDAQeNYYUHBwMBAwEHjV+JBwcDAQMBB41fhQcHAwEDAQeNXYkHBwMBAwEHjV2FBwcDAQMBB41biQcHAwEDAQeNW4UHBwMBAwEHjVmJBwcDAQMBB41ZhQcHAwEDAQeNV4kHBwMBAwEHjVeFBwcDAQMBB41ViQcHAwEDAQeNVYUHBwMBAwEHjVOJBwcDAQMBB41ThQcHAwEDAQeNUYkHBwMBAwEHjVGFBwcDAQMBB41PiQcHAwEDAQeNT4UHBwMBAwEHjU2JBwcDAQMBB41NhQcHAwEDAQeNS4kHBwMBAwEHjUuFBwcDAQMBB41JiQcHAwEDAQeNSYUHBwUBAwEHj4mJBwcFAQMBB4+JhQcHBQEDAQePg4kHBwUBAwEHj4OFBwcFAQMBB4+BiQcHBQEDAQePgYUHBwUBAwEHj3+JBwcFAQMBB49/hQcHBQEDAQePfYkHBwUBAwEHj32FBwcFAQMBB497iQcHBQEDAQePe4UHBwUBAwEHj3mJBwcFAQMBB495hQcHBQEDAQePd4kHBwUBAwEHj3eFBwcFAQMBB491iQcHBQEDAQePdYUHBwUBAwEHj3OJBwcFAQMBB49zhQcHBQEDAQePcYkHBwUBAwEHj3GFBwcFAQMBB49viQcHBQEDAQePb4UHBwUBAwEHj22JBwcFAQMBB49thQcHBQEDAQePa4kHBwUBAwEHj2uFBwcFAQMBB49piQcHBQEDAQePaYUHBwUBAwEHj2eJBwcFAQMBB49nhQcHBQEDAQePhYkHBwUBAwEHj4WFBwcFAQMBB49liQcHBQEDAQePZYUHBwUBAwEHj2OJBwcFAQMBB49jhQcHBQEDAQePYYkHBwUBAwEHj2GFBwcFAQMBB49fiQcHBQEDAQePX4UHBwUBAwEHj12JBwcFAQMBB49dhQcHBQEDAQePW4kHBwUBAwEHj1uFBwcFAQMBB49ZiQcHBQEDAQePWYUHBwUBAwEHj1eJBwcFAQMBB49XhQcHBQEDAQePVYkHBwUBAwEHj1WFBwcFAQMBB49TiQcHBQEDAQePU4UHBwUBAwEHj1GJBwcFAQMBB49RhQcHBQEDAQePT4kHBwUBAwEHj0+FBwcFAQMBB49NiQcHBQEDAQePTYUHBwUBAwEHj0uJBwcFAQMBB49LhQcHBQEDAQePSYkHBwUBAwEHj0mFBwcDAQMBB42JhwcHAwEDAQeNiUcHBwMBAwEHjYOHBwcDAQMBB42DRwcHAwEDAQeNgYcHBwMBAwEHjYFHBwcDAQMBB41/hwcHAwEDAQeNf0cHBwMBAwEHjX2HBwcDAQMBB419RwcHAwEDAQeNe4cHBwMBAwEHjXtHBwcDAQMBB415hwcHAwEDAQeNeUcHBwMBAwEHjXeHBwcDAQMBB413RwcHAwEDAQeNdYcHBwMBAwEHjXVHBwcDAQMBB41zhwcHAwEDAQeNc0cHBwMBAwEHjXGHBwcDAQMBB41xRwcHAwEDAQeNb4cHBwMBAwEHjW9HBwcDAQMBB41thwcHAwEDAQeNbUcHBwMBAwEHjWuHBwcDAQMBB41rRwcHAwEDAQeNaYcHBwMBAwEHjWlHBwcDAQMBB41nhwcHAwEDAQeNZ0cHBwMBAwEHjYWHBwcDAQMBB42FRwcHAwEDAQeNZYcHBwMBAwEHjWVHBwcDAQMBB41jhwcHAwEDAQeNY0cHBwMBAwEHjWGHBwcDAQMBB41hRwcHAwEDAQeNX4cHBwMBAwEHjV9HBwcDAQMBB41dhwcHAwEDAQeNXUcHBwMBAwEHjVuHBwcDAQMBB41bRwcHAwEDAQeNWYcHBwMBAwEHjVlHBwcDAQMBB41XhwcHAwEDAQeNV0cHBwMBAwEHjVWHBwcDAQMBB41VRwcHAwEDAQeNU4cHBwMBAwEHjVNHBwcDAQMBB41RhwcHAwEDAQeNUUcHBwMBAwEHjU+HBwcDAQMBB41PRwcHAwEDAQeNTYcHBwMBAwEHjU1HBwcDAQMBB41LhwcHAwEDAQeNS0cHBwMBAwEHjUmHBwcDAQMBB41JRwcHBQEDAQePh4kHBwUBAwEHj4eFBwcFAQMBB49FiQcHBQEDAQePRYUHBwUBAwEHj0OJBwcFAQMBB49DhQcHBQEDAQePQYkHBwUBAwEHj0GFBwcFAQMBB48/iQcHBQEDAQePP4UHBwUBAwEHjz2JBwcFAQMBB489hQcHBQEDAQePO4kHBwUBAwEHjzuFBwcFAQMBB485iQcHBQEDAQePOYUHBwUBAwEHjzeJBwcFAQMBB483hQcHBQEDAQePNYkHBwUBAwEHjzWFBwcFAQMBB48ziQcHBQEDAQePM4UHBwUBAwEHjzGJBwcFAQMBB48xhQcHBQEDAQePL4kHBwUBAwEHjy+FBwcFAQMBB48tiQcHBQEDAQePLYUHBwUBAwEHjyuJBwcFAQMBB48rhQcHBQEDAQePKYkHBwUBAwEHjymFBwcFAQMBB49HiQcHBQEDAQePR4UHBwUBAwEHjyeJBwcFAQMBB48nhQcHBQEDAQePJYkHBwUBAwEHjyWFBwcFAQMBB48jiQcHBQEDAQePI4UHBwUBAwEHjyGJBwcFAQMBB48hhQcHBQEDAQePH4kHBwUBAwEHjx+FBwcFAQMBB48diQcHBQEDAQePHYUHBwUBAwEHjxuJBwcFAQMBB48bhQcHBQEDAQePGYkHBwUBAwEHjxmFBwcFAQMBB48XiQcHBQEDAQePF4UHBwUBAwEHjxWJBwcFAQMBB48VhQcHBQEDAQePE4kHBwUBAwEHjxOFBwcFAQMBB48RiQcHBQEDAQePEYUHBwUBAwEHjw+JBwcFAQMBB48PhQcHBQEDAQePDYkHBwUBAwEHjw2FBwcFAQMBB48LiQcHBQEDAQePC4ULBw0RAwVBJgMuAzYDPgNGA04DVgNeA2YDbgN2A34DhgOOA5YDngOmA64DtgO+A8YDzgPWA94D5gPuA/YD/gMGBA4EFgQeBAsHDREDBUEqAzIDOgNCA0oDUgNaA2IDagNyA3oDggOKA5IDmgOiA6oDsgO6A8IDygPSA9oD4gPqA/ID+gMCBAoEEgQaBCIECwcNEQMLISYELgQ2BD4ERgROBFYEXgRmBG4EdgR+BIYEjgSWBJ4ECwcNEQMFQYuLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLDQcNEwMFByYFLgUyBQ8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEDNgULBw0RAwshpgSuBLYEvgTGBM4E1gTeBOYE7gT2BP4EBgUOBRYFHgULBw0RAwVBOgU+BUIFRgVKBU4FUgVWBVoFXgViBWYFagVuBXIFdgV6BX4FggWGBYoFjgWSBZYFmgWeBaIFpgWqBa4FsgW2BQ0HDRMDBQcqBboFvgUPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA8IFCwcNEQMLISoEMgQ6BEIESgRSBFoEYgRqBHIEegSCBIoEkgSaBKIECwcNEQMFQYuLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLDQcNEwMFByYFRgZKBg8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEDTgYLBw0RAwshqgSyBLoEwgTKBNIE2gTiBOoE8gT6BAIFCgUSBRoFIgULBw0RAwVBUgZWBloGXgZiBmYGagZuBnIGdgZ6Bn4GggaGBooGjgaSBpYGmgaeBqIGpgaqBq4Gsga2BroGvgbCBsYGygbOBg0HDRMDBQcqBdIG1gYPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA9oGCwcNEQMFQZOXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+/8GAg4CFgIeAgsHDREDBUGVmZ2hpamtsbW5vcHFyc3R1dnd4eXp7fH1+f0CAgoCEgIaAiICCwcNEQMLISYCLgI2Aj4CRgJOAlYCXgJmAm4CdgJ+AoYCjgKWAp4CCwcNEQMFQcYFygXOBdIF1gXaBd4F4gXmBeoF7gXyBfYF+gX+BQIGBgYKBg4GEgYWBhoGHgYiBiYGKgYuBjIGNgY6Bj4GQgYNBw0TAwUHXgdmB2oHDwcNFUEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQNuBwsHDREDCyGmAq4CtgK+AsYCzgLWAt4C5gLuAvYC/gIGAw4DFgMeAwsHDREDBUFyB3YHegd+B4IHhgeKB44HkgeWB5oHngeiB6YHqgeuB7IHtge6B74HwgfGB8oHzgfSB9YH2gfeB+IH5gfqB+4HDQcNEwMFB2IH8gf2Bw8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQED+gcLBw0RAwshKgIyAjoCQgJKAlICWgJiAmoCcgJ6AoICigKSApoCogILBw0RAwVB3gbiBuYG6gbuBvIG9gb6Bv4GAgcGBwoHDgcSBxYHGgceByIHJgcqBy4HMgc2BzoHPgdCB0YHSgdOB1IHVgdaBw0HDRMDBQdeB34IgggPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA4YICwcNEQMLIaoCsgK6AsICygLSAtoC4gLqAvIC+gICAwoDEgMaAyIDCwcNEQMFQYoIjgiSCJYImgieCKIIpgiqCK4Isgi2CLoIvgjCCMYIygjOCNII1gjaCN4I4gjmCOoI7gjyCPYI+gj+CAIJBgkNBw0TAwUHYgcKCQ4JDwcNFUEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQMSCQkFCwkJ/geRiYkJBQsJCRYJkYmFCQULCQkCCJGDiQkFCwkJGgmRg4UJBQsJCQYIkYGJCQULCQkeCZGBhQkFCwkJCgiRf4kJBQsJCSIJkX+FCQULCQkOCJF9iQkFCwkJJgmRfYUJBQsJCRIIkXuJCQULCQkqCZF7hQkFCwkJFgiReYkJBQsJCS4JkXmFCQULCQkaCJF3iQkFCwkJMgmRd4UJBQsJCR4IkXWJCQULCQk2CZF1hQkFCwkJIgiRc4kJBQsJCToJkXOFCQULCQkmCJFxiQkFCwkJPgmRcYUJBQsJCSoIkW+JCQULCQlCCZFvhQkFCwkJLgiRbYkJBQsJCUYJkW2FCQULCQkyCJFriQkFCwkJSgmRa4UJBQsJCTYIkWmJCQULCQlOCZFphQkFCwkJOgiRZ4kJBQsJCVIJkWeFCQULCQk+CJGFiQkFCwkJVgmRhYUJBQsJCUIIkWWJCQULCQlaCZFlhQkFCwkJRgiRY4kJBQsJCV4JkWOFCQULCQlKCJFhiQkFCwkJYgmRYYUJBQsJCU4IkV+JCQULCQlmCZFfhQkFCwkJUgiRXYkJBQsJCWoJkV2FCQULCQlWCJFbiQkFCwkJbgmRW4UJBQsJCVoIkVmJCQULCQlyCZFZhQkFCwkJXgiRV4kJBQsJCXYJkVeFCQULCQliCJFViQkFCwkJegmRVYUJBQsJCWYIkVOJCQULCQl+CZFThQkFCwkJagiRUYkJBQsJCYIJkVGFCQULCQluCJFPiQkFCwkJhgmRT4UJBQsJCXIIkU2JCQULCQmKCZFNhQkFCwkJdgiRS4kJBQsJCY4JkUuFCQULCQl6CJFJiQkFCwkJkgmRSYUFAQ8eAwMRD2UHAwcLBQcPBw8TAw8vAwcFBA8FAQUDEQ9nBwMHCwUHDwcPEwMPLwMHBQQPBQUDAxEPaQcDBQcFBw8HDwUEDwUBAwYDAQUBAE4VTXYDKR0dF04C/gOB/gMdCyEjKR8bGRkZHSUTHRUNEykfDxsNCw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAbW9kdWxlAHJldHVybgBsb2FkAHN0b3JlAHJvbGxfdmVjdG9ycwBtYXRtdWwAdW5yb2xsX3ZlY3RvcnMAZXJhc2VfbWVtcmVmX2xheW91dABjb25zdGFudAB2YWx1ZQBpbl9sYXlvdXQAZnVuY3Rpb25fdHlwZQBzeW1fbmFtZQB0cmFuc2Zvcm1faW5kaWNlcwB3aW5kb3dfYm91bmRzAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHN1YmxhbmVfbWFzawBzdWJsYW5lX3N0cmlkZQBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAG1haW4Ad2luZG93X3BhcmFtcwAvbWFza2VkX2xvYWRbbWFza2VkPUZhbHNlIGNhY2hlX21vZGlmaWVyPSBldmljdGlvbl9wb2xpY3k9IGlzX3ZvbGF0aWxlPUZhbHNlIGFyZ3NfdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhbXSwgUHlUcmVlRGVmKFtDdXN0b21Ob2RlKFNsaWNlWyhGYWxzZSwgMjU2KV0sIFsqXSksIEN1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSldKSwgW1RydWUsIFRydWVdLCAoNTEyLCAyNTYpLCAoKSldLCBbKl0pLCkpXQB0aGlyZF9wYXJ0eS9weS9qYXhfdHJpdG9uL2dvb2dsZS9wYWxsYXNfdHB1L2JhY2tfY29tcGF0X3Rlc3QucHkAL21hc2tlZF9sb2FkW21hc2tlZD1GYWxzZSBjYWNoZV9tb2RpZmllcj0gZXZpY3Rpb25fcG9saWN5PSBpc192b2xhdGlsZT1GYWxzZSBhcmdzX3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoW10sIFB5VHJlZURlZihbQ3VzdG9tTm9kZShTbGljZVsoVHJ1ZSwgMCwgMjU2KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoRmFsc2UsIDI1NildLCBbKl0pXSksIFtUcnVlLCBUcnVlXSwgKDI1NiwgNTEyKSwgKCkpXSwgWypdKSwpKV0AL2RvdF9nZW5lcmFsW2RpbWVuc2lvbl9udW1iZXJzPSgoKDEsKSwgKDAsKSksICgoKSwgKCkpKSBwcmVjaXNpb249KDxQcmVjaXNpb24uREVGQVVMVDogMD4sIDxQcmVjaXNpb24uREVGQVVMVDogMD4pIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0Ab3V0X2xheW91dAB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAb3BlcmFuZFNlZ21lbnRTaXplcwAvbWFza2VkX3N3YXBbbWFza2VkPUZhbHNlIGV2aWN0aW9uX3BvbGljeT0gYXJnc190cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFtdLCBQeVRyZWVEZWYoW0N1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSksIEN1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSldKSwgW1RydWUsIFRydWVdLCAoMjU2LCAyNTYpLCAoKSldLCBbXSksKSldAA==\22}}", kernel_name = "func", operand_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc18) - return %0 : tensor<1024x256xf32> loc(#loc17) - } loc(#loc17) + func.func private @matmul(%arg0: tensor<1024x512xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc6)), %arg1: tensor<512x256xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc6))) -> (tensor<1024x256xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @tpu_custom_call(%arg0, %arg1) {backend_config = "{\22custom_call_config\22: {\22body\22: \22TUzvUgFNTElSZ29vZ2xlMy10cnVuawABLQkBAwUHAQMJAxkLDQ8RExUXGRsdHyED58ETAbkHEwsTCwsLDwsPDw8LC1MLDw8PDwsPDw8LCwsLExMPDxMPGwsPC0MLFwuFC3MLCwsLFxsLGwsbCxsbGw8LExMPEw8LCxMPExMTHwsTGwsLEwsPCxMTEwsTDwsTEwUHjZFhBwNZARMPBx8nDwcLKyMCZggfAwMLiwUjAwMLdwUlBScFKR15ewUrHSmnHSmrHSm3BS0FLyMJBSEAAQAAAAAAAAABAAAAAAAADREdhzkdETsdEY0dEY8FMR0RqREJAREJBQUzBTUFNwU5FwU7BxcFQyMdlZcRDQAXrRcLHbO1AwVHSQlLBTsRCQ0FPQMPT1ENU1dZWy1dLwlfYWMFPwEHubm7DQ9hZmZpbmVfbWFwPChkMCwgZDEpIC0+IChkMCwgZDEpPgAFQSMJBzEEAAAAAAAAAAEAAAAAAAAAAgAAAAAAAAAFQwVFBUcFSQEHZWltAwUZZxsdCTEDBRlrGx0JMwMFGW8bHQk1AwUNHwkxAwUNHwkzAwUNHwk1EQEBBUsXBTsXAwMLfxEBBQMDNy0dhTkFTQVPAwM3LxEDARcFRQ0XBUcNAwMLkyUFCQAAAAAFURcFQ0EDBZs/nT8FUwVVAwOhvwVXHaU7BVkXBUMFFwVRKRcFUQUFWwMDC7ETCwEFXRcFPycXBT8JI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPHBhcmFsbGVsPgAjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8YXJiaXRyYXJ5PgAjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAjYXJpdGguZmFzdG1hdGg8bm9uZT4AAQICAycFAggCCAsXvQUCCAIIC1UBAgQLAQkFDwEBAQcHBwcBBQcBAQEFAQEEIgYFAREBRQcDAREHEQFNBwNHgw8BAQEBAQEHAQcBBwEHAQMDDwcDAQMDD30DAQMDDwcDAQ0HD4EDDQUFDxEGgwMBAxUDAyEHAwENByGJAw0FFxkTFCEDGwkDCx0DA0OvAwsZBkMDBQNHAwMXAwMDAwMXAwMDBQYXAwUHDUtNCwQXCUkNS00PAEEDAQUPAEEDAyMDAwMDAyMDAwMFBiMDBQcNHR8DAyUDAwMDAyUDAwMFBiUDBQcHIyUDAycDAwMDAycDAwMFBicDBQcJKSsDAz2RAwUVBz2ZAwUHJy0vFwejnwMFBSExAwMTAwMDAwMTAwMDBQYTAwUHDTU3CwQTCTMNNTcDAysDAwMDAysDAwMFBisDBQcNOz0DAxUDAwMDAxUDAwMFBhUDBQcLQUMLBBUJPwtBQwkAAQcRAXEHAwkLBwEBAQEBAQMDAQcDAQkEAQUBBQcRAXMHAwkLBwEBAQEBAQMDAQcDAQkEAQUFAwcRAXUHAwkLBwEBAQEBAQMDAQcDAQkEAQUBAwYDAQUBAO4JXyUFCxMdHRsNLQkdCyMhIykdLRUZGRkNHSULHQ0TcyMXFw8ZFRcbGRUZHw8NCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQBhcml0aABtb2R1bGUAYXJpdGguY29uc3RhbnQAdmVjdG9yLmxvYWQAZnVuYy5mdW5jAGZ1bmMucmV0dXJuAHZlY3Rvci5zdG9yZQBhcml0aC5jbXBpAHNjZi55aWVsZABhcml0aC5leHR1aQBzY2YuaWYAdHB1Lm1hdG11bABhcml0aC5hZGRmAHZlY3Rvci5icm9hZGNhc3QAdGhpcmRfcGFydHkvcHkvamF4L2V4cGVyaW1lbnRhbC9wYWxsYXMvb3BzL3RwdS9tYXRtdWwucHkAc3ltX25hbWUAdmFsdWUAZnVuY3Rpb25fdHlwZQAvZ2V0AHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAL3N3YXAAdHJhbnNmb3JtXzAAdHJhbnNmb3JtXzEAdHJhbnNmb3JtXzIAcHJlZGljYXRlAHN0YWJsZV9tb3NhaWMudmVyc2lvbgBtYXRtdWxfa2VybmVsAGRpbWVuc2lvbl9zZW1hbnRpY3MAaXRlcmF0aW9uX2JvdW5kcwBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBtYWluAHdpbmRvd19wYXJhbXMAL2VxAC9jb252ZXJ0X2VsZW1lbnRfdHlwZQAvY29uZAAvZG90X2dlbmVyYWwAdHJhbnNwb3NlX2xocwB0cmFuc3Bvc2VfcmhzAGZhc3RtYXRoAC9hZGQALQAvYnJvYWRjYXN0X2luX2RpbQA=\22, \22serialization_format\22: 1, \22needs_layout_passes\22: true}, \22implicit_sharding\22: {\22type\22: \22MANUAL\22}}", kernel_name = "matmul_kernel", operand_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc21) + return %0 : tensor<1024x256xf32> loc(#loc14) + } loc(#loc14) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":30:0) -#loc2 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":31:0) -#loc3 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":32:0) -#loc5 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":35:0) -#loc6 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(524288,) dimension=0]"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/reshape[new_sizes=(1024, 512) dimensions=None]"(#loc2)) -#loc8 = loc("jit(func)/jit(main)/mul"(#loc1)) -#loc9 = loc("jit(func)/jit(main)/add"(#loc1)) -#loc10 = loc("jit(func)/jit(main)/slice[start_indices=(0, 0) limit_indices=(512, 256) strides=None]"(#loc3)) -#loc12 = loc("jit(func)/jit(main)/broadcast_in_dim[shape=(64, 16, 1) broadcast_dimensions=(0,)]"(#loc5)) -#loc13 = loc("jit(func)/jit(main)/broadcast_in_dim[shape=(64, 16, 1) broadcast_dimensions=(1,)]"(#loc5)) -#loc14 = loc("jit(func)/jit(main)/concatenate[dimension=2]"(#loc5)) -#loc15 = loc("jit(func)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1)) slice_sizes=(1, 1) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc5)) -#loc18 = loc("jit(func)/jit(main)/jit(matmul)/jit(wrapped)/jit(apply_kernel)/tpu_custom_call[config=CustomCallBackendConfig() kernel_name=func kernel_regeneration_metadata=None out_avals=(ShapedArray(float32[1024,256]),)]"(#loc4)) +#loc1 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":71:25) +#loc2 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":72:43) +#loc3 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":71:17) +#loc4 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":71:10) +#loc5 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":73:10) +#loc7 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":76:13) +#loc8 = loc("third_party/py/jax/experimental/pallas/ops/tpu/matmul.py":68:9) +#loc9 = loc("jit(func)/jit(main)/iota"(#loc1)) +#loc10 = loc("jit(func)/jit(main)/reshape"(#loc2)) +#loc11 = loc("jit(func)/jit(main)/mul"(#loc3)) +#loc12 = loc("jit(func)/jit(main)/add"(#loc4)) +#loc13 = loc("jit(func)/jit(main)/slice"(#loc5)) +#loc15 = loc("jit(func)/jit(main)/iota"(#loc7)) +#loc16 = loc("jit(func)/jit(main)/mul"(#loc7)) +#loc17 = loc("jit(func)/jit(main)/add"(#loc7)) +#loc18 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc7)) +#loc19 = loc("jit(func)/jit(main)/concatenate"(#loc7)) +#loc20 = loc("jit(func)/jit(main)/gather"(#loc7)) +#loc21 = loc("jit(func)/jit(main)/jit(matmul)/pallas_call"(#loc8)) """, - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01+\x05\x01\x03\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03~\x02\xfb-\x01\xaf\x07\x0b\x0f\x0b\x0f\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0b\x13\x0b\x0f\x13\x0f\x0f+\x0b\x0f\x0b\x0b\x0b33\x0b3\x0b3\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x13\x13\x13\x13\x13\x0b\x0f\x0b\x0f\x0b\x13\x13\x0b\x13\x0b#\x0b\x0b\x0b\x0f\x0b\x13\x13\x13\x0f\x0b\x13\x0f\x0b\x13\x0b\x0f\x0b;\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x03M\x0b\x13\x0b\x0b\x0f\x0bO\x0b\x0b\x0b\x0f/\x0fO\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f&\x08\x1e\x02\x0f\x1f\x1fO///\x0b\x01\x05\x0b\x0f\x03)\x1f\x07\x1f\x1f\x07\x07\x0f\x13\x1b\x17\x13\x1f\x13\x13\x1b\x13\x07\x1b\x13\x1f\x022\x0e\x1f\x05!\x1d9\x15\x05#\x1d=\x15\x1dA\x15\x05%\x05\'\x05)\x05+\x17\x07C\x01\x05-\x17\x07G\x01\x05/\x17\x07=\x01\x051\x11\x03\x05\x03\x03\x1f\xc3\x1ds\x1d\x1dw\x1d\x03\t+-/!1!\x033\x053\x11\x01\x00\x055\x057\x059\x03\x0b\r\xaf\x0f\xcb\x11\xcd\x03\xd5\x13\xd7\x03\x0b\r\xb1\x0f\xb5\x11\xb7\x03\xbd\x13\xb9\x05;\x03\x0b\r\xb1\x0f\xb5\x11\xb7\x03\xbf\x13\xb9\x05=\x03\x0b\r\xb1\x0f\xb5\x11\xb7\x03\xc1\x13\xb9\x05?\x03\x13E\xd9G\xdbI\xddK\xafM\xdfO\xe1Q\xe3S\xafU\xe5\x05A\x05C\x05E\x05G\x05I\x05K\x05M\x05O\x05Q\x1dY\x15\x05S\x03\x03\x1b\xc1\x03\x03\x1b\xbf\x03\x03\x17\xe7\x03\x03\x17\xe9\x03\x03e\xeb\x05U\x1di\x1d\x05W\x1dmo\x05Y\x17\x07?\x01\x03\x03\x17\xed\x05[\x03\x03\x17\xef\x05]\x03\x07{\xf1}\xf3\x7f\xc5\x05_\x05a\x05c\x1d\x83\x85\x05e\x17\x07A\x01\x03\x03\x1b\xbd\x03\x03\x1f\xf5\x1d\x8d\x19\x05g\x03\x03\x1f\xf7\x1d\x93\x19\x05i\x03\x03\x97\xc7\x05k\x1d\x9b\x19\x05m\x03\r\x9f\xc9\xa1\xc7\xa3\xf9\xa5\xc3\xa7\xc5\xa9\xc9\x05o\x05q\x05s\x05u\x05w\x05y\x1d\xad\x19\x05{\x03\x01\x03\x05\xb3\xb3\r\x01#!\x03\x03\xb3\x1d}\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x7f\x1d\x81\x1d\x83\x1f)\x01\x1f\x13\x11\x01\x00\x00\x00\x00\x00\x00\x00\x13\r\t\x1f\x13!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x03\xcf\r\x03\xd1\xd3\x1d\x85\x1d\x87\x1d\x89\x1d\x8b\x0b\x03\x1d\x8d\x1d\x8f\x05\x01\x1d\x91\x03\x05\xbb\xbb\x03\x03\xbb\x1f\x17\x02\x04\x00\x00\x00\x00\x10\x00\x00\x00 \x00\x00\x000\x00\x00\x00@\x00\x00\x00P\x00\x00\x00`\x00\x00\x00p\x00\x00\x00\x80\x00\x00\x00\x90\x00\x00\x00\xa0\x00\x00\x00\xb0\x00\x00\x00\xc0\x00\x00\x00\xd0\x00\x00\x00\xe0\x00\x00\x00\xf0\x00\x00\x00\x00\x01\x00\x00\x10\x01\x00\x00 \x01\x00\x000\x01\x00\x00@\x01\x00\x00P\x01\x00\x00`\x01\x00\x00p\x01\x00\x00\x80\x01\x00\x00\x90\x01\x00\x00\xa0\x01\x00\x00\xb0\x01\x00\x00\xc0\x01\x00\x00\xd0\x01\x00\x00\xe0\x01\x00\x00\xf0\x01\x00\x00\x00\x02\x00\x00\x10\x02\x00\x00 \x02\x00\x000\x02\x00\x00@\x02\x00\x00P\x02\x00\x00`\x02\x00\x00p\x02\x00\x00\x80\x02\x00\x00\x90\x02\x00\x00\xa0\x02\x00\x00\xb0\x02\x00\x00\xc0\x02\x00\x00\xd0\x02\x00\x00\xe0\x02\x00\x00\xf0\x02\x00\x00\x00\x03\x00\x00\x10\x03\x00\x00 \x03\x00\x000\x03\x00\x00@\x03\x00\x00P\x03\x00\x00`\x03\x00\x00p\x03\x00\x00\x80\x03\x00\x00\x90\x03\x00\x00\xa0\x03\x00\x00\xb0\x03\x00\x00\xc0\x03\x00\x00\xd0\x03\x00\x00\xe0\x03\x00\x00\xf0\x03\x00\x00\x1f\x19\x81\x00\x00\x00\x00\x10\x00\x00\x00 \x00\x00\x000\x00\x00\x00@\x00\x00\x00P\x00\x00\x00`\x00\x00\x00p\x00\x00\x00\x80\x00\x00\x00\x90\x00\x00\x00\xa0\x00\x00\x00\xb0\x00\x00\x00\xc0\x00\x00\x00\xd0\x00\x00\x00\xe0\x00\x00\x00\xf0\x00\x00\x00\x13\r\x01\x1f\x11\to\x12\x83:\x1f\x11\t\x00\x00\x80?\x1f\x13!\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00\x05\x03\x01\t\x01\x02\x02)\x05\x02 \x02\x10\x07\t)\x05\x02\x10\x02\x08\x07)\x05\x02 \x02\x08\x07\x1d\x1b)\x01\x07)\x03\t\r)\x05\x02\x02A\x07)\x03\x02\x02\x0f)\x03A\x0f)\x07\x02\x02A\x05\x0f)\x03\x05\r\x11\x01\x03\x15\x11\x05\x05\t\x03\x0b)\x03\t%\x13)\x03\x04\x00\x80\x07)\x03\x01\r)\x07\x02\x02A\t\x0f\x04~\x03\x05\x01\x11\x01)\x07\x03\x01\x11\x03\x11\x015\x07\x03!E\x07\x03\x01_\x03\x17\x07\x03\x01a\x03\x19\x0f\x03gc\x03\'\x11\x06k\x03\x05\x03\x05\x07\x03\x01q\x03\x11\t\x07%#\x03\x05\x03\t\x13\x06%\x03\x05\x05\x0b\x07\x07\x03\x01u\x03\x11\t\x07\'#\x03\x05\x03\x0f\x15\x06\'\x03\x05\x05\x11\r\x17\x07\x81y\x03\t\x03\x13\x0b\x07\x05\x87\x03\x0b\x05\x13\x15\t\x07\x8b\x89\x03\x1b\x03\x01\t\x07\x91\x8f\x03\x1b\x03\x03\x19\x07\x99\x95\x03+\x05\x19\x1b\x1b\x07\xab\x9d\x03\x15\x05\x17\x1d\x05\x04\x01\x03\x1f\x03\x11\x057\x07\x03\x07\x0b\x05\x05\x05\t\x05\x0b\x07\t]\x03\x0b\x05\x01\x03\x05\x04\x05\x03\x05\x03\x11\t;\x07\x03\x07\x0b\x05\x05\t\t\t\x0b\x07\x0b[\x03\x0b\x05\x01\x03\x05\x04\t\x03\x05\x03\x11\x0b?\x07\x03\x07\x0b\x05\x05\x0b\t\x0b\r\x07WC\x03\x0b\x05\x01\x03\x05\x04\x0b\x03\x05\x06\x03\x01\x05\x01\x00\xee\xcd\x93\x0b!f\xa7\x0f\x0b\x03!\x1b\x11\x0f\x11\n\x04!\x19\x19\'#+[\x15\xa5\xa5\xad\x11\x1d\x1d11\x87\x89\x1ff\x03\x1f/!\x19!)#\x1f\x19\xa2\x03Z\x03&\x03\x13%)9+\x0f\r\x1f\x15\x1d\x15\x81\x13\x15\x1f\x13\x0f\x19\x17\x11\x1f\x11)\x19\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00constant_v1\x00broadcast_in_dim_v1\x00call_v1\x00custom_call_v1\x00iota_v1\x00reshape_v1\x00multiply_v1\x00add_v1\x00slice_v1\x00concatenate_v1\x00gather_v1\x00sym_name\x00third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00value\x00callee\x00broadcast_dimensions\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=matmul keep_unused=False inline=False]\x00jit(func)/jit(main)/jit(matmul)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=wrapped keep_unused=False inline=False]\x00jit(func)/jit(main)/jit(matmul)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue, UnspecifiedValue) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False, False) name=apply_kernel keep_unused=False inline=False]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00kernel_name\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/jit(matmul)/jit(wrapped)/jit(apply_kernel)/tpu_custom_call[config=CustomCallBackendConfig() kernel_name=func kernel_regeneration_metadata=None out_avals=(ShapedArray(float32[1024,256]),)]\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(524288,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(1024, 512) dimensions=None]\x00jit(func)/jit(main)/mul\x00jit(func)/jit(main)/add\x00limit_indices\x00start_indices\x00strides\x00jit(func)/jit(main)/slice[start_indices=(0, 0) limit_indices=(512, 256) strides=None]\x00jit(func)/jit(main)/broadcast_in_dim[shape=(64, 16, 1) broadcast_dimensions=(0,)]\x00jit(func)/jit(main)/broadcast_in_dim[shape=(64, 16, 1) broadcast_dimensions=(1,)]\x00dimension\x00jit(func)/jit(main)/concatenate[dimension=2]\x00collapsed_slice_dims\x00index_vector_dim\x00indices_are_sorted\x00offset_dims\x00slice_sizes\x00start_index_map\x00jit(func)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1)) slice_sizes=(1, 1) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00private\x00matmul\x00wrapped\x00apply_kernel\x00jax.result_info\x00\x00main\x00public\x00{"custom_call_config": {"body": "TUzvUgFNTElSZ29vZ2xlMy10cnVuawABLwkBAwUHAQMJAwUDCwUNDQ8RExUXBwMZA44DIgMhAfkbDw8LKxMTBxcjEwsLCwsTCwsLhQsLCxsLMwsPEw87CxMLC1MLDwsLFxsLUxsLUxsLUxsbGw8TEwsLExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTCxMTExMXBQthkWlpeQGPExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXExcTFxMXDxcPFw8XDxcPFw8XDxcTHxMfDxcfCwsLCwtTCxMBIRsHHw8HHw8nJycLIx8nJycCwhEDBTEzNTcd7R8dcR8FGwMHEgMWAzEzNTcdGgMfHQIDHx8DAwYDOQMFCgM7DgM7AwMXOQUdBR8FIQEBF3NDAQUjBSUNGWFmZmluZV9tYXA8KGQwLCBkMSkgLT4gKGQwLCBkMSk+AAUnBSkFKwMFFx0HawUtIxUREQEBAQEBAQEBBS8RBwUBAwICERUAAw0/QRlDRUdJSxtNT1EFMQEF+/sNFwUzIw0FIQQAAAAAAAAAAQAAAAAAAAAFNRENAQU3BTkBB1NZXwMFIVUjVwkpIw0FIQABAAAAAAAAAAIAAAAAAAADBSFbI10JKyMNBSEAAgAAAAAAAAABAAAAAAAAAwUhYSNjCS0jDQUhAAEAAAAAAAAAAQAAAAAAAAMFGSUbKQMFGSUbKwMFGSUbLREHAQMDB28RA8IPBTsFPQMDB3cRA4IPAwMHexEDQg8DAwd/EQMCDwMDB4MRA8IOAwMHhxEDgg4DAweLEQNCDgMDB48RAwIOAwMHkxEDwg0DAweXEQOCDQMDB5sRA0INAwMHnxEDAg0DAwejEQPCDAMDB6cRA4IMAwMHqxEDQgwDAwevEQPCCwMDB7MRA4ILAwMHtxEDQgsDAwe7EQMCCwMDB78RA8IKAwMHwxEDggoDAwfHEQNCCgMDB8sRAwIKAwMHzxEDwgkDAwfTEQOCCQMDB9cRA0IJAwMH2xEDAgkDAwffEQPCCAMDB+MRA4IIAwMH5xEDQggDAwfrEQMCDAU/AwMH8REDwgcDAwf1EQOCBwMDBwYCI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPGFyYml0cmFyeT4AI3RwdS50aWxlZDwoOCwxMjgpLFsyLDFdPgAjdHB1LnRpbGVkPCg4LDEyOCksWzQsMV0+ACN0cHUudnBhZDwiMzIsezAsMH0sKDgsMTI4KSI+ABEDQgcDAwcOAhEDAgcDAwcWAhEDwgYDAwceAhEDggYDAwcmAhEDQgYDAwcuAhEDAgYDAwc2AhEDwgUDAwc+AhEDggUDAwdGAhEDQgUDAwdOAhEDAgUDAwdWAhEDwgQDAwdeAhEDggQDAwdmAhEDQgQDAwduAhEDwgMDAwd2AhEDggMDAwd+AhEDQgMDAweGAhEDAgMDAweOAhEDwgIDAweWAhEDggIDAweeAhEDQgIDAwemAhEDAgIDAweuAhED4QMDB7YCEQPBAwMHvgIRA6EDAwfGAhEDgQMDB84CEQNhAwMH1gIRA0EDAwfeAhEDIQMDB+YCEQMCBAMFFx0H7gIRAwIIAwUXHQf2AhEDAQMDB/4CJQEJAAAAAAVBBUMFRQVHBUkjBwkhAQAAAAEAAAACAAAAAAAAAAVLAwMXHScFIQIECQMnBQIIAgQJAQICCycFAgQCBAkBAgQX+QUCCAIQCf8X+QUCEAIICf0X+QUCCAIICf0BCQULBwcPERMBBQUHBwUHBxf5BQIIAhAJJxf5BQIQAggJJxf5BQIIAggJJwR6UQUBEA8HAwERAxEPPQcDlglODQsHDwcPDw8RDxMPEwMFbQMDEwMFdQMDEwMFeQMDEwMFfQMDEwMFgQMDEwMFhQMDEwMFiQMDEwMFjQMDEwMFkQMDEwMFlQMDEwMFmQMDEwMFnQMDEwMFoQMDEwMFpQMDEwMFqQMDEwMFrQMDEwMFsQMDEwMFtQMDEwMFuQMDEwMFvQMDEwMFwQMDEwMFxQMDEwMFyQMDEwMFzQMDEwMF0QMDEwMF1QMDEwMF2QMDEwMF3QMDEwMF4QMDEwMF5QMDEwMD6QMDEwMD7wMDEwMD8wMDEwMD9wMDEwMDCgIDAxMDAxICAwMTAwMaAgMDEwMDIgIDAxMDAyoCAwMTAwMyAgMDEwMDOgIDAxMDA0ICAwMTAwNKAgMDEwMDUgIDAxMDA1oCAwMTAwNiAgMDEwMDagIDAxMDA3ICAwMTAwN6AgMDEwMDggIDAxMDA4oCAwMTAwOSAgMDEwMDmgIDAxMDA6ICAwMTAwOqAgMDEwMDsgIDAxMDA7oCAwMTAwPCAgMDEwMDygIDAxMDA9ICAwMTAwPaAgMDEwMD4gIDAxMDA+oCAwMTAwPyAgMDEwMN+gIDAREGDwMbAwURBg8DHQMHEQYPAx8DCQcHAwEDAQeNiYkHBwMBAwEHjYmFBwcDAQMBB42DiQcHAwEDAQeNg4UHBwMBAwEHjYGJBwcDAQMBB42BhQcHAwEDAQeNf4kHBwMBAwEHjX+FBwcDAQMBB419iQcHAwEDAQeNfYUHBwMBAwEHjXuJBwcDAQMBB417hQcHAwEDAQeNeYkHBwMBAwEHjXmFBwcDAQMBB413iQcHAwEDAQeNd4UHBwMBAwEHjXWJBwcDAQMBB411hQcHAwEDAQeNc4kHBwMBAwEHjXOFBwcDAQMBB41xiQcHAwEDAQeNcYUHBwMBAwEHjW+JBwcDAQMBB41vhQcHAwEDAQeNbYkHBwMBAwEHjW2FBwcDAQMBB41riQcHAwEDAQeNa4UHBwMBAwEHjWmJBwcDAQMBB41phQcHAwEDAQeNZ4kHBwMBAwEHjWeFBwcDAQMBB42FiQcHAwEDAQeNhYUHBwMBAwEHjWWJBwcDAQMBB41lhQcHAwEDAQeNY4kHBwMBAwEHjWOFBwcDAQMBB41hiQcHAwEDAQeNYYUHBwMBAwEHjV+JBwcDAQMBB41fhQcHAwEDAQeNXYkHBwMBAwEHjV2FBwcDAQMBB41biQcHAwEDAQeNW4UHBwMBAwEHjVmJBwcDAQMBB41ZhQcHAwEDAQeNV4kHBwMBAwEHjVeFBwcDAQMBB41ViQcHAwEDAQeNVYUHBwMBAwEHjVOJBwcDAQMBB41ThQcHAwEDAQeNUYkHBwMBAwEHjVGFBwcDAQMBB41PiQcHAwEDAQeNT4UHBwMBAwEHjU2JBwcDAQMBB41NhQcHAwEDAQeNS4kHBwMBAwEHjUuFBwcDAQMBB41JiQcHAwEDAQeNSYUHBwUBAwEHj4mJBwcFAQMBB4+JhQcHBQEDAQePg4kHBwUBAwEHj4OFBwcFAQMBB4+BiQcHBQEDAQePgYUHBwUBAwEHj3+JBwcFAQMBB49/hQcHBQEDAQePfYkHBwUBAwEHj32FBwcFAQMBB497iQcHBQEDAQePe4UHBwUBAwEHj3mJBwcFAQMBB495hQcHBQEDAQePd4kHBwUBAwEHj3eFBwcFAQMBB491iQcHBQEDAQePdYUHBwUBAwEHj3OJBwcFAQMBB49zhQcHBQEDAQePcYkHBwUBAwEHj3GFBwcFAQMBB49viQcHBQEDAQePb4UHBwUBAwEHj22JBwcFAQMBB49thQcHBQEDAQePa4kHBwUBAwEHj2uFBwcFAQMBB49piQcHBQEDAQePaYUHBwUBAwEHj2eJBwcFAQMBB49nhQcHBQEDAQePhYkHBwUBAwEHj4WFBwcFAQMBB49liQcHBQEDAQePZYUHBwUBAwEHj2OJBwcFAQMBB49jhQcHBQEDAQePYYkHBwUBAwEHj2GFBwcFAQMBB49fiQcHBQEDAQePX4UHBwUBAwEHj12JBwcFAQMBB49dhQcHBQEDAQePW4kHBwUBAwEHj1uFBwcFAQMBB49ZiQcHBQEDAQePWYUHBwUBAwEHj1eJBwcFAQMBB49XhQcHBQEDAQePVYkHBwUBAwEHj1WFBwcFAQMBB49TiQcHBQEDAQePU4UHBwUBAwEHj1GJBwcFAQMBB49RhQcHBQEDAQePT4kHBwUBAwEHj0+FBwcFAQMBB49NiQcHBQEDAQePTYUHBwUBAwEHj0uJBwcFAQMBB49LhQcHBQEDAQePSYkHBwUBAwEHj0mFBwcDAQMBB42JhwcHAwEDAQeNiUcHBwMBAwEHjYOHBwcDAQMBB42DRwcHAwEDAQeNgYcHBwMBAwEHjYFHBwcDAQMBB41/hwcHAwEDAQeNf0cHBwMBAwEHjX2HBwcDAQMBB419RwcHAwEDAQeNe4cHBwMBAwEHjXtHBwcDAQMBB415hwcHAwEDAQeNeUcHBwMBAwEHjXeHBwcDAQMBB413RwcHAwEDAQeNdYcHBwMBAwEHjXVHBwcDAQMBB41zhwcHAwEDAQeNc0cHBwMBAwEHjXGHBwcDAQMBB41xRwcHAwEDAQeNb4cHBwMBAwEHjW9HBwcDAQMBB41thwcHAwEDAQeNbUcHBwMBAwEHjWuHBwcDAQMBB41rRwcHAwEDAQeNaYcHBwMBAwEHjWlHBwcDAQMBB41nhwcHAwEDAQeNZ0cHBwMBAwEHjYWHBwcDAQMBB42FRwcHAwEDAQeNZYcHBwMBAwEHjWVHBwcDAQMBB41jhwcHAwEDAQeNY0cHBwMBAwEHjWGHBwcDAQMBB41hRwcHAwEDAQeNX4cHBwMBAwEHjV9HBwcDAQMBB41dhwcHAwEDAQeNXUcHBwMBAwEHjVuHBwcDAQMBB41bRwcHAwEDAQeNWYcHBwMBAwEHjVlHBwcDAQMBB41XhwcHAwEDAQeNV0cHBwMBAwEHjVWHBwcDAQMBB41VRwcHAwEDAQeNU4cHBwMBAwEHjVNHBwcDAQMBB41RhwcHAwEDAQeNUUcHBwMBAwEHjU+HBwcDAQMBB41PRwcHAwEDAQeNTYcHBwMBAwEHjU1HBwcDAQMBB41LhwcHAwEDAQeNS0cHBwMBAwEHjUmHBwcDAQMBB41JRwcHBQEDAQePh4kHBwUBAwEHj4eFBwcFAQMBB49FiQcHBQEDAQePRYUHBwUBAwEHj0OJBwcFAQMBB49DhQcHBQEDAQePQYkHBwUBAwEHj0GFBwcFAQMBB48/iQcHBQEDAQePP4UHBwUBAwEHjz2JBwcFAQMBB489hQcHBQEDAQePO4kHBwUBAwEHjzuFBwcFAQMBB485iQcHBQEDAQePOYUHBwUBAwEHjzeJBwcFAQMBB483hQcHBQEDAQePNYkHBwUBAwEHjzWFBwcFAQMBB48ziQcHBQEDAQePM4UHBwUBAwEHjzGJBwcFAQMBB48xhQcHBQEDAQePL4kHBwUBAwEHjy+FBwcFAQMBB48tiQcHBQEDAQePLYUHBwUBAwEHjyuJBwcFAQMBB48rhQcHBQEDAQePKYkHBwUBAwEHjymFBwcFAQMBB49HiQcHBQEDAQePR4UHBwUBAwEHjyeJBwcFAQMBB48nhQcHBQEDAQePJYkHBwUBAwEHjyWFBwcFAQMBB48jiQcHBQEDAQePI4UHBwUBAwEHjyGJBwcFAQMBB48hhQcHBQEDAQePH4kHBwUBAwEHjx+FBwcFAQMBB48diQcHBQEDAQePHYUHBwUBAwEHjxuJBwcFAQMBB48bhQcHBQEDAQePGYkHBwUBAwEHjxmFBwcFAQMBB48XiQcHBQEDAQePF4UHBwUBAwEHjxWJBwcFAQMBB48VhQcHBQEDAQePE4kHBwUBAwEHjxOFBwcFAQMBB48RiQcHBQEDAQePEYUHBwUBAwEHjw+JBwcFAQMBB48PhQcHBQEDAQePDYkHBwUBAwEHjw2FBwcFAQMBB48LiQcHBQEDAQePC4ULBw0RAwVBJgMuAzYDPgNGA04DVgNeA2YDbgN2A34DhgOOA5YDngOmA64DtgO+A8YDzgPWA94D5gPuA/YD/gMGBA4EFgQeBAsHDREDBUEqAzIDOgNCA0oDUgNaA2IDagNyA3oDggOKA5IDmgOiA6oDsgO6A8IDygPSA9oD4gPqA/ID+gMCBAoEEgQaBCIECwcNEQMLISYELgQ2BD4ERgROBFYEXgRmBG4EdgR+BIYEjgSWBJ4ECwcNEQMFQYuLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLDQcNEwMFByYFLgUyBQ8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEDNgULBw0RAwshpgSuBLYEvgTGBM4E1gTeBOYE7gT2BP4EBgUOBRYFHgULBw0RAwVBOgU+BUIFRgVKBU4FUgVWBVoFXgViBWYFagVuBXIFdgV6BX4FggWGBYoFjgWSBZYFmgWeBaIFpgWqBa4FsgW2BQ0HDRMDBQcqBboFvgUPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA8IFCwcNEQMLISoEMgQ6BEIESgRSBFoEYgRqBHIEegSCBIoEkgSaBKIECwcNEQMFQYuLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLi4uLDQcNEwMFByYFRgZKBg8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEDTgYLBw0RAwshqgSyBLoEwgTKBNIE2gTiBOoE8gT6BAIFCgUSBRoFIgULBw0RAwVBUgZWBloGXgZiBmYGagZuBnIGdgZ6Bn4GggaGBooGjgaSBpYGmgaeBqIGpgaqBq4Gsga2BroGvgbCBsYGygbOBg0HDRMDBQcqBdIG1gYPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA9oGCwcNEQMFQZOXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+/8GAg4CFgIeAgsHDREDBUGVmZ2hpamtsbW5vcHFyc3R1dnd4eXp7fH1+f0CAgoCEgIaAiICCwcNEQMLISYCLgI2Aj4CRgJOAlYCXgJmAm4CdgJ+AoYCjgKWAp4CCwcNEQMFQcYFygXOBdIF1gXaBd4F4gXmBeoF7gXyBfYF+gX+BQIGBgYKBg4GEgYWBhoGHgYiBiYGKgYuBjIGNgY6Bj4GQgYNBw0TAwUHXgdmB2oHDwcNFUEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQNuBwsHDREDCyGmAq4CtgK+AsYCzgLWAt4C5gLuAvYC/gIGAw4DFgMeAwsHDREDBUFyB3YHegd+B4IHhgeKB44HkgeWB5oHngeiB6YHqgeuB7IHtge6B74HwgfGB8oHzgfSB9YH2gfeB+IH5gfqB+4HDQcNEwMFB2IH8gf2Bw8HDRVBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQED+gcLBw0RAwshKgIyAjoCQgJKAlICWgJiAmoCcgJ6AoICigKSApoCogILBw0RAwVB3gbiBuYG6gbuBvIG9gb6Bv4GAgcGBwoHDgcSBxYHGgceByIHJgcqBy4HMgc2BzoHPgdCB0YHSgdOB1IHVgdaBw0HDRMDBQdeB34IgggPBw0VQQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBA4YICwcNEQMLIaoCsgK6AsICygLSAtoC4gLqAvIC+gICAwoDEgMaAyIDCwcNEQMFQYoIjgiSCJYImgieCKIIpgiqCK4Isgi2CLoIvgjCCMYIygjOCNII1gjaCN4I4gjmCOoI7gjyCPYI+gj+CAIJBgkNBw0TAwUHYgcKCQ4JDwcNFUEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQMSCQkFCwkJ/geRiYkJBQsJCRYJkYmFCQULCQkCCJGDiQkFCwkJGgmRg4UJBQsJCQYIkYGJCQULCQkeCZGBhQkFCwkJCgiRf4kJBQsJCSIJkX+FCQULCQkOCJF9iQkFCwkJJgmRfYUJBQsJCRIIkXuJCQULCQkqCZF7hQkFCwkJFgiReYkJBQsJCS4JkXmFCQULCQkaCJF3iQkFCwkJMgmRd4UJBQsJCR4IkXWJCQULCQk2CZF1hQkFCwkJIgiRc4kJBQsJCToJkXOFCQULCQkmCJFxiQkFCwkJPgmRcYUJBQsJCSoIkW+JCQULCQlCCZFvhQkFCwkJLgiRbYkJBQsJCUYJkW2FCQULCQkyCJFriQkFCwkJSgmRa4UJBQsJCTYIkWmJCQULCQlOCZFphQkFCwkJOgiRZ4kJBQsJCVIJkWeFCQULCQk+CJGFiQkFCwkJVgmRhYUJBQsJCUIIkWWJCQULCQlaCZFlhQkFCwkJRgiRY4kJBQsJCV4JkWOFCQULCQlKCJFhiQkFCwkJYgmRYYUJBQsJCU4IkV+JCQULCQlmCZFfhQkFCwkJUgiRXYkJBQsJCWoJkV2FCQULCQlWCJFbiQkFCwkJbgmRW4UJBQsJCVoIkVmJCQULCQlyCZFZhQkFCwkJXgiRV4kJBQsJCXYJkVeFCQULCQliCJFViQkFCwkJegmRVYUJBQsJCWYIkVOJCQULCQl+CZFThQkFCwkJagiRUYkJBQsJCYIJkVGFCQULCQluCJFPiQkFCwkJhgmRT4UJBQsJCXIIkU2JCQULCQmKCZFNhQkFCwkJdgiRS4kJBQsJCY4JkUuFCQULCQl6CJFJiQkFCwkJkgmRSYUFAQ8eAwMRD2UHAwcLBQcPBw8TAw8vAwcFBA8FAQUDEQ9nBwMHCwUHDwcPEwMPLwMHBQQPBQUDAxEPaQcDBQcFBw8HDwUEDwUBAwYDAQUBAE4VTXYDKR0dF04C/gOB/gMdCyEjKR8bGRkZHSUTHRUNEykfDxsNCw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAbW9kdWxlAHJldHVybgBsb2FkAHN0b3JlAHJvbGxfdmVjdG9ycwBtYXRtdWwAdW5yb2xsX3ZlY3RvcnMAZXJhc2VfbWVtcmVmX2xheW91dABjb25zdGFudAB2YWx1ZQBpbl9sYXlvdXQAZnVuY3Rpb25fdHlwZQBzeW1fbmFtZQB0cmFuc2Zvcm1faW5kaWNlcwB3aW5kb3dfYm91bmRzAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHN1YmxhbmVfbWFzawBzdWJsYW5lX3N0cmlkZQBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAG1haW4Ad2luZG93X3BhcmFtcwAvbWFza2VkX2xvYWRbbWFza2VkPUZhbHNlIGNhY2hlX21vZGlmaWVyPSBldmljdGlvbl9wb2xpY3k9IGlzX3ZvbGF0aWxlPUZhbHNlIGFyZ3NfdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhbXSwgUHlUcmVlRGVmKFtDdXN0b21Ob2RlKFNsaWNlWyhGYWxzZSwgMjU2KV0sIFsqXSksIEN1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSldKSwgW1RydWUsIFRydWVdLCAoNTEyLCAyNTYpLCAoKSldLCBbKl0pLCkpXQB0aGlyZF9wYXJ0eS9weS9qYXhfdHJpdG9uL2dvb2dsZS9wYWxsYXNfdHB1L2JhY2tfY29tcGF0X3Rlc3QucHkAL21hc2tlZF9sb2FkW21hc2tlZD1GYWxzZSBjYWNoZV9tb2RpZmllcj0gZXZpY3Rpb25fcG9saWN5PSBpc192b2xhdGlsZT1GYWxzZSBhcmdzX3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoW10sIFB5VHJlZURlZihbQ3VzdG9tTm9kZShTbGljZVsoVHJ1ZSwgMCwgMjU2KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoRmFsc2UsIDI1NildLCBbKl0pXSksIFtUcnVlLCBUcnVlXSwgKDI1NiwgNTEyKSwgKCkpXSwgWypdKSwpKV0AL2RvdF9nZW5lcmFsW2RpbWVuc2lvbl9udW1iZXJzPSgoKDEsKSwgKDAsKSksICgoKSwgKCkpKSBwcmVjaXNpb249KDxQcmVjaXNpb24uREVGQVVMVDogMD4sIDxQcmVjaXNpb24uREVGQVVMVDogMD4pIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0Ab3V0X2xheW91dAB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAb3BlcmFuZFNlZ21lbnRTaXplcwAvbWFza2VkX3N3YXBbbWFza2VkPUZhbHNlIGV2aWN0aW9uX3BvbGljeT0gYXJnc190cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFtdLCBQeVRyZWVEZWYoW0N1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSksIEN1c3RvbU5vZGUoU2xpY2VbKFRydWUsIDAsIDI1NildLCBbXSldKSwgW1RydWUsIFRydWVdLCAoMjU2LCAyNTYpLCAoKSldLCBbXSksKSldAA=="}}\x00tpu_custom_call\x00func\x00', - xla_call_module_version=7, + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.5.0\x00\x01-\x05\x01\x05\x1d\x01\x03\x0b\x03\x1b\x0f\x13\x17\x1b\x1f#\'+/37;?\x03\xe5\xa3/\x01W\x07\x0b\x13\x0f\x0f\x0f\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x13\x13\x0b\x0f\x0b\x13\x0b\x0f\x13\x0f\x0b\x13\x13\x13\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x03M\x0f\x0b\x13O\x0f\x0b\x0b\x0b/\x0fO\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x1f\x1f\x1f\x1fO///\x0b\x01\x05\x0b\x0f\x03+\x1f\x07\x07\x07\x17\x13\x0f\x0f\x13\x1f\x1f\x1b\x1f\x13\x13\x1b\x13\x07\x1b\x13\x1f\x02\x92\x06\x1f\x05!\x17\x03\x99\x1b\x1d)+\x1d\x13\x05\x1d\x17\x05\x11\x03\x05\x05#\x1d\x13C\x05%\x1d\x17E\x05\'\x1d\x0f\x05\x1dM\x05\x03\x07\x1f!#\r%\r\x05)\x11\x01\x00\x05+\x05-\x05/\x051\x17\x03\x95\x19\x03\x03/\x83\x053\x1d35\x055\x177\x89\x13\x057\x1d\x0f;\x17\x03\x8f3\x1d?A\x059\x17\x03\x91W\x17\x03\x8f#\x17\x03\x8f\x15\x1dIK\x05;\x17\x03\x93\x15\x05=\x1dQ\x05\x05?\x1dU\x05\x05A\x1f+\x01\x03\x01\r\x03ac\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x01\x1dC\x1dE\x1dG\x1f\x15\x11\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0b\t\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#!\x03\x03q\r\x05suac\x1dI\x1dK\x1dM\x1dO\x03\x05[[##\x03\x03[\x1dQ\x1dS\x0b\x03\x1dU\x1dW\x05\x01\x03\x05]]\x03\x03]\x1f\x11\t\x00\x00\x00\x00\x1f\x11\t\x10\x00\x00\x00\x1f\x13\t\x00\x00\x80?\x1f\x13\to\x12\x83:\x1f\x15!\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x1f\x15\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1f\x11\x01\x00\x00\x00\x00\x00\x00\x00\x05\x03\x01\t\x01\x02\x02)\x05\x02 \x02\x10\x07\t\x1b\x1d)\x03\x02\x02\t)\x03A\t)\x01\t)\x01\x07)\x03\t\x0b)\x05\x02\x10\x02\x08\x07)\x05\x02 \x02\x08\x07)\x05\x02\x02A\x07)\x07\x02\x02A\x05\t)\x03\x05\x0b\x11\x01\x03\x1b\x11\x05\x05\x17\x03\x19)\x03\t\'\x13)\x03\x04\x00\x80\x07)\x03\x01\x0b)\x07\x02\x02A\t\t\x04\x02\x04\x05\x01Q\x01\x1d\x01\x07\x04\xda\x03\x03\x01\t\rP\x01\x03\x07\x042\x03\x035m\x05B\x01\x05\x03\x11\x05B\x01\x07\x03\x11\x05B\x01\t\x03\x13\x05B\x01\x0b\x03\x13\x07B9\r\x03)\x13\x06=\x03\x05\x03\t\x03F\x11\x0f\x03\x05\x03\x07\t\x06\x11\x03\x05\x05\r\x0b\x03F\x15\x0f\x03\x05\x03\x05\x0b\x06\x15\x03\x05\x05\x11\x0f\x15FG\x11\x03\x17\x03\x13\x17F\x07\x13\x03\x19\x05\x13\x15\x07B\x19\r\x03\r\x03F\t\x0f\x03\r\x03\x03\t\x06\t\x03\r\x05\x1b\x19\x03F\x0b\x0f\x03\r\x03\x01\x0b\x06\x0b\x03\r\x05\x1f\x1d\x07B\x19\r\x03\x0f\x03F\t\x0f\x03\x0f\x03\x03\t\x06\t\x03\x0f\x05%#\x03F\x0b\x0f\x03\x0f\x03\x01\x0b\x06\x0b\x03\x0f\x05)\'\x03F\x1b\x15\x03\x1d\x03!\x03F\x1b\x17\x03\x1d\x03+\x19FO\x19\x03-\x05-/\x1bFS\x1b\x03\x1b\x05\x171\x0f\x04\x01\x033\rP\x07\x1d\x07\x041\x03\x07\x0b\x05\x0b\x07/\x07\x00\x11G1-\x1f\x03\x19\x05\x01\x03\x0f\x04\x07\x03\x05\x06\x03\x01\x05\x01\x00\x1a3Y!j&\x1d\x11\x0f\x0b\x03!\x0f\x11#7AK59sY\x193\x13%)9113\x85\x15\x1f\x11\x13\x17\x1f\x15\x11\x0f\x19\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00multiply_v1\x00add_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00reshape_v1\x00slice_v1\x00call_v1\x00concatenate_v1\x00gather_v2\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00jit(func)/jit(main)/iota\x00jit(func)/jit(main)/mul\x00jit(func)/jit(main)/add\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00kernel_name\x00jit(func)/jit(main)/jit(matmul)/pallas_call\x00third_party/py/jax/experimental/pallas/ops/tpu/matmul.py\x00jit(func)/jit(main)/reshape\x00jit(func)/jit(main)/slice\x00jit(func)/jit(main)/broadcast_in_dim\x00jit(func)/jit(main)/concatenate\x00jit(func)/jit(main)/gather\x00mhlo.layout_mode\x00default\x00matmul\x00jax.result_info\x00\x00main\x00public\x00private\x00matmul_kernel\x00{"custom_call_config": {"body": "TUzvUgFNTElSZ29vZ2xlMy10cnVuawABLQkBAwUHAQMJAxkLDQ8RExUXGRsdHyED58ETAbkHEwsTCwsLDwsPDw8LC1MLDw8PDwsPDw8LCwsLExMPDxMPGwsPC0MLFwuFC3MLCwsLFxsLGwsbCxsbGw8LExMPEw8LCxMPExMTHwsTGwsLEwsPCxMTEwsTDwsTEwUHjZFhBwNZARMPBx8nDwcLKyMCZggfAwMLiwUjAwMLdwUlBScFKR15ewUrHSmnHSmrHSm3BS0FLyMJBSEAAQAAAAAAAAABAAAAAAAADREdhzkdETsdEY0dEY8FMR0RqREJAREJBQUzBTUFNwU5FwU7BxcFQyMdlZcRDQAXrRcLHbO1AwVHSQlLBTsRCQ0FPQMPT1ENU1dZWy1dLwlfYWMFPwEHubm7DQ9hZmZpbmVfbWFwPChkMCwgZDEpIC0+IChkMCwgZDEpPgAFQSMJBzEEAAAAAAAAAAEAAAAAAAAAAgAAAAAAAAAFQwVFBUcFSQEHZWltAwUZZxsdCTEDBRlrGx0JMwMFGW8bHQk1AwUNHwkxAwUNHwkzAwUNHwk1EQEBBUsXBTsXAwMLfxEBBQMDNy0dhTkFTQVPAwM3LxEDARcFRQ0XBUcNAwMLkyUFCQAAAAAFURcFQ0EDBZs/nT8FUwVVAwOhvwVXHaU7BVkXBUMFFwVRKRcFUQUFWwMDC7ETCwEFXRcFPycXBT8JI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPHBhcmFsbGVsPgAjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8YXJiaXRyYXJ5PgAjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAjYXJpdGguZmFzdG1hdGg8bm9uZT4AAQICAycFAggCCAsXvQUCCAIIC1UBAgQLAQkFDwEBAQcHBwcBBQcBAQEFAQEEIgYFAREBRQcDAREHEQFNBwNHgw8BAQEBAQEHAQcBBwEHAQMDDwcDAQMDD30DAQMDDwcDAQ0HD4EDDQUFDxEGgwMBAxUDAyEHAwENByGJAw0FFxkTFCEDGwkDCx0DA0OvAwsZBkMDBQNHAwMXAwMDAwMXAwMDBQYXAwUHDUtNCwQXCUkNS00PAEEDAQUPAEEDAyMDAwMDAyMDAwMFBiMDBQcNHR8DAyUDAwMDAyUDAwMFBiUDBQcHIyUDAycDAwMDAycDAwMFBicDBQcJKSsDAz2RAwUVBz2ZAwUHJy0vFwejnwMFBSExAwMTAwMDAwMTAwMDBQYTAwUHDTU3CwQTCTMNNTcDAysDAwMDAysDAwMFBisDBQcNOz0DAxUDAwMDAxUDAwMFBhUDBQcLQUMLBBUJPwtBQwkAAQcRAXEHAwkLBwEBAQEBAQMDAQcDAQkEAQUBBQcRAXMHAwkLBwEBAQEBAQMDAQcDAQkEAQUFAwcRAXUHAwkLBwEBAQEBAQMDAQcDAQkEAQUBAwYDAQUBAO4JXyUFCxMdHRsNLQkdCyMhIykdLRUZGRkNHSULHQ0TcyMXFw8ZFRcbGRUZHw8NCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQBhcml0aABtb2R1bGUAYXJpdGguY29uc3RhbnQAdmVjdG9yLmxvYWQAZnVuYy5mdW5jAGZ1bmMucmV0dXJuAHZlY3Rvci5zdG9yZQBhcml0aC5jbXBpAHNjZi55aWVsZABhcml0aC5leHR1aQBzY2YuaWYAdHB1Lm1hdG11bABhcml0aC5hZGRmAHZlY3Rvci5icm9hZGNhc3QAdGhpcmRfcGFydHkvcHkvamF4L2V4cGVyaW1lbnRhbC9wYWxsYXMvb3BzL3RwdS9tYXRtdWwucHkAc3ltX25hbWUAdmFsdWUAZnVuY3Rpb25fdHlwZQAvZ2V0AHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAL3N3YXAAdHJhbnNmb3JtXzAAdHJhbnNmb3JtXzEAdHJhbnNmb3JtXzIAcHJlZGljYXRlAHN0YWJsZV9tb3NhaWMudmVyc2lvbgBtYXRtdWxfa2VybmVsAGRpbWVuc2lvbl9zZW1hbnRpY3MAaXRlcmF0aW9uX2JvdW5kcwBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBtYWluAHdpbmRvd19wYXJhbXMAL2VxAC9jb252ZXJ0X2VsZW1lbnRfdHlwZQAvY29uZAAvZG90X2dlbmVyYWwAdHJhbnNwb3NlX2xocwB0cmFuc3Bvc2VfcmhzAGZhc3RtYXRoAC9hZGQALQAvYnJvYWRjYXN0X2luX2RpbQA=", "serialization_format": 1, "needs_layout_passes": true}, "implicit_sharding": {"type": "MANUAL"}}\x00tpu_custom_call\x00\x08u!\x05O\x01\x0bYmowy\x03\x91\x03\x93\x03\x95\x03\x97\x03_\x03W\x07\x99\x9bg\x03e\x03\x9d\x03\x9f\x03i\x11ki\xa1WWgkW\x0b{}\x7fe\x81\x11\x85\x87\x89Y\x8b\x8dY\x8f', + xla_call_module_version=9, + nr_devices=1, ) # End paste diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index ff1d828f27fc..9e9935884b3a 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -62,8 +62,6 @@ def add_one(x_ref, o_ref): @jax.default_matmul_precision("bfloat16") def test_mosaic_matmul(self): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(apaszke): Test fails on TPU v6e") dtype = jnp.float32 def func(): # Build the inputs here, to reduce the size of the golden inputs. @@ -77,7 +75,7 @@ def func(): # Keep only slices of the output, to reduce the size of the goldens. return res[::16, ::16] - data = self.load_testdata(mosaic_matmul.data_2023_09_22) + data = self.load_testdata(mosaic_matmul.data_2024_09_24) self.run_one_test(func, data, rtol=2e-7) def test_mosaic_semaphore_dma(self): From d2ac88c1936cb2269dcd4e543c9178b209453236 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 24 Sep 2024 09:48:06 -0700 Subject: [PATCH 636/702] Expose some APIs for querying trace state. This will let us move users away from depending on our internals. Prep work for "stackless". PiperOrigin-RevId: 678288660 --- jax/_src/core.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++++ jax/core.py | 6 ++++++ 2 files changed, 56 insertions(+) diff --git a/jax/_src/core.py b/jax/_src/core.py index 467f2b63d390..85361ab5a7b8 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -3410,3 +3410,53 @@ def clean_up_dead_vars(eqn: JaxprEqn, env: dict[Var, Any], # Used in shard_map for converting avals shard_aval_handlers = {} # type: ignore unshard_aval_handlers = {} # type: ignore + +# ----------------- external APIs for querying tracing context ----------------- + +# TODO(dougalm, jakevdp): expose these via jax.extend + +# Comparable object for checking whether JAX's trace state has changed. +class OpaqueTraceState: + def __init__(self, trace_info, convention): + self._trace_info = trace_info + self._convention = convention + + def __eq__(self, other): + if isinstance(other, OpaqueTraceState): + if self._convention in ["nnx"]: + return self._trace_info is other._trace_info + elif self._convention in ["haiku", "flax"]: + return self._trace_info == other._trace_info + else: + raise Exception(f"unrecognized convention: {self._convention}") + + +# Each library has its own opinion about what the important fragment of jax's +# internal state is. TODO: reconcile the differences and remove the flag. +def get_opaque_trace_state(convention="flax"): + if convention == "flax": + trace_info = find_top_trace(()).level + elif convention == "haiku": + trace_stack = thread_local_state.trace_state.trace_stack.stack + top_type = trace_stack[0].trace_type + level = trace_stack[-1].level + sublevel = cur_sublevel() + trace_info = (top_type, level, sublevel) + elif convention == "nnx": + trace_info = thread_local_state.trace_state.trace_stack.dynamic + else: + raise Exception(f"unrecognized convention: {convention}") + + return OpaqueTraceState(trace_info, convention) + +def nonempty_axis_env() -> bool: + return bool(thread_local_state.trace_state.axis_env) + +def unsafe_am_i_under_a_jit() -> bool: + return 'DynamicJaxprTrace' in str(thread_local_state.trace_state.trace_stack) + +def unsafe_am_i_under_a_vmap() -> bool: + return 'BatchTrace' in str(thread_local_state.trace_state.trace_stack) + +def unsafe_get_axis_names() -> list[str]: + return [axis.name for axis in thread_local_state.trace_state.axis_env] diff --git a/jax/core.py b/jax/core.py index cdf8d76558d9..035dcdcdb7e8 100644 --- a/jax/core.py +++ b/jax/core.py @@ -29,6 +29,7 @@ Effect as Effect, Effects as Effects, EvalTrace as EvalTrace, + get_opaque_trace_state as get_opaque_trace_state, InDBIdx as InDBIdx, InconclusiveDimensionOperation as InconclusiveDimensionOperation, InputType as InputType, @@ -41,6 +42,8 @@ Literal as Literal, MainTrace as MainTrace, MapPrimitive as MapPrimitive, + nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, + OpaqueTraceState as OpaqueTraceState, NameGatheringSubst as NameGatheringSubst, OutDBIdx as OutDBIdx, OutputType as OutputType, @@ -55,6 +58,9 @@ TraceStack as TraceStack, TraceState as TraceState, Tracer as Tracer, + unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, + unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, + unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, UnshapedArray as UnshapedArray, Value as Value, Var as Var, From be7fe878c3e6d7928ca354fc33a820692499d956 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 24 Sep 2024 11:10:44 -0700 Subject: [PATCH 637/702] [pallas:triton] Elide `program_id` calls where launch grid dimension is 1. This may allow for parts of indexing calculations to be optimized away. PiperOrigin-RevId: 678321871 --- jax/_src/pallas/triton/lowering.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 15f0d265b836..4722a31db92c 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -220,19 +220,18 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): if len(collapse_dims) == 0: prog_ids = [None] * len(prog_id_dims) for i in range(len(prog_id_dims)): - out_idx = launch_grid_to_pallas_grid[i] - prog_ids[out_idx] = _program_id(i) + prog_ids[launch_grid_to_pallas_grid[i]] = _program_id(i, prog_id_dims) return prog_id_dims, prog_ids - else: - new_grid = [math.prod(collapse_dims), *prog_id_dims] + + new_grid = [math.prod(collapse_dims), *prog_id_dims] assert new_grid[0] < 2**31 - 1, \ "Cannot fix pallas kernel launch grid within CUDA limits" out_indices = [None] * len(grid_mapping.grid) - grid0 = _program_id(0) + grid0 = _program_id(0, new_grid) for i, s in enumerate(collapse_dims): out_idx = launch_grid_to_pallas_grid[i] s = _i32_constant(s) @@ -241,7 +240,7 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): for i in range(len(prog_id_dims)): out_idx = launch_grid_to_pallas_grid[num_collapse + i] - out_indices[out_idx] = _program_id(i + 1) + out_indices[out_idx] = _program_id(i + 1, new_grid) assert len(out_indices) == len(grid_mapping.grid) return new_grid, out_indices @@ -428,9 +427,11 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): # ## Programming model primitives -def _program_id(axis: int) -> ir.Value: +def _program_id(axis: int, launch_grid: Sequence[int]) -> ir.Value: if axis not in range(3): raise ValueError(f"axis must be in [0, 3), but got: {axis}") + if launch_grid[axis] == 1: + return _i32_constant(0) return tt_dialect.get_program_id(axis) From 407dc774f7e28226ceada76829ddfdacddb91ece Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 24 Sep 2024 11:34:08 -0700 Subject: [PATCH 638/702] [Mosaic TPU] Support all cases for extui. PiperOrigin-RevId: 678331795 --- .../tpu/transforms/apply_vector_layout.cc | 97 ++++++++++++++++--- .../tpu/transforms/infer_vector_layout.cc | 16 +-- 2 files changed, 92 insertions(+), 21 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index d3e1b59afe16..a1714fc8090b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -764,10 +764,10 @@ using rule_type = std::function, ArrayRef)>; template -LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, - const VectorLayout &layout_in, - const VectorLayout &layout_out) { - ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); +FailureOr> ext_op_rule_impl(RewriteContext &ctx, + OpBuilder &builder, OpTy op, + const VectorLayout &layout_in, + const VectorLayout &layout_out) { const auto result_ty = cast(op.getResult().getType()); auto source = cast>(op.getIn()); const auto source_ty = source.getType(); @@ -801,7 +801,7 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, int64_t vreg_part = *(input_vreg_idxs.end() - 2) % packing; *(input_vreg_idxs.end() - 2) /= packing; *v = builder.create( - res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); + op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); }); } else { if (layout_in.tiling() != layout_out.tiling()) { @@ -817,17 +817,13 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, input_vreg_idxs.back() /= packing; const int64_t vreg_part = idxs.back() % packing; *v = builder.create( - res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); + op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); }); } if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { output_vregs.Reshape(output_vregs_shape); } - op.replaceAllUsesWith(assemble(builder, result_ty, layout_out, - std::move(output_vregs), ctx.target_shape) - .getResult()); - op.erase(); - return success(); + return output_vregs; } LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op, @@ -842,8 +838,17 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op, return op.emitOpError( "Not implemented: Only 16-bit to 32-bit conversion supported"); } - return ext_op_rule_impl(ctx, extf_op, *layouts_in.front(), - *layouts_out.front()); + ImplicitLocOpBuilder builder(op.getLoc(), &op); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array output_vregs, + ext_op_rule_impl(ctx, builder, extf_op, *layouts_in.front(), + *layouts_out.front())); + const auto result_ty = cast(extf_op.getResult().getType()); + extf_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), + std::move(output_vregs), ctx.target_shape) + .getResult()); + extf_op.erase(); + return success(); } LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op, @@ -854,8 +859,69 @@ LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(layouts_out.size(), 1); TPU_ASSERT_OP(layouts_out.front().has_value()); auto extsi_op = cast(op); - return ext_op_rule_impl(ctx, extsi_op, *layouts_in.front(), - *layouts_out.front()); + ImplicitLocOpBuilder builder(op.getLoc(), &op); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array output_vregs, + ext_op_rule_impl(ctx, builder, extsi_op, *layouts_in.front(), + *layouts_out.front())); + const auto result_ty = cast(extsi_op.getResult().getType()); + extsi_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), + std::move(output_vregs), ctx.target_shape) + .getResult()); + extsi_op.erase(); + return success(); +} + +LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + TPU_ASSERT_EQ_OP(layouts_in.size(), 1); + TPU_ASSERT_OP(layouts_in.front().has_value()); + TPU_ASSERT_EQ_OP(layouts_out.size(), 1); + TPU_ASSERT_OP(layouts_out.front().has_value()); + auto extui_op = cast(op); + auto in_ty = dyn_cast(extui_op.getIn().getType()); + auto out_ty = dyn_cast(extui_op.getType()); + CHECK(in_ty && out_ty); + auto in_bitwidth = in_ty ? in_ty.getElementTypeBitWidth() + : extui_op.getIn().getType().getIntOrFloatBitWidth(); + if (in_bitwidth == 1) { + return elementwise_op_rule(ctx, op, layouts_in, layouts_out); + } + ImplicitLocOpBuilder builder(op.getLoc(), &op); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array output_vregs, + ext_op_rule_impl(ctx, builder, extui_op, *layouts_in.front(), + *layouts_out.front())); + const auto source_ty = cast(extui_op.getIn().getType()); + const auto result_ty = cast(extui_op.getResult().getType()); + auto src_bitwidth = source_ty.getElementTypeBitWidth(); + auto dst_bitwidth = result_ty.getElementTypeBitWidth(); + // Generate a mask to mask out the sign extension. e.g., for u8 -> u16, + // the mask is 0x00ff00ff. + unsigned mask = (1 << src_bitwidth) - 1; + while (dst_bitwidth < 32) { + mask = (mask << dst_bitwidth) | mask; + dst_bitwidth *= 2; + } + const VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), ctx.target_shape); + auto mask_const = builder.create( + op.getLoc(), i32_vreg_ty, DenseIntElementsAttr::get(i32_vreg_ty, {mask})); + const VectorType res_vreg_ty = + getNativeVregType(result_ty.getElementType(), ctx.target_shape); + output_vregs.Each([&](absl::Span _, Value *v) { + Value unpacked = + builder.create(op.getLoc(), i32_vreg_ty, *v); + unpacked = builder.create(op.getLoc(), i32_vreg_ty, unpacked, + mask_const); + *v = builder.create(op.getLoc(), res_vreg_ty, unpacked); + }); + extui_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), + std::move(output_vregs), ctx.target_shape) + .getResult()); + extui_op.erase(); + return success(); } template @@ -4352,6 +4418,7 @@ const llvm::StringMap &rules() { {arith::ConstantOp::getOperationName(), arith_constant_rule}, {arith::ExtFOp::getOperationName(), arith_extf_rule}, {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, + {arith::ExtUIOp::getOperationName(), arith_extui_rule}, {arith::TruncFOp::getOperationName(), arith_truncf_rule}, {arith::TruncIOp::getOperationName(), arith_trunci_rule}, {func::ReturnOp::getOperationName(), func_return_rule}, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index c5a43e898cd0..2894b0797e7b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -196,12 +196,16 @@ class VectorLayoutInferer { auto out_ty = dyn_cast(op.getType()); TPU_CHECK_OP(static_cast(in_ty) == static_cast(out_ty), "Input and output are not both vectors?"); - if (in_ty) { - TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1, - "Only extending i1 is supported"); - } - if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { - return failure(); + auto in_bitwidth = in_ty ? in_ty.getElementTypeBitWidth() + : op.getIn().getType().getIntOrFloatBitWidth(); + if (in_bitwidth == 1) { + if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { + return failure(); + } + } else { + if (inferExt(&any_op).failed()) { + return failure(); + } } } else if (isa(any_op) || isa(any_op)) { Operation *op = &any_op; // For TPU_CHECK_OP macros, which use the `op` From 6e116491c1bb31b5ccc6c17c0d9eb324d95fd38f Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 24 Sep 2024 11:36:13 -0700 Subject: [PATCH 639/702] Add `--use_cuda_nvcc` flag to enable or disable compilation of CUDA code using NVCC. If `--use_cuda_nvcc` flag is set the NVCC compiler driver will be used to build the CUDA code (default behavior). Otherwise, if the flag `--nouse_cuda_nvcc` is set, only the clang compiler will be used to build the CUDA code (effectively disabling NVCC). Mark `--use_clang` flag as deprecated. Refactor `.bazelrc` configs to match the new flag and to cleanup all previous confusing names. PiperOrigin-RevId: 678332548 --- .bazelrc | 52 ++++++++++++++++++++++++----------------------- build/build.py | 36 +++++++++++++++++++++----------- docs/developer.md | 3 ++- 3 files changed, 53 insertions(+), 38 deletions(-) diff --git a/.bazelrc b/.bazelrc index 948d92c29c26..458ce69fae8b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -57,6 +57,16 @@ build:native_arch_posix --host_copt=-march=native build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 +build:clang --action_env=CC="/usr/lib/llvm-18/bin/clang" +# Disable clang extention that rejects type definitions within offsetof. +# This was added in clang-16 by https://reviews.llvm.org/D133574. +# Can be removed once upb is updated, since a type definition is used within +# offset of in the current version of ubp. +# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. +build:clang --copt=-Wno-gnu-offsetof-extensions +# Disable clang extention that rejects unknown arguments. +build:clang --copt=-Qunused-arguments + build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 # "sm" means we emit only cubin, which is forward compatible within a GPU generation. @@ -68,14 +78,6 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true # Default hermetic CUDA and CUDNN versions. build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" -# This flag is needed to include CUDA libraries for bazel tests. -test:cuda --@local_config_cuda//cuda:include_cuda_libs=true - -# Requires MSVC and LLVM to be installed -build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl -build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl -build:win_clang --compiler=clang-cl - # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to # point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA @@ -89,23 +91,18 @@ build:win_clang --compiler=clang-cl # acceptable, because the workaround is "remove the nvidia-..." pip packages. # The list of CUDA pip packages that JAX depends on are present in setup.py. build:cuda --linkopt=-Wl,--disable-new-dtags +build:cuda --@local_config_cuda//:cuda_compiler=clang +build:cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang --@local_config_cuda//:cuda_compiler=clang -build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -# Disable clang extention that rejects type definitions within offsetof. -# This was added in clang-16 by https://reviews.llvm.org/D133574. -# Can be removed once upb is updated, since a type definition is used within -# offset of in the current version of ubp. -# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. -build:cuda_clang --copt=-Wno-gnu-offsetof-extensions -# Disable clang extention that rejects unknown arguments. -build:cuda_clang --copt=-Qunused-arguments +# This flag is needed to include CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_cuda_libs=true -# Build with nvcc for CUDA and clang for host -build:nvcc_clang --config=cuda -build:nvcc_clang --config=cuda_clang -build:nvcc_clang --action_env=TF_NVCC_CLANG="1" -build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc +# Build with NVCC for CUDA +build:cuda_nvcc --config=cuda +build:cuda_nvcc --config=clang +build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc +build:cuda_nvcc --action_env=TF_NVCC_CLANG="1" +build:cuda_nvcc --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true @@ -114,6 +111,11 @@ build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1 build:nonccl --define=no_nccl_support=true +# Requires MSVC and LLVM to be installed +build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl +build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl +build:win_clang --compiler=clang-cl + # Windows has a relatively short command line limit, which JAX has begun to hit. # See https://docs.bazel.build/versions/main/windows.html build:windows --features=compiler_param_file @@ -200,7 +202,7 @@ build:rbe_linux --host_linkopt=-lm # Use the GPU toolchain until the CPU one is ready. # https://github.com/bazelbuild/bazel/issues/13623 build:rbe_cpu_linux_base --config=rbe_linux -build:rbe_cpu_linux_base --config=cuda_clang +build:rbe_cpu_linux_base --config=clang build:rbe_cpu_linux_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" build:rbe_cpu_linux_base --crosstool_top="@local_config_cuda//crosstool:toolchain" build:rbe_cpu_linux_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" @@ -223,7 +225,7 @@ build:rbe_linux_cuda_base --config=cuda build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda12.3_nvcc_base --config=nvcc_clang +build:rbe_linux_cuda12.3_nvcc_base --config=cuda_nvcc build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2" build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" diff --git a/build/build.py b/build/build.py index c3a86278a427..42db37fd74af 100755 --- a/build/build.py +++ b/build/build.py @@ -218,7 +218,7 @@ def get_clang_path_or_exit(): return str(pathlib.Path(which_clang_output).resolve()) else: print( - "--use_clang set, but --clang_path is unset and clang cannot be found" + "--clang_path is unset and clang cannot be found" " on the PATH. Please pass --clang_path directly." ) sys.exit(-1) @@ -241,8 +241,9 @@ def write_bazelrc(*, remote_build, cpu, cuda_compute_capabilities, rocm_amdgpu_targets, target_cpu_features, wheel_cpu, enable_mkl_dnn, use_clang, clang_path, - clang_major_version, enable_cuda, enable_nccl, enable_rocm, - python_version): + clang_major_version, python_version, + enable_cuda, enable_nccl, enable_rocm, + use_cuda_nvcc): with open("../.jax_configure.bazelrc", "w") as f: if not remote_build: @@ -283,11 +284,11 @@ def write_bazelrc(*, remote_build, f.write("build --config=mkl_open_source_only\n") if enable_cuda: f.write("build --config=cuda\n") + f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") if not enable_nccl: f.write("build --config=nonccl\n") - if use_clang: - f.write("build --config=nvcc_clang\n") - f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") + if use_cuda_nvcc: + f.write("build --config=cuda_nvcc\n") if cuda_version: f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n" .format(cuda_version=cuda_version)) @@ -392,15 +393,14 @@ def main(): "use_clang", default = "true", help_str=( - "Should we build using clang as the host compiler? Requires " - "clang to be findable via the PATH, or a path to be given via " - "--clang_path." + "DEPRECATED: This flag is redundant because clang is " + "always used as default compiler." ), ) parser.add_argument( "--clang_path", help=( - "Path to clang binary to use if --use_clang is set. The default is " + "Path to clang binary to use. The default is " "to find clang via the PATH." ), ) @@ -413,7 +413,18 @@ def main(): add_boolean_argument( parser, "enable_cuda", - help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.") + help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN." + ) + add_boolean_argument( + parser, + "use_cuda_nvcc", + default=True, + help_str=( + "Should we build CUDA code using NVCC compiler driver? The default value " + "is true. If --nouse_cuda_nvcc flag is used then CUDA code is built " + "by clang compiler." + ), + ) add_boolean_argument( parser, "build_gpu_plugin", @@ -617,10 +628,11 @@ def main(): use_clang=args.use_clang, clang_path=clang_path, clang_major_version=clang_major_version, + python_version=python_version, enable_cuda=args.enable_cuda, enable_nccl=args.enable_nccl, enable_rocm=args.enable_rocm, - python_version=python_version, + use_cuda_nvcc=args.use_cuda_nvcc, ) if args.requirements_update or args.requirements_nightly_update: diff --git a/docs/developer.md b/docs/developer.md index 4f33614138ef..5f57b2499860 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -86,7 +86,8 @@ There are two ways to build `jaxlib` with CUDA support: (1) use support, or (2) use `python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and -jax-cuda-pjrt). +jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and +clang, but it can be restricted to clang via the `--nouse_cuda_nvcc` flag. See `python build/build.py --help` for configuration options. Here `python` should be the name of your Python 3 interpreter; on some systems, you From 5e3f7618fc14305ab5dc9820b2fc3ef486e14986 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 24 Sep 2024 11:45:49 -0700 Subject: [PATCH 640/702] Support pmin and pmax in check_rep. PiperOrigin-RevId: 678336530 --- jax/experimental/shard_map.py | 28 ++++++++++++++++++++++++++++ tests/shard_map_test.py | 16 ++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 35d665943792..10d4874d7329 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1050,6 +1050,10 @@ def register_standard_collective(prim): register_check(prim)(partial(_standard_collective_check, prim)) register_rewrite(prim)(partial(_standard_collective_rewrite, prim)) +def register_reduction_collective(prim): + register_check(prim)(partial(_reduction_collective_check, prim)) + register_rewrite(prim)(partial(_reduction_collective_rewrite, prim)) + def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): # The standard collective check is varying -> varying over axis_name. del mesh, params @@ -1071,6 +1075,28 @@ def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params): out_val = prim.bind(x, axis_name=axis_name, **params) return [out_val], [x_rep - axis_name_set] +def _reduction_collective_check(prim, mesh, x_rep, *, axes, **params): + # The reduction collective check is varying -> replicated over axes. + del mesh, params + axes = (axes,) if not isinstance(axes, tuple) else axes + if x_rep is None or any(a in x_rep for a in axes): + raise Exception(f"Collective {prim} must be applied to a device-varying " + f"replication type, but got {x_rep} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return x_rep | set(axes) + +def _reduction_collective_rewrite(prim, mesh, in_rep, x, axes, **params): + # The standard collective rewrite may insert a pbroadcast on the input. + axes = (axes,) if not isinstance(axes, tuple) else axes + x_rep, = in_rep + axes_set = set(axes) + if pbroadcast_axes := axes_set & x_rep: + x = pbroadcast(x, tuple(pbroadcast_axes)) + out_val, = prim.bind(x, axes=axes, **params) + return [out_val], [x_rep | axes_set] + for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(), windowed_reductions.__dict__.values(), @@ -1140,6 +1166,8 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): register_standard_collective(lax_parallel.all_to_all_p) register_standard_collective(lax_parallel.ppermute_p) register_standard_collective(lax_parallel.reduce_scatter_p) +register_reduction_collective(lax_parallel.pmin_p) +register_reduction_collective(lax_parallel.pmax_p) @register_check(lax_parallel.axis_index_p) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index fbe9746513f5..397f2d94c7f7 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2179,6 +2179,22 @@ def check_rep(result): f(x, reduce_along=('x',), use_jit=use_jit) f(x, reduce_along=('x', 'y'), use_jit=use_jit) + def test_pmin(self): + mesh = jtu.create_mesh((4,), ('i',)) + x = jnp.arange(8., dtype=np.float32) + y = shard_map(lambda x: jax.lax.pmin(x, 'i'), + mesh=mesh, in_specs=P('i'), out_specs=P() + )(x) # don't crash + self.assertArraysEqual(y, np.array([0, 1], dtype=np.float32)) + + def test_pmax(self): + mesh = jtu.create_mesh((4,), ('i',)) + x = jnp.arange(8., dtype=np.float32) + y = shard_map(lambda x: jax.lax.pmax(x, 'i'), + mesh=mesh, in_specs=P('i'), out_specs=P() + )(x) # don't crash + self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) + class FunSpec(NamedTuple): name: str From d58a09faed78a2308706b362b4833f2cd80a0ab0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 24 Sep 2024 11:51:37 -0700 Subject: [PATCH 641/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/bcc98dcd1cd334a1aa833a1055a840bcd2ac87f5. PiperOrigin-RevId: 678338581 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f8bddcd740bf..dcc12e68eae3 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "1162b7e30d12d00aa4d004a71217ef958d8aa290" -XLA_SHA256 = "706d360fa2f82174fb7210cf7b87470faa2440f7614efc57136f47879d0032ed" +XLA_COMMIT = "bcc98dcd1cd334a1aa833a1055a840bcd2ac87f5" +XLA_SHA256 = "c69acc5dd6eef894a400a5ae9076d3b53c0586acbd7d5970e7f9556d28b28462" def repo(): tf_http_archive( From 70f91db853f5aa6fc353063a1ca1d5a36a73c379 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 24 Sep 2024 12:28:32 -0700 Subject: [PATCH 642/702] Set PYTHONWARNINGS=error in bazel tests. The goal of this change is to catch PRs that introduce new warnings sooner. To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable. Add code to suppress some new warnings uncovered in CI. PiperOrigin-RevId: 678352286 --- benchmarks/mosaic/BUILD | 4 +- docs/cuda_custom_call/BUILD | 4 +- .../jax2tf/examples/keras_reuse_main.py | 5 +- .../jax2tf/examples/keras_reuse_main_test.py | 1 + jax/experimental/jax2tf/examples/mnist_lib.py | 5 +- jax/experimental/jax2tf/tests/call_tf_test.py | 34 +++ jax/experimental/jax2tf/tests/jax2tf_test.py | 27 ++ .../jax2tf/tests/savedmodel_test.py | 11 + .../jax2tf/tests/shape_poly_test.py | 6 +- .../jax2tf/tests/sharding_test.py | 12 +- jax/experimental/mosaic/gpu/examples/BUILD | 6 +- jaxlib/jax.bzl | 12 +- jaxlib/tools/BUILD.bazel | 4 +- pyproject.toml | 5 + tests/BUILD | 268 +++++++++--------- tests/array_interoperability_test.py | 6 + tests/host_callback_to_tf_test.py | 8 + tests/lax_numpy_test.py | 6 +- tests/lax_test.py | 1 + tests/mosaic/BUILD | 10 +- tests/mosaic/gpu_test.py | 3 + tests/pallas/BUILD | 50 ++-- tests/pmap_test.py | 4 + tests/pytorch_interoperability_test.py | 4 + tests/sparse_bcoo_bcsr_test.py | 1 + 25 files changed, 316 insertions(+), 181 deletions(-) diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD index 72aae09af4a2..4345e620a3ae 100644 --- a/benchmarks/mosaic/BUILD +++ b/benchmarks/mosaic/BUILD @@ -15,7 +15,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", "py_deps", ) @@ -42,7 +42,7 @@ DISABLED_CONFIGS = [ "gpu_pjrt_c_api", ] -jax_test( +jax_multiplatform_test( name = "matmul_bench", srcs = ["matmul_bench.py"], disable_backends = DISABLED_BACKENDS, diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD index 0591eed1fbec..0089b6b9fb0d 100644 --- a/docs/cuda_custom_call/BUILD +++ b/docs/cuda_custom_call/BUILD @@ -16,7 +16,7 @@ load( "//jaxlib:jax.bzl", "cuda_library", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", ) licenses(["notice"]) @@ -28,7 +28,7 @@ package( jax_generate_backend_suites() -jax_test( +jax_multiplatform_test( name = "cuda_custom_call_test", srcs = ["cuda_custom_call_test.py"], data = [":foo"], diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main.py b/jax/experimental/jax2tf/examples/keras_reuse_main.py index 77f882af6850..1806e8c4545d 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main.py @@ -18,13 +18,16 @@ See README.md. """ import logging +import warnings from absl import app from absl import flags from jax.experimental.jax2tf.examples import mnist_lib from jax.experimental.jax2tf.examples import saved_model_main import tensorflow as tf import tensorflow_datasets as tfds # type: ignore -import tensorflow_hub as hub # type: ignore +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import tensorflow_hub as hub # type: ignore FLAGS = flags.FLAGS diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py index 2934842912f0..e34282a76ff4 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py @@ -41,6 +41,7 @@ def setUp(self): @parameterized.named_parameters( dict(testcase_name=f"_{model}", model=model) for model in ["mnist_pure_jax", "mnist_flax"]) + @jtu.ignore_warning(message="the imp module is deprecated") def test_keras_reuse(self, model="mnist_pure_jax"): FLAGS.model = model keras_reuse_main.main(None) diff --git a/jax/experimental/jax2tf/examples/mnist_lib.py b/jax/experimental/jax2tf/examples/mnist_lib.py index 41173c79a5b9..77432f9ebd92 100644 --- a/jax/experimental/jax2tf/examples/mnist_lib.py +++ b/jax/experimental/jax2tf/examples/mnist_lib.py @@ -27,6 +27,7 @@ import re import time from typing import Any +import warnings from absl import flags import flax @@ -70,7 +71,9 @@ def load_mnist(split: tfds.Split, batch_size: int): if _MOCK_DATA.value: with tfds.testing.mock_data(num_examples=batch_size): try: - ds = tfds.load("mnist", split=split) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + ds = tfds.load("mnist", split=split) except Exception as e: m = re.search(r'metadata files were not found in (.+/)mnist/', str(e)) if m: diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index e10c3fbfdff7..492dfad4c855 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -88,6 +88,17 @@ def setUp(self): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + self.warning_ctx = jtu.ignore_warning( + message=( + "(jax2tf.convert with native_serialization=False is deprecated" + "|Calling from_dlpack with a DLPack tensor is deprecated)" + ) + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() @_parameterized_jit def test_eval_scalar_arg(self, with_jit=True): @@ -862,6 +873,7 @@ def _transfer_guard(guard_level): class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase): "Reloading output of jax2tf into JAX with call_tf" + def setUp(self): if tf is None: raise unittest.SkipTest("Test requires tensorflow") @@ -869,6 +881,17 @@ def setUp(self): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + self.warning_ctx = jtu.ignore_warning( + message=( + "(jax2tf.convert with native_serialization=False is deprecated" + "|Calling from_dlpack with a DLPack tensor is deprecated)" + ) + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() def test_simple(self): f_jax = jnp.sin @@ -1157,6 +1180,17 @@ def setUp(self): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + self.warning_ctx = jtu.ignore_warning( + message=( + "(jax2tf.convert with native_serialization=False is deprecated" + "|Calling from_dlpack with a DLPack tensor is deprecated)" + ) + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() def test_alternate(self): # Alternate sin/cos with sin in TF and cos in JAX diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index ef7a5ee2c138..6411dc581424 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -76,6 +76,17 @@ def setUpClass(cls): super().setUpClass() + def setUp(self): + super().setUp() + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def test_empty(self): f_jax = lambda x, y: x self.ConvertAndCompare(f_jax, 0.7, 1) @@ -1621,6 +1632,8 @@ def f_jax(*many_args): res = jax2tf.convert(f_jax, native_serialization=True)(*many_args) self.assertAllClose(f_jax(*many_args), res) + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def test_nested_convert(self): # Test call sequence: convert -> call_tf -> convert. @@ -1677,6 +1690,17 @@ def f_jax(x): @jtu.with_config(jax_enable_custom_prng=True) class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): + def setUp(self): + super().setUp() + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def test_key_argument(self): func = lambda key: jax.random.uniform(key, ()) key = jax.random.PRNGKey(0) @@ -1709,6 +1733,9 @@ def setUp(self): self.use_max_serialization_version = False super().setUp() + @jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) def test_simple(self): self.ConvertAndCompare(jnp.sin, 0.7) diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py index bc19915d1644..aee15883332a 100644 --- a/jax/experimental/jax2tf/tests/savedmodel_test.py +++ b/jax/experimental/jax2tf/tests/savedmodel_test.py @@ -30,6 +30,17 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase): + def setUp(self): + super().setUp() + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def test_eval(self): f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) model = tf.Module() diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index a9ee1776222c..07bd9b5aed22 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -334,7 +334,6 @@ def f_jax(x): # x: i32[b] check_shape_poly(self, f_jax, arg_descriptors=[x], polymorphic_shapes=["b"]) - @jtu.parameterized_filterable( kwargs=[ dict(testcase_name=f"expr={name}", expr=expr) @@ -941,7 +940,7 @@ def test_grad_int(self, with_function=False): xi_yf = (xi, yf) zb = np.array([True, False], dtype=np.bool_) def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2] - # results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4] + # results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4] xi, yf = xi_yf # Return a tuple: # (1) float constant, with 0 tangent; @@ -1032,6 +1031,9 @@ def f_jax(x): # A function whose gradient is a constant f_tf, input_signature=[tf.TensorSpec([None], x.dtype)]) self.assertAllClose(f_jax(x), restored_f(x)) + @jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) def test_readme_examples(self): """Some of the examples from the README.""" diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 9009c1586f15..24713539512c 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -61,7 +61,8 @@ def setUpModule(): global topology if jtu.test_device_matches(["tpu"]): - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + with jtu.ignore_warning(message="the imp module is deprecated"): + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') tf.config.experimental_connect_to_cluster(resolver) # Do TPU init at beginning since it will wipe out all HBMs. topology = tf.tpu.experimental.initialize_tpu_system(resolver) @@ -84,6 +85,15 @@ def setUp(self): raise unittest.SkipTest("Test requires at least 2 local devices") self.devices = np.array(jax.devices()[:2]) # use 2 devices + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def log_jax_hlo(self, f_jax, args: Sequence[Any], *, num_replicas=1, num_partitions=2): """Log the HLO generated from JAX before and after optimizations""" diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index 6f5af51fbf0f..57f78cb2c5c8 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_library", "py_test") -load("//jaxlib:jax.bzl", "py_deps") +load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "jax_py_test", "py_deps") licenses(["notice"]) @@ -48,7 +48,7 @@ py_library( ], ) -py_test( +jax_py_test( name = "run_matmul", srcs = ["matmul.py"], main = "matmul.py", diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index cf9047cc4e17..65ec572c7ee2 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -19,6 +19,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") +load("@rules_python//python:defs.bzl", "py_test") load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") @@ -222,7 +223,7 @@ def if_building_jaxlib( }) # buildifier: disable=function-docstring -def jax_test( +def jax_multiplatform_test( name, srcs, args = [], @@ -300,3 +301,12 @@ jax_test_file_visibility = [] def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable pass + +def jax_py_test( + name, + env = {}, + **kwargs): + env = dict(env) + if "PYTHONWARNINGS" not in env: + env["PYTHONWARNINGS"] = "error" + py_test(name = name, env = env, **kwargs) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 4642af12011d..4553dc1e3ea8 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -16,7 +16,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("//jaxlib:jax.bzl", "if_windows") +load("//jaxlib:jax.bzl", "if_windows", "jax_py_test") licenses(["notice"]) # Apache 2 @@ -52,7 +52,7 @@ py_binary( ], ) -py_test( +jax_py_test( name = "build_wheel_test", srcs = ["build_wheel_test.py"], data = [":build_wheel"], diff --git a/pyproject.toml b/pyproject.toml index b629762feff9..9ce13ea501e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,11 @@ filterwarnings = [ "default:Special cases found for .* but none were parsed.*:UserWarning", "default:.*is not JSON-serializable. Using the repr instead.*:UserWarning", "default:The .* method is good for exploring strategies.*", + + # NOTE: this is probably not where you want to add code to suppress a + # warning. Only pytest tests look at this list, whereas Bazel tests also + # check for warnings and do not check this list. Most likely, you should + # add a @jtu.ignore_warning decorator to your test instead. ] doctest_optionflags = [ "NUMBER", diff --git a/tests/BUILD b/tests/BUILD index e64889cc3b1f..49dbf05125ca 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_test") load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", + "jax_py_test", "jax_test_file_visibility", "py_deps", "pytype_test", @@ -31,29 +31,29 @@ package( jax_generate_backend_suites() -jax_test( +jax_multiplatform_test( name = "api_test", srcs = ["api_test.py"], shard_count = 10, ) -jax_test( +jax_multiplatform_test( name = "device_test", srcs = ["device_test.py"], ) -jax_test( +jax_multiplatform_test( name = "dynamic_api_test", srcs = ["dynamic_api_test.py"], shard_count = 2, ) -jax_test( +jax_multiplatform_test( name = "api_util_test", srcs = ["api_util_test.py"], ) -py_test( +jax_py_test( name = "array_api_test", srcs = ["array_api_test.py"], deps = [ @@ -63,7 +63,7 @@ py_test( ] + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], disable_backends = ["tpu"], @@ -71,7 +71,7 @@ jax_test( deps = py_deps("tensorflow_core"), ) -jax_test( +jax_multiplatform_test( name = "batching_test", srcs = ["batching_test.py"], shard_count = { @@ -79,12 +79,12 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "config_test", srcs = ["config_test.py"], ) -jax_test( +jax_multiplatform_test( name = "core_test", srcs = ["core_test.py"], shard_count = { @@ -93,17 +93,17 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "custom_object_test", srcs = ["custom_object_test.py"], ) -jax_test( +jax_multiplatform_test( name = "debug_nans_test", srcs = ["debug_nans_test.py"], ) -py_test( +jax_py_test( name = "multiprocess_gpu_test", srcs = ["multiprocess_gpu_test.py"], args = [ @@ -116,12 +116,12 @@ py_test( ] + py_deps("portpicker"), ) -jax_test( +jax_multiplatform_test( name = "dtypes_test", srcs = ["dtypes_test.py"], ) -jax_test( +jax_multiplatform_test( name = "errors_test", srcs = ["errors_test.py"], # No need to test all other configs. @@ -130,13 +130,13 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "extend_test", srcs = ["extend_test.py"], deps = ["//jax:extend"], ) -jax_test( +jax_multiplatform_test( name = "fft_test", srcs = ["fft_test.py"], backend_tags = { @@ -152,12 +152,12 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "generated_fun_test", srcs = ["generated_fun_test.py"], ) -jax_test( +jax_multiplatform_test( name = "gpu_memory_flags_test_no_preallocation", srcs = ["gpu_memory_flags_test.py"], disable_backends = [ @@ -170,7 +170,7 @@ jax_test( main = "gpu_memory_flags_test.py", ) -jax_test( +jax_multiplatform_test( name = "gpu_memory_flags_test", srcs = ["gpu_memory_flags_test.py"], disable_backends = [ @@ -182,7 +182,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lobpcg_test", srcs = ["lobpcg_test.py"], env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, @@ -196,7 +196,7 @@ jax_test( ] + py_deps("matplotlib"), ) -jax_test( +jax_multiplatform_test( name = "svd_test", srcs = ["svd_test.py"], shard_count = { @@ -206,7 +206,7 @@ jax_test( }, ) -py_test( +jax_py_test( name = "xla_interpreter_test", srcs = ["xla_interpreter_test.py"], deps = [ @@ -215,7 +215,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "memories_test", srcs = ["memories_test.py"], shard_count = { @@ -226,7 +226,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pjit_test", srcs = ["pjit_test.py"], backend_tags = { @@ -249,7 +249,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "layout_test", srcs = ["layout_test.py"], backend_tags = { @@ -258,7 +258,7 @@ jax_test( tags = ["multiaccelerator"], ) -jax_test( +jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], deps = [ @@ -266,7 +266,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pgle_test", srcs = ["pgle_test.py"], backend_tags = { @@ -286,7 +286,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "mock_gpu_test", srcs = ["mock_gpu_test.py"], disable_backends = [ @@ -301,7 +301,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "array_test", srcs = ["array_test.py"], backend_tags = { @@ -314,7 +314,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "aot_test", srcs = ["aot_test.py"], tags = ["multiaccelerator"], @@ -323,7 +323,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "image_test", srcs = ["image_test.py"], shard_count = { @@ -335,7 +335,7 @@ jax_test( deps = py_deps("pil") + py_deps("tensorflow_core"), ) -jax_test( +jax_multiplatform_test( name = "infeed_test", srcs = ["infeed_test.py"], deps = [ @@ -343,13 +343,13 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "jax_jit_test", srcs = ["jax_jit_test.py"], main = "jax_jit_test.py", ) -py_test( +jax_py_test( name = "jax_to_ir_test", srcs = ["jax_to_ir_test.py"], deps = [ @@ -359,7 +359,7 @@ py_test( ] + py_deps("tensorflow_core"), ) -py_test( +jax_py_test( name = "jaxpr_util_test", srcs = ["jaxpr_util_test.py"], deps = [ @@ -369,7 +369,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "jet_test", srcs = ["jet_test.py"], shard_count = { @@ -382,7 +382,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "lax_control_flow_test", srcs = ["lax_control_flow_test.py"], shard_count = { @@ -392,7 +392,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "custom_root_test", srcs = ["custom_root_test.py"], shard_count = { @@ -402,7 +402,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "custom_linear_solve_test", srcs = ["custom_linear_solve_test.py"], shard_count = { @@ -412,7 +412,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_test", srcs = ["lax_numpy_test.py"], backend_tags = { @@ -429,7 +429,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_operators_test", srcs = ["lax_numpy_operators_test.py"], shard_count = { @@ -439,7 +439,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_reducers_test", srcs = ["lax_numpy_reducers_test.py"], shard_count = { @@ -449,7 +449,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_indexing_test", srcs = ["lax_numpy_indexing_test.py"], shard_count = { @@ -459,7 +459,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_einsum_test", srcs = ["lax_numpy_einsum_test.py"], shard_count = { @@ -469,7 +469,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_ufuncs_test", srcs = ["lax_numpy_ufuncs_test.py"], shard_count = { @@ -479,12 +479,12 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_vectorize_test", srcs = ["lax_numpy_vectorize_test.py"], ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_test", srcs = ["lax_scipy_test.py"], shard_count = { @@ -495,7 +495,7 @@ jax_test( deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_sparse_test", srcs = ["lax_scipy_sparse_test.py"], backend_tags = { @@ -508,7 +508,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_special_functions_test", srcs = ["lax_scipy_special_functions_test.py"], backend_tags = { @@ -522,7 +522,7 @@ jax_test( deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_spectral_dac_test", srcs = ["lax_scipy_spectral_dac_test.py"], shard_count = { @@ -535,7 +535,7 @@ jax_test( ] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_test", srcs = ["lax_test.py"], backend_tags = { @@ -552,7 +552,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "lax_metal_test", srcs = ["lax_metal_test.py"], disable_backends = [ @@ -567,7 +567,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "lax_autodiff_test", srcs = ["lax_autodiff_test.py"], shard_count = { @@ -577,7 +577,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_vmap_test", srcs = ["lax_vmap_test.py"], shard_count = { @@ -588,7 +588,7 @@ jax_test( deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_vmap_op_test", srcs = ["lax_vmap_op_test.py"], shard_count = { @@ -599,7 +599,7 @@ jax_test( deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), ) -py_test( +jax_py_test( name = "lazy_loader_test", srcs = [ "lazy_loader_test.py", @@ -610,7 +610,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "deprecation_test", srcs = [ "deprecation_test.py", @@ -621,7 +621,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "linalg_test", srcs = ["linalg_test.py"], backend_tags = { @@ -640,12 +640,12 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "cholesky_update_test", srcs = ["cholesky_update_test.py"], ) -jax_test( +jax_multiplatform_test( name = "metadata_test", srcs = ["metadata_test.py"], disable_backends = [ @@ -654,7 +654,7 @@ jax_test( ], ) -py_test( +jax_py_test( name = "monitoring_test", srcs = ["monitoring_test.py"], deps = [ @@ -663,12 +663,12 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "multibackend_test", srcs = ["multibackend_test.py"], ) -jax_test( +jax_multiplatform_test( name = "multi_device_test", srcs = ["multi_device_test.py"], disable_backends = [ @@ -677,7 +677,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "nn_test", srcs = ["nn_test.py"], backend_tags = { @@ -695,13 +695,13 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "optimizers_test", srcs = ["optimizers_test.py"], deps = ["//jax:optimizers"], ) -jax_test( +jax_multiplatform_test( name = "pickle_test", srcs = ["pickle_test.py"], deps = [ @@ -709,7 +709,7 @@ jax_test( ] + py_deps("cloudpickle") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "pmap_test", srcs = ["pmap_test.py"], backend_tags = { @@ -729,7 +729,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "polynomial_test", srcs = ["polynomial_test.py"], # No implementation of nonsymmetric Eigendecomposition. @@ -749,7 +749,7 @@ jax_test( tags = ["nomsan"], ) -jax_test( +jax_multiplatform_test( name = "heap_profiler_test", srcs = ["heap_profiler_test.py"], disable_backends = [ @@ -758,7 +758,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "profiler_test", srcs = ["profiler_test.py"], disable_backends = [ @@ -767,7 +767,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], disable_backends = ["tpu"], @@ -786,7 +786,7 @@ jax_test( deps = py_deps("torch"), ) -jax_test( +jax_multiplatform_test( name = "qdwh_test", srcs = ["qdwh_test.py"], backend_tags = { @@ -799,7 +799,7 @@ jax_test( shard_count = 10, ) -jax_test( +jax_multiplatform_test( name = "random_test", srcs = ["random_test.py"], backend_tags = { @@ -821,7 +821,7 @@ jax_test( tags = ["noasan"], # Times out ) -jax_test( +jax_multiplatform_test( name = "random_lax_test", srcs = ["random_lax_test.py"], backend_tags = { @@ -847,7 +847,7 @@ jax_test( ) # TODO(b/199564969): remove once we always enable_custom_prng -jax_test( +jax_multiplatform_test( name = "random_test_with_custom_prng", srcs = ["random_test.py"], args = ["--jax_enable_custom_prng=true"], @@ -872,7 +872,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "scipy_fft_test", srcs = ["scipy_fft_test.py"], backend_tags = { @@ -885,22 +885,22 @@ jax_test( shard_count = 4, ) -jax_test( +jax_multiplatform_test( name = "scipy_interpolate_test", srcs = ["scipy_interpolate_test.py"], ) -jax_test( +jax_multiplatform_test( name = "scipy_ndimage_test", srcs = ["scipy_ndimage_test.py"], ) -jax_test( +jax_multiplatform_test( name = "scipy_optimize_test", srcs = ["scipy_optimize_test.py"], ) -jax_test( +jax_multiplatform_test( name = "scipy_signal_test", srcs = ["scipy_signal_test.py"], backend_tags = { @@ -925,13 +925,13 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "scipy_spatial_test", srcs = ["scipy_spatial_test.py"], deps = py_deps("scipy"), ) -jax_test( +jax_multiplatform_test( name = "scipy_stats_test", srcs = ["scipy_stats_test.py"], backend_tags = { @@ -948,7 +948,7 @@ jax_test( ], # Times out ) -jax_test( +jax_multiplatform_test( name = "sparse_test", srcs = ["sparse_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], @@ -981,7 +981,7 @@ jax_test( ] + py_deps("scipy"), ) -jax_test( +jax_multiplatform_test( name = "sparse_bcoo_bcsr_test", srcs = ["sparse_bcoo_bcsr_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], @@ -1014,7 +1014,7 @@ jax_test( ] + py_deps("scipy"), ) -jax_test( +jax_multiplatform_test( name = "sparse_nm_test", srcs = ["sparse_nm_test.py"], config_tags_overrides = { @@ -1037,7 +1037,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "sparsify_test", srcs = ["sparsify_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], @@ -1061,12 +1061,12 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "stack_test", srcs = ["stack_test.py"], ) -jax_test( +jax_multiplatform_test( name = "checkify_test", srcs = ["checkify_test.py"], shard_count = { @@ -1075,7 +1075,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], shard_count = { @@ -1085,18 +1085,18 @@ jax_test( deps = ["//jax:stax"], ) -jax_test( +jax_multiplatform_test( name = "linear_search_test", srcs = ["third_party/scipy/line_search_test.py"], main = "third_party/scipy/line_search_test.py", ) -jax_test( +jax_multiplatform_test( name = "blocked_sampler_test", srcs = ["blocked_sampler_test.py"], ) -py_test( +jax_py_test( name = "tree_util_test", srcs = ["tree_util_test.py"], deps = [ @@ -1114,7 +1114,7 @@ pytype_test( ], ) -py_test( +jax_py_test( name = "util_test", srcs = ["util_test.py"], deps = [ @@ -1123,7 +1123,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "version_test", srcs = ["version_test.py"], deps = [ @@ -1132,7 +1132,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "xla_bridge_test", srcs = ["xla_bridge_test.py"], data = ["testdata/example_pjrt_plugin_config.json"], @@ -1143,7 +1143,7 @@ py_test( ] + py_deps("absl/logging"), ) -py_test( +jax_py_test( name = "lru_cache_test", srcs = ["lru_cache_test.py"], deps = [ @@ -1153,7 +1153,7 @@ py_test( ] + py_deps("filelock"), ) -jax_test( +jax_multiplatform_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.py"], deps = [ @@ -1162,7 +1162,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "cache_key_test", srcs = ["cache_key_test.py"], deps = [ @@ -1171,7 +1171,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "ode_test", srcs = ["ode_test.py"], shard_count = { @@ -1180,7 +1180,7 @@ jax_test( deps = ["//jax:ode"], ) -jax_test( +jax_multiplatform_test( name = "host_callback_outfeed_test", srcs = ["host_callback_test.py"], args = ["--jax_host_callback_outfeed=true"], @@ -1197,7 +1197,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "host_callback_test", srcs = ["host_callback_test.py"], args = ["--jax_host_callback_outfeed=false"], @@ -1213,7 +1213,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "host_callback_to_tf_test", srcs = ["host_callback_to_tf_test.py"], tags = ["noasan"], # Linking TF causes a linker OOM. @@ -1223,12 +1223,12 @@ jax_test( ] + py_deps("tensorflow_core"), ) -jax_test( +jax_multiplatform_test( name = "key_reuse_test", srcs = ["key_reuse_test.py"], ) -jax_test( +jax_multiplatform_test( name = "x64_context_test", srcs = ["x64_context_test.py"], deps = [ @@ -1236,13 +1236,13 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "ann_test", srcs = ["ann_test.py"], shard_count = 10, ) -py_test( +jax_py_test( name = "mesh_utils_test", srcs = ["mesh_utils_test.py"], deps = [ @@ -1252,17 +1252,17 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "transfer_guard_test", srcs = ["transfer_guard_test.py"], ) -jax_test( +jax_multiplatform_test( name = "name_stack_test", srcs = ["name_stack_test.py"], ) -jax_test( +jax_multiplatform_test( name = "jaxpr_effects_test", srcs = ["jaxpr_effects_test.py"], backend_tags = { @@ -1275,7 +1275,7 @@ jax_test( tags = ["multiaccelerator"], ) -jax_test( +jax_multiplatform_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], enable_configs = [ @@ -1284,7 +1284,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "python_callback_test", srcs = ["python_callback_test.py"], backend_tags = { @@ -1296,7 +1296,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "debugger_test", srcs = ["debugger_test.py"], enable_configs = [ @@ -1305,7 +1305,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "state_test", srcs = ["state_test.py"], # Use fewer cases to prevent timeouts. @@ -1327,12 +1327,12 @@ jax_test( deps = py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "mutable_array_test", srcs = ["mutable_array_test.py"], ) -jax_test( +jax_multiplatform_test( name = "for_loop_test", srcs = ["for_loop_test.py"], shard_count = { @@ -1342,7 +1342,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], shard_count = { @@ -1362,12 +1362,12 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "clear_backends_test", srcs = ["clear_backends_test.py"], ) -jax_test( +jax_multiplatform_test( name = "attrs_test", srcs = ["attrs_test.py"], deps = [ @@ -1375,7 +1375,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "experimental_rnn_test", srcs = ["experimental_rnn_test.py"], disable_backends = [ @@ -1391,7 +1391,7 @@ jax_test( ], ) -py_test( +jax_py_test( name = "mosaic_test", srcs = ["mosaic_test.py"], deps = [ @@ -1401,7 +1401,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "source_info_test", srcs = ["source_info_test.py"], deps = [ @@ -1410,7 +1410,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "package_structure_test", srcs = ["package_structure_test.py"], deps = [ @@ -1419,12 +1419,12 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "logging_test", srcs = ["logging_test.py"], ) -jax_test( +jax_multiplatform_test( name = "export_test", srcs = ["export_test.py"], enable_configs = [ @@ -1436,7 +1436,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "shape_poly_test", srcs = ["shape_poly_test.py"], disable_configs = [ @@ -1461,7 +1461,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "export_harnesses_multi_platform_test", srcs = ["export_harnesses_multi_platform_test.py"], disable_configs = [ @@ -1484,7 +1484,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "export_back_compat_test", srcs = ["export_back_compat_test.py"], tags = [], @@ -1494,7 +1494,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "fused_attention_stablehlo_test", srcs = ["fused_attention_stablehlo_test.py"], disable_backends = [ @@ -1507,13 +1507,13 @@ jax_test( tags = ["multiaccelerator"], ) -jax_test( +jax_multiplatform_test( name = "xla_metadata_test", srcs = ["xla_metadata_test.py"], deps = ["//jax:experimental"], ) -py_test( +jax_py_test( name = "pretty_printer_test", srcs = ["pretty_printer_test.py"], deps = [ @@ -1522,7 +1522,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "sourcemap_test", srcs = ["sourcemap_test.py"], deps = [ @@ -1531,7 +1531,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "cudnn_fusion_test", srcs = ["cudnn_fusion_test.py"], disable_backends = [ diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 3560241530c5..02f5ad527c61 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -75,6 +75,8 @@ def setUp(self): use_stream=[False, True], ) @jtu.run_on_devices("gpu") + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testJaxRoundTrip(self, shape, dtype, copy, use_stream): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) @@ -142,6 +144,8 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu): dtype=dlpack_dtypes, ) @unittest.skipIf(not tf, "Test requires TensorFlow") + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTensorFlowToJax(self, shape, dtype): if (not config.enable_x64.value and dtype in [jnp.int64, jnp.uint64, jnp.float64]): @@ -184,6 +188,8 @@ def testJaxToTensorFlow(self, shape, dtype): self.assertAllClose(np, y.numpy()) @unittest.skipIf(not tf, "Test requires TensorFlow") + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTensorFlowToJaxInt64(self): # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( diff --git a/tests/host_callback_to_tf_test.py b/tests/host_callback_to_tf_test.py index fe80c90ace68..3a36ce1296a6 100644 --- a/tests/host_callback_to_tf_test.py +++ b/tests/host_callback_to_tf_test.py @@ -176,6 +176,8 @@ def supported_only_in_legacy_mode(self): testcase_name=f"_{ad=}", ad=ad) for ad in CALL_TF_IMPLEMENTATIONS.keys()) + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_impl(self, ad="simple"): self.supported_only_in_legacy_mode() call_tf = CALL_TF_IMPLEMENTATIONS[ad] @@ -197,6 +199,8 @@ def f_outside(x): ad=ad) for ad in CALL_TF_IMPLEMENTATIONS.keys() if ad != "none") + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_grad(self, ad="simple"): self.supported_only_in_legacy_mode() call_tf = CALL_TF_IMPLEMENTATIONS[ad] @@ -217,6 +221,8 @@ def f_outside(x): self.assertAllClose(jax.grad(f_jax)(x), grad_f, check_dtypes=False) + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_grad_pytree(self): self.supported_only_in_legacy_mode() call_tf = call_tf_full_ad @@ -246,6 +252,8 @@ def f_outside(xy): testcase_name=f"_degree=_{degree}", degree=degree) for degree in [1, 2, 3, 4]) + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_higher_order_grad(self, degree=4): self.supported_only_in_legacy_mode() call_tf = call_tf_full_ad diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 371a13f0cde6..a10a7369721f 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2681,6 +2681,7 @@ def np_fun(x1, x2): shape=all_shapes, dtype=default_dtypes, ) + @jtu.ignore_warning(category=RuntimeWarning, message="overflow") def testFrexp(self, shape, dtype, rng_factory): # integer types are converted to float64 in numpy's implementation if (dtype not in [jnp.bfloat16, np.float16, np.float32] @@ -6270,7 +6271,8 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]: for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin): args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes) try: - with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"): + with jtu.ignore_warning( + category=RuntimeWarning, message="(divide by zero|invalid value)"): _ = func(*args) except TypeError: pass @@ -6292,7 +6294,7 @@ def testUfuncInputTypes(self, name, arg_dtypes): jnp_op = getattr(jnp, name) np_op = getattr(np, name) np_op = jtu.ignore_warning(category=RuntimeWarning, - message="divide by zero.*")(np_op) + message="(divide by zero|invalid value)")(np_op) args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes) with jtu.strict_promotion_if_dtypes_match(arg_dtypes): diff --git a/tests/lax_test.py b/tests/lax_test.py index 3f43773a8ec7..d82b35c6b711 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -110,6 +110,7 @@ def testOp(self, op_name, rng_factory, shapes, dtype): for shape_group in lax_test_util.compatible_shapes), dtype=rec.dtypes) for rec in lax_test_util.lax_ops())) + @jtu.ignore_warning(message="invalid value", category=RuntimeWarning) def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol): if (not config.enable_x64.value and op_name == "nextafter" and dtype == np.float64): diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 6e5c94982d47..9eadc08d4987 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -15,7 +15,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", "py_deps", ) @@ -43,7 +43,7 @@ DISABLED_CONFIGS = [ "gpu", ] -jax_test( +jax_multiplatform_test( name = "gpu_test", srcs = ["gpu_test.py"], config_tags_overrides = { @@ -63,7 +63,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "matmul_test", srcs = ["matmul_test.py"], disable_backends = DISABLED_BACKENDS, @@ -75,7 +75,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "flash_attention", srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"], disable_backends = DISABLED_BACKENDS, @@ -87,7 +87,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "flash_attention_test", srcs = ["flash_attention_test.py"], disable_backends = DISABLED_BACKENDS, diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 2eacf7c9984c..30f830c31ccf 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1200,6 +1200,8 @@ class FragmentedArrayTest(TestCase): m=(64, 128), n=(8, 16, 32, 64, 80, 128, 256), ) + @jtu.ignore_warning(message="(invalid value|divide by zero)", + category=RuntimeWarning) def test_binary(self, op, dtype, m=64, n=32): if isinstance(op, tuple): op, np_op = op @@ -1294,6 +1296,7 @@ def kernel(ctx, dst, _): ], approx=[False, True], ) + @jtu.ignore_warning(message="overflow encountered", category=RuntimeWarning) def test_math(self, ops, approx, m=64, n=32): op, np_op = ops def kernel(ctx, dst, _): diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 6804d91675c1..e535f1f59dac 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -15,7 +15,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", "py_deps", ) @@ -28,7 +28,7 @@ package( jax_generate_backend_suites() -jax_test( +jax_multiplatform_test( name = "pallas_test", srcs = [ "pallas_test.py", @@ -62,7 +62,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "pallas_jumble_test", srcs = [ "pallas_jumble_test.py", @@ -85,7 +85,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "ops_test", srcs = [ "ops_test.py", @@ -125,7 +125,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "indexing_test", srcs = [ "indexing_test.py", @@ -144,7 +144,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "pallas_vmap_test", srcs = [ "pallas_vmap_test.py", @@ -176,7 +176,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "mosaic_gpu_test", srcs = [ "mosaic_gpu_test.py", @@ -213,7 +213,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "export_back_compat_pallas_test", srcs = ["export_back_compat_pallas_test.py"], config_tags_overrides = { @@ -244,7 +244,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "export_pallas_test", srcs = ["export_pallas_test.py"], config_tags_overrides = { @@ -272,7 +272,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pallas_shape_poly_test", srcs = ["pallas_shape_poly_test.py"], config_tags_overrides = { @@ -299,7 +299,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pallas_error_handling_test", srcs = [ "pallas_error_handling_test.py", @@ -317,7 +317,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "tpu_all_gather_test", srcs = [ "tpu_all_gather_test.py", @@ -331,7 +331,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "tpu_gmm_test", srcs = [ "tpu_gmm_test.py", @@ -356,7 +356,7 @@ jax_test( ]), ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_test", srcs = ["tpu_pallas_test.py"], # The flag is necessary for ``pl.debug_print`` tests to work on TPU. @@ -372,7 +372,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "tpu_ops_test", srcs = [ "tpu_ops_test.py", @@ -388,7 +388,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_distributed_test", srcs = ["tpu_pallas_distributed_test.py"], disable_backends = [ @@ -402,7 +402,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_pipeline_test", srcs = ["tpu_pallas_pipeline_test.py"], disable_backends = [ @@ -422,7 +422,7 @@ jax_test( ] + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_async_test", srcs = ["tpu_pallas_async_test.py"], disable_backends = [ @@ -436,7 +436,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_mesh_test", srcs = ["tpu_pallas_mesh_test.py"], disable_backends = [ @@ -454,7 +454,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_random_test", srcs = [ "tpu_pallas_random_test.py", @@ -472,7 +472,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], disable_backends = [ @@ -490,7 +490,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "tpu_splash_attention_kernel_test", srcs = [ "tpu_splash_attention_kernel_test.py", @@ -510,7 +510,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "tpu_splash_attention_mask_test", srcs = [ "tpu_splash_attention_mask_test.py", @@ -523,7 +523,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "gpu_attention_test", srcs = [ "gpu_attention_test.py", @@ -556,7 +556,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "gpu_ops_test", srcs = [ "gpu_ops_test.py", diff --git a/tests/pmap_test.py b/tests/pmap_test.py index d7dcc7ba3cc4..9a8d0b91272b 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3209,8 +3209,12 @@ def setUp(self): self.jit_disabled = config.disable_jit.value config.update('jax_disable_jit', True) config.update('jax_eager_pmap', True) + self.warning_ctx = jtu.ignore_warning( + message="Some donated buffers were not usable", category=UserWarning) + self.warning_ctx.__enter__() def tearDown(self): + self.warning_ctx.__exit__(None, None, None) config.update('jax_eager_pmap', self.eager_pmap_enabled) config.update('jax_disable_jit', self.jit_disabled) super().tearDown() diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index 2e0fc32238ae..e41c4329b95b 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -108,6 +108,8 @@ def testJaxArrayToTorch(self, shape, dtype): else: self.assertAllClose(np, y.cpu().numpy()) + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTorchToJaxInt64(self): # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( @@ -116,6 +118,8 @@ def testTorchToJaxInt64(self): self.assertEqual(x.dtype, dtype_expected) @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTorchToJax(self, shape, dtype): if not config.enable_x64.value and dtype in [ jnp.int64, diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index 38fde72f0440..12088db7fe18 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -973,6 +973,7 @@ def test_bcoo_spdot_general_nse(self, lhs_shape, rhs_shape): self.assertArraysAllClose(out.todense(), expected_out) self.assertEqual(out.nse, expected_nse) + @jtu.ignore_warning(message="bcoo_dot_general cusparse/hipsparse lowering not available") def test_bcoo_spdot_general_ad_bug(self): # Regression test for https://github.com/jax-ml/jax/issues/10163 A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]]) From 85a466d730079a9daaa94486607816dc83701508 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 24 Sep 2024 13:09:42 -0700 Subject: [PATCH 643/702] Lower the shard count for sparse_bcoo_bcsr_test on TPU as well. There are flaky timeouts in CI, and we've already lowered the shard count on multiple other platforms. PiperOrigin-RevId: 678367575 --- tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/BUILD b/tests/BUILD index 49dbf05125ca..0cc6ed6d9d8c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -997,6 +997,7 @@ jax_multiplatform_test( "cpu": ["--jax_num_generated_cases=40"], "cpu_x32": ["--jax_num_generated_cases=40"], "gpu": ["--jax_num_generated_cases=40"], + "tpu": ["--jax_num_generated_cases=40"], }, shard_count = { "cpu": 50, From e1a68eee5ecb732968fd44eb698e5db753f154bd Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 3 Sep 2024 14:31:15 -0400 Subject: [PATCH 644/702] Add FFI example project and test on CI. This PR includes an end-to-end example project which demonstrates the use of the FFI. This complements [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html) by putting all of the code in one place, as well as demonstrating how FFI extensions can be packaged. Alongside the example project, I have also added a new GitHub Actions workflow to test the example as part of CI. For now, the tests only run on CPU, but once we have GPU runners for GitHub Actions (soon!), I plan on migrating the custom call examples from `docs/gpu_ops` and `docs/cuda_custom_call` into this test case. Similarly, I wanted to start small and this example project only includes exactly the same functions as the tutorial for now, but I think this could be a good place to showcase more advanced examples (including custom calls with state). --- .github/workflows/ci-build.yaml | 35 ++++- examples/ffi/CMakeLists.txt | 15 ++ examples/ffi/README.md | 9 ++ examples/ffi/pyproject.toml | 11 ++ examples/ffi/src/jax_ffi_example/__init__.py | 13 ++ examples/ffi/src/jax_ffi_example/rms_norm.cc | 157 +++++++++++++++++++ examples/ffi/src/jax_ffi_example/rms_norm.py | 99 ++++++++++++ examples/ffi/tests/rms_norm_test.py | 46 ++++++ pyproject.toml | 2 +- 9 files changed, 385 insertions(+), 2 deletions(-) create mode 100644 examples/ffi/CMakeLists.txt create mode 100644 examples/ffi/README.md create mode 100644 examples/ffi/pyproject.toml create mode 100644 examples/ffi/src/jax_ffi_example/__init__.py create mode 100644 examples/ffi/src/jax_ffi_example/rms_norm.cc create mode 100644 examples/ffi/src/jax_ffi_example/rms_norm.py create mode 100644 examples/ffi/tests/rms_norm_test.py diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 0f90cd72e463..f5c1a1d6348c 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -210,4 +210,37 @@ jobs: echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py - \ No newline at end of file + + ffi: + name: FFI example + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - name: Set up Python 3.11 + uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + with: + python-version: 3.11 + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip wheel + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }} + - name: Install JAX + run: pip install . + - name: Build and install example project + run: python -m pip install -v ./examples/ffi[test] + env: + # We test building using GCC instead of clang. All other JAX builds use + # clang, but it is useful to make sure that FFI users can compile using + # a different toolchain. GCC is the default compiler on the + # 'ubuntu-latest' runner, but we still set this explicitly just to be + # clear. + CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ + - name: Run tests + run: python -m pytest examples/ffi/tests diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt new file mode 100644 index 000000000000..343ae96f404f --- /dev/null +++ b/examples/ffi/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.15...3.30) +project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) + +find_package(Python 3.8 REQUIRED COMPONENTS Interpreter Development.Module) +execute_process( + COMMAND "${Python_EXECUTABLE}" + "-c" "from jax.extend import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) +message(STATUS "XLA include directory: ${XLA_DIR}") + +find_package(nanobind CONFIG REQUIRED) + +nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc") +target_include_directories(_rms_norm PUBLIC ${XLA_DIR}) +install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) diff --git a/examples/ffi/README.md b/examples/ffi/README.md new file mode 100644 index 000000000000..cc7018782a25 --- /dev/null +++ b/examples/ffi/README.md @@ -0,0 +1,9 @@ +# End-to-end example usage for JAX's foreign function interface + +This directory includes an example project demonstrating the use of JAX's +foreign function interface (FFI). The JAX docs provide more information about +this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html), +but the example in this directory explicitly demonstrates: + +1. One way to package and distribute FFI targets, and +2. Some more advanced use cases. diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml new file mode 100644 index 000000000000..aa006419652c --- /dev/null +++ b/examples/ffi/pyproject.toml @@ -0,0 +1,11 @@ +[build-system] +requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"] +build-backend = "scikit_build_core.build" + +[project] +name = "jax_ffi_example" +version = "0.0.1" +dependencies = ["jax"] + +[project.optional-dependencies] +test = ["pytest", "absl-py"] diff --git a/examples/ffi/src/jax_ffi_example/__init__.py b/examples/ffi/src/jax_ffi_example/__init__.py new file mode 100644 index 000000000000..862a661e24b9 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc new file mode 100644 index 000000000000..2fb8d96c8461 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -0,0 +1,157 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace nb = nanobind; +namespace ffi = xla::ffi; + +// This is the example "library function" that we want to expose to JAX. This +// isn't meant to be a particularly good implementation, it's just here as a +// placeholder for the purposes of this tutorial. +float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { + float sm = 0.0f; + for (int64_t n = 0; n < size; ++n) { + sm += x[n] * x[n]; + } + float scale = 1.0f / std::sqrt(sm / float(size) + eps); + for (int64_t n = 0; n < size; ++n) { + y[n] = x[n] * scale; + } + return scale; +} + +// A helper function for extracting the relevant dimensions from `ffi::Buffer`s. +// In this example, we treat all leading dimensions as batch dimensions, so this +// function returns the total number of elements in the buffer, and the size of +// the last dimension. +template +std::pair GetDims(const ffi::Buffer &buffer) { + auto dims = buffer.dimensions(); + if (dims.size() == 0) { + return std::make_pair(0, 0); + } + return std::make_pair(buffer.element_count(), dims.back()); +} + +// A wrapper function providing the interface between the XLA FFI call and our +// library function `ComputeRmsNorm` above. This function handles the batch +// dimensions by calling `ComputeRmsNorm` within a loop. +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::Result> y) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNorm input must be an array"); + } + for (int64_t n = 0; n < totalSize; n += lastDim) { + ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); + } + return ffi::Error::Success(); +} + +// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare +// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL` +// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`. +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>() // x + .Ret>() // y +); + +ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, + ffi::Result> y, + ffi::Result> res) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormFwd input must be an array"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), + &(y->typed_data()[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>() // x + .Ret>() // y + .Ret>() // res +); + +void ComputeRmsNormBwd(int64_t size, float res, const float *x, + const float *ct_y, float *ct_x) { + float ct_res = 0.0f; + for (int64_t n = 0; n < size; ++n) { + ct_res += x[n] * ct_y[n]; + } + float factor = ct_res * res * res * res / float(size); + for (int64_t n = 0; n < size; ++n) { + ct_x[n] = res * ct_y[n] - factor * x[n]; + } +} + +ffi::Error RmsNormBwdImpl(ffi::Buffer res, ffi::Buffer x, + ffi::Buffer ct_y, + ffi::Result> ct_x) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormBwd inputs must be arrays"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), + &(ct_y.typed_data()[n]), &(ct_x->typed_data()[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl, + ffi::Ffi::Bind() + .Arg>() // res + .Arg>() // x + .Arg>() // ct_y + .Ret>() // ct_x +); + +template +nb::capsule EncapsulateFfiHandler(T *fn) { + static_assert(std::is_invocable_r_v, + "Encapsulated function must be and XLA FFI handler"); + return nb::capsule(reinterpret_cast(fn)); +} + +NB_MODULE(_rms_norm, m) { + m.def("registrations", []() { + nb::dict registrations; + registrations["rms_norm"] = EncapsulateFfiHandler(RmsNorm); + registrations["rms_norm_fwd"] = EncapsulateFfiHandler(RmsNormFwd); + registrations["rms_norm_bwd"] = EncapsulateFfiHandler(RmsNormBwd); + return registrations; + }); +} diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py new file mode 100644 index 000000000000..4e0ed1d195b4 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -0,0 +1,99 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An example demontrating the basic end-to-end use of the JAX FFI. + +This example is exactly the same as the one in the `FFI tutorial +`, so more details can be found +on that page. But, the high level summary is that we implement our custom +extension in ``rms_norm.cc``, then call it usin ``jax.extend.ffi.ffi_call`` in +this module. The behavior under autodiff is implemented using +``jax.custom_vjp``. +""" + +from functools import partial + +import numpy as np + +import jax +import jax.extend as jex +import jax.numpy as jnp + +from jax_ffi_example import _rms_norm + +for name, target in _rms_norm.registrations().items(): + jex.ffi.register_ffi_target(name, target) + + +@partial(jax.custom_vjp, nondiff_argnums=(1,)) +def rms_norm(x, eps=1e-5): + # We only implemented the `float32` version of this function, so we start by + # checking the dtype. This check isn't strictly necessary because type + # checking is also performed by the FFI when decoding input and output + # buffers, but it can be useful to check types in Python to raise more + # informative errors. + if x.dtype != jnp.float32: + raise ValueError("Only the float32 dtype is implemented by rms_norm") + + # In this case, the output of our FFI function is just a single array with the + # same shape and dtype as the input. + out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + + return jex.ffi.ffi_call( + # The target name must be the same string as we used to register the target + # above in `register_ffi_target` + "rms_norm", + out_type, + x, + # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for + # the attribute `eps`. Our FFI function expects this to have the C++ `float` + # type (which corresponds to numpy's `float32` type), and it must be a + # static parameter (i.e. not a JAX array). + eps=np.float32(eps), + # The `vectorized` parameter controls this function's behavior under `vmap`. + vectorized=True, + ) + + +def rms_norm_fwd(x, eps=1e-5): + y, res = jex.ffi.ffi_call( + "rms_norm_fwd", + ( + jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct(x.shape[:-1], x.dtype), + ), + x, + eps=np.float32(eps), + vectorized=True, + ) + return y, (res, x) + + +def rms_norm_bwd(eps, res, ct): + del eps + res, x = res + assert res.shape == ct.shape[:-1] + assert x.shape == ct.shape + return ( + jex.ffi.ffi_call( + "rms_norm_bwd", + jax.ShapeDtypeStruct(ct.shape, ct.dtype), + res, + x, + ct, + vectorized=True, + ), + ) + + +rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) diff --git a/examples/ffi/tests/rms_norm_test.py b/examples/ffi/tests/rms_norm_test.py new file mode 100644 index 000000000000..aad5562629ed --- /dev/null +++ b/examples/ffi/tests/rms_norm_test.py @@ -0,0 +1,46 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu + +from jax_ffi_example import rms_norm + +jax.config.parse_flags_with_absl() + + +def rms_norm_ref(x, eps=1e-5): + scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps) + return x / scale + + +class RmsNormTests(jtu.JaxTestCase): + def test_basic(self): + x = jnp.linspace(-0.5, 0.5, 15) + self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x)) + + def test_batching(self): + x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) + self.assertAllClose(jax.vmap(rms_norm.rms_norm)(x), jax.vmap(rms_norm_ref)(x)) + + def test_grads(self): + x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) + jtu.check_grads(rms_norm.rms_norm, (x,), order=1, modes=("rev",)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/pyproject.toml b/pyproject.toml index 9ce13ea501e8..3423783e2407 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,7 @@ doctest_optionflags = [ "NUMBER", "NORMALIZE_WHITESPACE" ] -addopts = "--doctest-glob='*.rst'" +addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'" [tool.pylint.master] extension-pkg-whitelist = "numpy" From aa73aa0021c7dc0c9c514397c7e7774640774150 Mon Sep 17 00:00:00 2001 From: Enrique Piqueras Date: Tue, 24 Sep 2024 15:51:18 -0700 Subject: [PATCH 645/702] Pallas pipeline API tweaks for more advanced pipelining patterns. PiperOrigin-RevId: 678426679 --- jax/_src/pallas/mosaic/BUILD | 1 + jax/_src/pallas/mosaic/pipeline.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 071f09f3f567..ae76a00a6c17 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -107,6 +107,7 @@ py_library( ":primitives", "//jax", "//jax:api_util", + "//jax:pallas", "//jax:util", "//jax/_src/pallas", ] + py_deps("numpy"), diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 005e4acdd106..2c96aa512e41 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -178,6 +178,8 @@ class BufferType(enum.Enum): ACCUMULATOR = 3 INPUT_OUTPUT = 4 + MANUAL = 5 + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -234,6 +236,10 @@ def tree_flatten(self): def tree_unflatten(cls, meta, data): return cls(*meta, *data) + @staticmethod + def buffer_types() -> type[BufferType]: + return BufferType + @classmethod def create(cls, spec, dtype, buffer_type) -> BufferedRef: """Create a BufferedRef. @@ -1034,6 +1040,7 @@ def pipeline( prefetch=None, postyeet=None, schedule=None, + body_prologue=None, ): """ Run the pipeline. @@ -1056,6 +1063,8 @@ def pipeline( Called during the outputs phase in the first inner step. schedule: manually specified pipeline schedules for brefs, None indicates default schedule. + body_prologue: For running code within the grid environment before the + body is run. Useful for updating manual refs. """ if scratches is None: scratches = () @@ -1119,6 +1128,9 @@ def loop_body(step, _): lambda: None) # run the kernel! + if body_prologue is not None: + with scheduler.grid_env(): + body_prologue() current_refs = map_brefs(lambda x: x.current_ref, brefs) with scheduler._named_scope("ep_run_kernel"): with scheduler.grid_env(): From 02cfaa858f557ba5fd1f58327b592dd9750d1b2a Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 24 Sep 2024 17:01:01 -0700 Subject: [PATCH 646/702] [Pallas TPU] Improve error message when trying to store a scalar to VMEM Fixes https://github.com/jax-ml/jax/issues/23884 PiperOrigin-RevId: 678448445 --- jax/_src/pallas/mosaic/lowering.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index d4dc534d034a..c9ff21c49689 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1262,7 +1262,9 @@ def _masked_swap_lowering_rule( ) ref_type = ir.MemRefType(ref.type) - is_smem_store = str(ref_type.memory_space) == "#tpu.memory_space" + memory_space = str(ref_type.memory_space) + is_smem_store = memory_space == "#tpu.memory_space" + is_vmem_store = memory_space == "#tpu.memory_space" (aval_out,) = ctx.avals_out if not isinstance(val, ir.Value): val = ir_constant(val, mlir_type=_dtype_to_ir_type(val_aval.dtype)) @@ -1281,6 +1283,7 @@ def _masked_swap_lowering_rule( cast_to_index=True, ) need_stride = not all((s is None or s == 1) for s in strides) + if is_smem_store: if val_aval.shape: raise ValueError("Can only store scalars to SMEM") @@ -1289,13 +1292,19 @@ def _masked_swap_lowering_rule( val = _maybe_cast_store_to_memref_type(val_aval, val) memref.StoreOp(val, ref, starts) return result - elif str(ref_type.memory_space) != "#tpu.memory_space": + + if not is_vmem_store: extra = "" - if str(ref_type.memory_space) == "#tpu.memory_space": + if memory_space == "#tpu.memory_space": extra = " ANY memory space can only be accessed using async_copy." raise ValueError( "Loads and stores are only allowed on VMEM and SMEM references." + extra ) + + # handling VMEM store below + if not val_aval.shape: + raise ValueError("Cannot store scalars to VMEM") + mem_slice_shape = list(aval_out.shape) for i, a in enumerate(idx_aval.indices): if not isinstance(a, primitives.Slice): From 0a73d74a4ed499d5aaff6e9a995ff97a2c5714d6 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 20 Sep 2024 22:58:01 +0000 Subject: [PATCH 647/702] simplify conversion logic involving extended dtypes Previously, the idea was that we would use the `convert_element_type` primitive to cast to/from extended dtypes. Extended dtype rules specified `convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool` functions. They were meant to do something like indicate whether a convert_element_type was legal. But I'm not sure if they really made sense. The implementation was certainly buggy for non-scalar representation types (physical element types). This PR simplifies and fixes things: 1. Instead of overloading the `convert_element_type_p` primitive with more cases involving casts to/from extended dtypes, let's just have distinct `to_edtype_p` and `from_edtype_p` primitives, which can be much simpler. We still reuse the `jax.lax.convert_element_type` API function, so there's no API change to the few existing users who know about this stuff. 2. Instead of extended dtype rules including `convert_from`/`convert_to` functions with questionable semantics, let's only allow casts to/from the representation type, which is already specified by the rules' `physical_element_aval`. (Indeed that should be roughly _all_ we need, and this PR is just one step towards realizing that goal.) We still have a boolean `allow_conversion` on extended dtype rules just so we can handle the PRNGKey case, where we don't want to allow any casts. 3. Fix the conversion logic to handle non-scalar representation types (physical element types). --- jax/_src/core.py | 18 +-- jax/_src/dtypes.py | 15 ++- jax/_src/lax/lax.py | 150 +++++++++++++++++----- jax/_src/pallas/core.py | 7 +- jax/_src/prng.py | 9 +- jax/experimental/jax2tf/jax2tf.py | 2 + tests/dtypes_test.py | 198 +++++++++++++++++++++++------- tests/dynamic_api_test.py | 3 + 8 files changed, 295 insertions(+), 107 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 85361ab5a7b8..9ef19fbeccdc 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1608,12 +1608,6 @@ def physical_element_aval(edtype: dtypes.ExtendedDType) -> ShapedArray: duck = edtype._rules.physical_element_aval(edtype) # type: ignore return ShapedArray(duck.shape, dtypes.dtype(duck.dtype)) -def _short_dtype_name(dtype) -> str: - if isinstance(dtype, dtypes.ExtendedDType): - return str(dtype) - else: - return (dtype.name.replace('float', 'f').replace('uint' , 'u') - .replace('int' , 'i').replace('complex', 'c')) def _dtype_object(dtype): return dtype if isinstance(dtype, dtypes.ExtendedDType) else np.dtype(dtype) @@ -1672,7 +1666,7 @@ def join(self, other): raise TypeError(self, other) def str_short(self, short_dtypes=False) -> str: - return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name def strip_weak_type(self): """Returns a copy of the aval with weak_type=False.""" @@ -1811,7 +1805,7 @@ def join(self, other): raise TypeError(self, other) def str_short(self, short_dtypes=False): - dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name dt_str = dt_str.replace('void', 'float0') shapestr = ','.join(map(str, self.shape)) if hasattr(self, 'sharding'): @@ -1872,7 +1866,7 @@ def join(self, other) -> AbstractValue: raise TypeError(self, other) def str_short(self, short_dtypes=False) -> str: - dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name return f'{self.val}, dtype={dt_str}' _bool = partialmethod(_forward_to_value, bool) @@ -1922,7 +1916,7 @@ def __init__(self, shape, dtype, weak_type=False): def str_short(self, short_dtypes=False) -> str: del short_dtypes # ignored shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else '' - dtype = _short_dtype_name(self.dtype) + dtype = dtypes.short_dtype_name(self.dtype) return f'{dtype}[{shape}]' __str__ = __repr__ = str_short @@ -1989,7 +1983,7 @@ def __repr__(self) -> str: # special-case scalar bints return f'{int(self._data)}{{≤{self.dtype.bound}}}' - dtypestr = _short_dtype_name(self._aval.dtype) + dtypestr = dtypes.short_dtype_name(self._aval.dtype) shapestr = ','.join(map(str, self.shape)) data = self.data return f'{dtypestr}[{shapestr}] with value: {data}' @@ -3203,7 +3197,7 @@ def pp_var(v: Var | Literal, context: JaxprPpContext) -> str: def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str: if isinstance(a, DShapedArray): shape = [pp_var(d, context) if type(d) is Var else str(d) for d in a.shape] - dtype = _short_dtype_name(a.dtype) + dtype = dtypes.short_dtype_name(a.dtype) return f'{dtype}[{",".join(shape)}]' else: return a.str_short(short_dtypes=True) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 9865632d8975..82be38d1cb57 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -839,13 +839,14 @@ def safe_to_cast(input_dtype_or_value: Any, def primal_tangent_dtype(primal_dtype, tangent_dtype, name: str | None = None) -> ExtendedDType: - name_ = name or f'PTDtype{{{primal_dtype}:{tangent_dtype}}}' + primal_dtype, tangent_dtype = map(dtype, (primal_dtype, tangent_dtype)) + name_ = name or (f'PrimalTangentDType{{{short_dtype_name(primal_dtype)}' + f'/{short_dtype_name(tangent_dtype)}}}') rules = types.SimpleNamespace( physical_element_aval= lambda dtype: types.SimpleNamespace(shape=(), dtype=primal_dtype), tangent_dtype=lambda dtype: tangent_dtype, - convert_from=lambda _, other: other == primal_dtype, - convert_to=lambda other, _: other == primal_dtype) + allow_conversion=True) class primal_tangent_dtype_scalar(extended): ... @@ -854,5 +855,13 @@ class PrimalTangentDType(ExtendedDType): name = name_ _rules = rules type = primal_tangent_dtype_scalar + __repr__ = lambda _: name_ return PrimalTangentDType() + +def short_dtype_name(dtype) -> str: + if isinstance(dtype, ExtendedDType): + return str(dtype) + else: + return (dtype.name.replace('float', 'f').replace('uint' , 'u') + .replace('int' , 'i').replace('complex', 'c')) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 3748e8191607..7226ea25922b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -137,7 +137,7 @@ def asarray(x: ArrayLike) -> Array: if isinstance(x, Array): return x if isinstance(x, (np.ndarray, np.generic, bool, int, float, builtins.complex)): - return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x)) + return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x)) # type: ignore[unused-ignore,bad-return-type] else: raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.") @@ -520,7 +520,7 @@ def convert_element_type(operand: ArrayLike, Returns: An array with the same shape as `operand`, cast elementwise to `new_dtype`. """ - return _convert_element_type(operand, new_dtype, weak_type=False) + return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type] def _convert_element_type( operand: ArrayLike, @@ -530,17 +530,30 @@ def _convert_element_type( if hasattr(operand, '__jax_array__'): operand = operand.__jax_array__() - if (dtypes.issubdtype(new_dtype, dtypes.extended) or - dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)): - return convert_element_type_p.bind( - operand, new_dtype=new_dtype, weak_type=bool(weak_type), - sharding=sharding) - - new_dtype = type_cast(DTypeLike | None, new_dtype) - # Don't canonicalize old_dtype because x64 context might cause # un-canonicalized operands to be passed in. old_dtype = dtypes.dtype(operand, canonicalize=False) + + if (isinstance(new_dtype, dtypes.ExtendedDType) or + isinstance(old_dtype, dtypes.ExtendedDType)): + if sharding is not None or weak_type: raise NotImplementedError + if new_dtype == old_dtype: return operand + if (isinstance(new_dtype, dtypes.ExtendedDType) and + isinstance(old_dtype, dtypes.ExtendedDType)): + old_rep_dtype = core.physical_element_aval(old_dtype).dtype + new_rep_dtype = core.physical_element_aval(new_dtype).dtype + raise ValueError( + "cannot directly convert between extended dtypes: from " + f"{dtype_to_string(old_dtype)} to {dtype_to_string(new_dtype)}. " + "Instead, convert to and from their representation dtypes, e.g.:\n" + f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} " + f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}") + if isinstance(new_dtype, dtypes.ExtendedDType): + return to_edtype_p.bind(operand, edtype=new_dtype) + return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype)) + + new_dtype = type_cast(DTypeLike | None, new_dtype) + old_weak_type = dtypes.is_weakly_typed(operand) if new_dtype is None: new_dtype = old_dtype @@ -2560,14 +2573,6 @@ def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type, def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type, sharding): - if (operand.dtype != new_dtype and - ((dtypes.issubdtype(operand.dtype, dtypes.extended) and - not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or - (dtypes.issubdtype(new_dtype, dtypes.extended) and - not new_dtype._rules.convert_to(operand.dtype, new_dtype)))): - raise ValueError( - f"Cannot convert_element_type from {dtype_to_string(operand.dtype)} " - f"to {dtype_to_string(new_dtype)}") return new_dtype def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type, @@ -2587,13 +2592,13 @@ def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type, return [convert_element_type_p.bind( ct, new_dtype=old_dtype, weak_type=old_weak_type, sharding=sharding)] -def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type, - sharding): - if core.primal_dtype_to_tangent_dtype(new_dtype) == dtypes.float0: - tangent_aval = core.raise_to_shaped(core.get_aval(tangent)) - return ad_util.Zero(tangent_aval.update(dtype=dtypes.float0, weak_type=False)) +def _convert_element_type_jvp_rule(tangent, primal_result, operand, *, + new_dtype, weak_type, sharding): + new_tangent_dtype = core.primal_dtype_to_tangent_dtype(new_dtype) + if new_tangent_dtype == dtypes.float0: + return ad_util.Zero.from_primal_value(primal_result) else: - return convert_element_type_p.bind(tangent, new_dtype=new_dtype, + return convert_element_type_p.bind(tangent, new_dtype=new_tangent_dtype, weak_type=weak_type, sharding=sharding) def _convert_elt_type_folding_rule(consts, eqn): @@ -2653,7 +2658,7 @@ def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding): _convert_element_type_shape_rule, _convert_element_type_dtype_rule, _convert_element_type_weak_type_rule, _convert_element_type_sharding_rule)) -ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule) +ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule batching.defvectorized(convert_element_type_p) pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule @@ -2676,6 +2681,91 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, mlir.register_lowering(convert_element_type_p, _convert_element_type_lower) +def _to_edtype_abstract_eval(x, *, edtype): + assert (isinstance(edtype, dtypes.ExtendedDType) and + not isinstance(x.dtype, dtypes.ExtendedDType)) + # For backward compatibility, if the edtype rules have a `convert_to` method, + # use that rather than looking for an `allow_conversion: bool` attribute. + if convert_to := getattr(edtype._rules, 'convert_to', None): + allow_conversion = convert_to(x.dtype, edtype) + else: + allow_conversion = edtype._rules.allow_conversion + if not allow_conversion: + raise ValueError( + f"Cannot convert_element_type from {dtype_to_string(x.dtype)} " + f"to {dtype_to_string(edtype)}") + rep_aval = core.physical_element_aval(edtype) + if x.dtype != rep_aval.dtype: + raise ValueError( + "can only convert to extended dtype from its representation dtype, " + f"but tried to convert from {dtype_to_string(x.dtype)} to " + f"{dtype_to_string(edtype)} which doesn't match the representation type " + f"{dtype_to_string(rep_aval.dtype)}.") + if x.ndim < rep_aval.ndim: + raise ValueError( + "can only convert to extended dtype from an array of its " + f"representation type, but the extended dtype {dtype_to_string(edtype)}" + f" has a representation shape {rep_aval.shape} (rank {rep_aval.ndim}) " + f"while the given representation array has shape {x.shape} (rank " + f"{x.ndim} < {rep_aval.ndim}).") + n = x.ndim - rep_aval.ndim + shape_prefix, shape_suffix = x.shape[:n], x.shape[n:] + if shape_suffix != rep_aval.shape: + raise ValueError( + "can only convert to extended dtype from an array of its " + f"representation type, but the extended dtype {dtype_to_string(edtype)}" + f" has a representation shape {rep_aval.shape} while the given " + f"representation array has shape {x.shape}, so the shape suffix " + f"does not match: given {shape_suffix} but required {rep_aval.shape}.") + return core.raise_to_shaped(x).update(shape=shape_prefix, dtype=edtype) + +to_edtype_p = Primitive('to_edtype') +to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p)) +to_edtype_p.def_abstract_eval(_to_edtype_abstract_eval) +ad.defjvp(to_edtype_p, + lambda t, x, edtype: + convert_element_type(t, core.primal_dtype_to_tangent_dtype(edtype))) +ad.primitive_transposes[to_edtype_p] = \ + lambda ct, x, edtype: [from_edtype_p.bind(ct, dtype=x.aval.dtype)] # type: ignore +batching.defvectorized(to_edtype_p) +mlir.register_lowering(to_edtype_p, lambda _, x, **__: [x]) + + +def _from_edtype_abstract_eval(x, *, dtype): + assert (isinstance(x.dtype, dtypes.ExtendedDType) and + not isinstance(dtype, dtypes.ExtendedDType)) + if convert_from := getattr(x.dtype._rules, 'convert_from', None): + allow_conversion = convert_from(x.dtype, dtype) + else: + allow_conversion = x.dtype._rules.allow_conversion + if not allow_conversion: + raise ValueError( + f"Cannot convert_element_type from {dtype_to_string(x.dtype)} " + f"to {dtype_to_string(dtype)}") + rep_aval = core.physical_element_aval(x.dtype) + if rep_aval.dtype != dtype: + raise ValueError( + "can only convert from extended dtype to its representation dtype, " + f"but tried to convert from {dtype_to_string(x.dtype)} to " + f"{dtype_to_string(dtype)} which doesn't match the representation type " + f"{dtype_to_string(rep_aval.dtype)}.") + if all(isinstance(d, int) for d in x.shape): + return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype) + else: + raise NotImplementedError + +from_edtype_p = Primitive('from_edtype') +from_edtype_p.def_impl(partial(dispatch.apply_primitive, from_edtype_p)) +from_edtype_p.def_abstract_eval(_from_edtype_abstract_eval) +ad.defjvp(from_edtype_p, + lambda t, x, dtype: + convert_element_type(t, core.primal_dtype_to_tangent_dtype(dtype))) +ad.primitive_transposes[from_edtype_p] = \ + lambda ct, x, dtype: [to_edtype_p.bind(ct, edtype=x.dtype)] +batching.defvectorized(from_edtype_p) +mlir.register_lowering(from_edtype_p, lambda _, x, **__: [x]) + + def _bitcast_convert_type_shape_rule(operand, *, new_dtype): old_dtype = dtypes.canonicalize_dtype(operand.dtype) new_dtype = dtypes.canonicalize_dtype(new_dtype) @@ -5343,6 +5433,8 @@ def _empty_lower(ctx, *, dtype): class BIntRules: + allow_conversion: bool = True + @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((), np.dtype('int32')) @@ -5369,14 +5461,6 @@ def handler(bufs): return core.DArray(aval, phys_handler(bufs)) return handler - @staticmethod - def convert_from(bint_dtype, other_dtype) -> bool: - return other_dtype in (np.dtype('int32'), np.dtype('int64')) - - @staticmethod - def convert_to(other_dtype, bint_dtype) -> bool: - return other_dtype in (np.dtype('int32'), np.dtype('int64')) - core.bint._rules = BIntRules diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 1c3fa10e9e92..e817369a50c5 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -149,11 +149,8 @@ def join(self, other): raise NotImplementedError def str_short(self, short_dtypes=False): - dt_str = ( - jax_core._short_dtype_name(self.dtype) - if short_dtypes - else self.dtype.name - ) + dt_str = \ + dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name dt_str = dt_str.replace("void", "float0") shapestr = ",".join(map(str, self.shape)) if hasattr(self, "sharding"): diff --git a/jax/_src/prng.py b/jax/_src/prng.py index c4d6683c0262..a9dca0b4bffe 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -321,6 +321,7 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape): class KeyTyRules: + allow_conversion: bool = False @staticmethod def full(shape, fill_value, dtype): @@ -425,14 +426,6 @@ def tangent_dtype(_): def zero(_): return np.zeros((), dtypes.float0) - @staticmethod - def convert_from(key_dtype, other_dtype) -> bool: - return False - - @staticmethod - def convert_to(other_dtype, key_dtype) -> bool: - return False - class KeyTy(dtypes.ExtendedDType): _impl: PRNGImpl # TODO(mattjj,frostig): protocol really diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 8a90c491e526..6d7f2c2a1e2c 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1549,6 +1549,8 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "ragged_dot", "cholesky_update", "symmetric_update", + "from_edtype", + "to_edtype", # Pallas TPU primitives "bitcast", "repeat", diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index e9d05a9eaaa5..89d70871a8f9 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -19,6 +19,7 @@ from functools import partial import itertools import operator +import types from absl.testing import absltest from absl.testing import parameterized @@ -300,16 +301,6 @@ def testIsSubdtype(self): self.assertEqual(dtypes.issubdtype(t, category), np.issubdtype(np.dtype(t).type, category)) - def testIsSubdtypeExtended(self): - self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended)) - self.assertTrue(dtypes.issubdtype(dtypes.extended, np.generic)) - self.assertFalse(dtypes.issubdtype(dtypes.extended, np.number)) - - self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.prng_key)) - self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.extended)) - self.assertTrue(jnp.issubdtype(dtypes.prng_key, np.generic)) - self.assertFalse(dtypes.issubdtype(dtypes.prng_key, np.number)) - @parameterized.product(dtype=custom_float_dtypes) def testIsSubdtypeCustomFloats(self, dtype): for dt in [dtype, np.dtype(dtype), str(np.dtype(dtype))]: @@ -408,6 +399,34 @@ def testDefaultDtypes(self): self.assertEqual(dtypes.float_, np.float32 if precision == '32' else np.float64) self.assertEqual(dtypes.complex_, np.complex64 if precision == '32' else np.complex128) + def test_check_dtype_non_hashable(self): + # regression test for issue with checking non-hashable custom dtype + class MyDtype: + __hash__ = None + dtype = np.dtype('float32') + dtypes.check_user_dtype_supported(MyDtype()) + + def test_check_dtype_array(self): + x = jnp.arange(4) + msg = "Passing an array as a dtype argument is deprecated" + with self.assertWarnsRegex(DeprecationWarning, msg): + dtypes.check_user_dtype_supported(x) + with self.assertWarnsRegex(DeprecationWarning, msg): + jax.jit(dtypes.check_user_dtype_supported)(x) + + +class ExtendedDTypeTest(jtu.JaxTestCase): + + def testIsSubdtypeExtended(self): + self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended)) + self.assertTrue(dtypes.issubdtype(dtypes.extended, np.generic)) + self.assertFalse(dtypes.issubdtype(dtypes.extended, np.number)) + + self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.prng_key)) + self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.extended)) + self.assertTrue(jnp.issubdtype(dtypes.prng_key, np.generic)) + self.assertFalse(dtypes.issubdtype(dtypes.prng_key, np.number)) + def test_custom_tangent_dtype(self): from jax._src import core @@ -415,6 +434,8 @@ class scale(dtypes.extended): pass class ScalesTyRules: + allow_conversion: bool = True + @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((), dtype.float_dtype) @@ -435,14 +456,6 @@ def zero(dt): else dtypes.finfo(dt.float_dtype).min, dt.float_dtype) return jax.lax.convert_element_type(neginf, dt) - @staticmethod - def convert_from(dtype, other_dtype) -> bool: - return dtype.float_dtype == other_dtype - - @staticmethod - def convert_to(other_dtype, dtype) -> bool: - return dtype.float_dtype == other_dtype - @dataclasses.dataclass(frozen=True) class ScaleTy(dtypes.ExtendedDType): float_dtype: dtypes.DType @@ -485,19 +498,13 @@ def test_custom_tangent_dtype_with_scan(self): from jax._src import core class ScalesTyRules: - # tell JAX how to lower this dtype to an HLO dtype + # tell JAX how to lower this dtype to an HLO representation dtype @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((), dtype.float_dtype) - # allow conversions to and from the corresponding float type - @staticmethod - def convert_from(scale_dtype, other_dtype) -> bool: - return scale_dtype.float_dtype == other_dtype - - @staticmethod - def convert_to(other_dtype, scale_dtype) -> bool: - return scale_dtype.float_dtype == other_dtype + # allow conversions to and from the corresponding representation type + allow_conversion: bool = True # define how autodiff should accumulate these values @staticmethod @@ -563,21 +570,6 @@ def inner_bwd(prev_scale, grads): _, new_scale = jax.jit(jax.grad(outer, (0, 1)))(jnp.float32(3.14), scale) self.assertAllClose(new_scale, jnp.float32(1.0)) - def test_check_dtype_non_hashable(self): - # regression test for issue with checking non-hashable custom dtype - class MyDtype: - __hash__ = None - dtype = np.dtype('float32') - dtypes.check_user_dtype_supported(MyDtype()) - - def test_check_dtype_array(self): - x = jnp.arange(4) - msg = "Passing an array as a dtype argument is deprecated" - with self.assertWarnsRegex(DeprecationWarning, msg): - dtypes.check_user_dtype_supported(x) - with self.assertWarnsRegex(DeprecationWarning, msg): - jax.jit(dtypes.check_user_dtype_supported)(x) - @parameterized.parameters([True]) # TODO(mattjj): make jit=False work def test_primal_tangent_dtype(self, jit): dt = dtypes.primal_tangent_dtype(jnp.int8, jnp.bfloat16) @@ -605,6 +597,123 @@ def h(): self.assertEqual(result.dtype, jnp.bfloat16) self.assertEqual(bwd_result.dtype, jnp.bfloat16) self.assertAllClose(bwd_result, 2 * g) + self.assertEqual(repr(dt), 'PrimalTangentDType{i8/bf16}') + + @parameterized.parameters(itertools.product([(), (2,), (3, 4)], repeat=2)) + def test_edtype_conversion(self, shape_prefix, shape_suffix): + class scalar(dtypes.extended): ... + + @dataclasses.dataclass(frozen=True) + class DType(dtypes.ExtendedDType): + name = 'dt' + type = scalar + _rules = types.SimpleNamespace( + physical_element_aval= + lambda _: types.SimpleNamespace(shape=shape_suffix, dtype='int32'), + allow_conversion=True) + dtype = DType() + + @jax.jit + def f(x): + self.assertEqual(x.shape, shape_prefix + shape_suffix) + self.assertEqual(x.dtype, jnp.dtype('int32')) + x = jax.lax.convert_element_type(x, dtype) + self.assertEqual(x.shape, shape_prefix) + self.assertEqual(x.dtype, dtype) + x = jax.lax.convert_element_type(x, 'int32') + self.assertEqual(x.shape, shape_prefix + shape_suffix) + self.assertEqual(x.dtype, jnp.dtype('int32')) + f(jnp.zeros(shape_prefix + shape_suffix, dtype='int32')) + + def test_edtype_conversion_errors(self): + class scalar(dtypes.extended): ... + + @dataclasses.dataclass(frozen=True) + class DType(dtypes.ExtendedDType): + name = 'dt' + type = scalar + _rules = types.SimpleNamespace( + physical_element_aval= + lambda _: types.SimpleNamespace(shape=(3,), dtype='int32'), + allow_conversion=True) + dtype = DType() + + class scalar2(dtypes.extended): ... + + @dataclasses.dataclass(frozen=True) + class DType2(dtypes.ExtendedDType): + name = 'dt2' + type = scalar2 + _rules = types.SimpleNamespace( + physical_element_aval= + lambda _: types.SimpleNamespace(shape=(3,), dtype='int32'), + allow_conversion=True) + dtype2 = DType2() + + @jax.jit + def f(x): + y = jax.lax.convert_element_type(x, dtype) + with self.assertRaisesRegex(ValueError, "cannot directly"): + jax.lax.convert_element_type(y, dtype2) + with self.assertRaisesRegex(ValueError, "can only convert"): + jax.lax.convert_element_type(x.astype('float32'), dtype) + with self.assertRaisesRegex(ValueError, "can only convert"): + jax.lax.convert_element_type(x[:, :2], dtype) + with self.assertRaisesRegex(ValueError, "can only convert"): + jax.lax.convert_element_type(x[:, 0], dtype) + with self.assertRaisesRegex(ValueError, "can only convert"): + jax.lax.convert_element_type(y, 'float32') + f(jnp.zeros((5, 3), dtype='int32')) + + def test_edtype_conversion_autodiff(self): + + class scalar(dtypes.extended): ... + + @dataclasses.dataclass(frozen=True) + class DType(dtypes.ExtendedDType): + name = 'dt' + type = scalar + _rules = types.SimpleNamespace( + physical_element_aval= + lambda _: types.SimpleNamespace(shape=(), dtype='float32'), + tangent_dtype=lambda dtype: jnp.dtype('bfloat16'), + allow_conversion=True) + dtype = DType() + + @jax.jit + @jax.grad + def f(x): + x = jax.lax.convert_element_type(x, dtype) + + @jax.custom_jvp + def g(x): return x + @g.defjvp + def g_jvp(primals, tangents): + (x,), (x_dot,) = primals, tangents + self.assertEqual(x.shape, (5,)) + self.assertEqual(x.dtype, dtype) + self.assertEqual(x_dot.shape, (5,)) + self.assertEqual(x_dot.dtype, jnp.dtype('bfloat16')) + return x, x_dot + x = g(x) + + x = jax.lax.convert_element_type(x, 'float32') + + @jax.custom_jvp + def h(x): return x + @h.defjvp + def h_jvp(primals, tangents): + (x,), (x_dot,) = primals, tangents + self.assertEqual(x.shape, (5,)) + self.assertEqual(x.dtype, jnp.dtype('float32')) + self.assertEqual(x_dot.shape, (5,)) + self.assertEqual(x_dot.dtype, jnp.dtype('float32')) + return x, x_dot + x = h(x) + + return 0. + + f(jnp.zeros(5, dtype='float32')) # test assertions in the function class EArrayTest(jtu.JaxTestCase): @@ -618,10 +727,7 @@ def test_extended_dtypes_at_rest(self, jit): class foo(dtypes.extended): pass class FooTyRules: - - @staticmethod - def convert_to(foo_dtype, target_dtype): - return True + allow_conversion: bool = True @staticmethod def physical_element_aval(foo_dtype): diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index cc2419fb3757..f6625e86ca14 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -1486,6 +1486,9 @@ def f(i): jax_traceback_filtering='off') class JumbleTest(jtu.JaxTestCase): + def setUp(self): + if jax.config.x64_enabled: raise unittest.SkipTest() + @parameterized.parameters((True,), (False,)) def test_internal_jumble(self, disable_jit): with jax.disable_jit(disable_jit): From 76039016084f645d39364f3baeb7f6be994616c2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 24 Sep 2024 17:11:18 -0700 Subject: [PATCH 648/702] [mosaic] Fix a warning causing CI failures. An array ref object was passed where a dtype was expected. PiperOrigin-RevId: 678451446 --- jax/_src/pallas/mosaic/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 2c96aa512e41..1514f67a9e33 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -512,7 +512,7 @@ def set_accumulator(self, init=False): def _init(): self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...]) def _set(): - self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref) + self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref.dtype) lax.cond(init, _init, _set) def accumulate(self): From 291e52a713bf581f669dda241402b34bafa87e41 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 24 Sep 2024 17:24:32 -0700 Subject: [PATCH 649/702] Fix some warnings causing CI failures on ARM. PiperOrigin-RevId: 678454816 --- tests/lax_numpy_test.py | 1 + tests/scipy_optimize_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a10a7369721f..c6d56885a6a8 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5706,6 +5706,7 @@ def testTraceMethod(self): self.assertAllClose(x.trace(), jnp.array(x).trace()) self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) + @jtu.ignore_warning(category=RuntimeWarning, message="divide by zero") def testIntegerPowersArePrecise(self): # See https://github.com/jax-ml/jax/pull/3036 # Checks if the squares of float32 integers have no numerical errors. diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index 70a00e14c468..ffa576850538 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -117,6 +117,7 @@ def zakharov_fn(x): jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS') self.assertLess(jax_res.fun, 1e-6) + @jtu.ignore_warning(category=RuntimeWarning, message='divide by zero') def test_minimize_bad_initial_values(self): # This test runs deliberately "bad" initial values to test that handling # of failed line search, etc. is the same across implementations From d21ad1e19a94ccd44bf33aa1561205e7b9946761 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 25 Sep 2024 07:25:08 +0530 Subject: [PATCH 650/702] Improve docs for jax.numpy: arcsin, arccos and arctan --- jax/_src/numpy/ufuncs.py | 114 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 111 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index a88ea3d760dc..8cde0348ce5d 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -587,19 +587,127 @@ def tan(x: ArrayLike, /) -> Array: """ return lax.tan(*promote_args_inexact('tan', x)) -@implements(np.arcsin, module='numpy') + @partial(jit, inline=True) def arcsin(x: ArrayLike, /) -> Array: + r"""Compute element-wise inverse of trigonometric sine of input. + + JAX implementation of :obj:`numpy.arcsin`. + + Args: + x: input array or scalar. + + Returns: + An array containing the inverse trigonometric sine of each element of ``x`` + in radians in the range ``[-pi/2, pi/2]``, promoting to inexact dtype. + + Note: + - ``jnp.arcsin`` returns ``nan`` when ``x`` is real-valued and not in the closed + interval ``[-1, 1]``. + - ``jnp.arcsin`` follows the branch cut convention of :func:`numpy.arcsin` for + complex inputs. + + See also: + - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input. + - :func:`jax.numpy.arccos` and :func:`jax.numpy.acos`: Computes the inverse of + trigonometric cosine of each element of input. + - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of + trigonometric tangent of each element of input. + + Examples: + >>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arcsin(x) + Array([ nan, -1.571, -0.524, 0. , 0.524, 1.571, nan], dtype=float32) + + For complex-valued inputs: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arcsin(3+4j) + Array(0.634+2.306j, dtype=complex64, weak_type=True) + """ return lax.asin(*promote_args_inexact('arcsin', x)) -@implements(np.arccos, module='numpy') + @partial(jit, inline=True) def arccos(x: ArrayLike, /) -> Array: + """Compute element-wise inverse of trigonometric cosine of input. + + JAX implementation of :obj:`numpy.arccos`. + + Args: + x: input array or scalar. + + Returns: + An array containing the inverse trigonometric cosine of each element of ``x`` + in radians in the range ``[0, pi]``, promoting to inexact dtype. + + Note: + - ``jnp.arccos`` returns ``nan`` when ``x`` is real-valued and not in the closed + interval ``[-1, 1]``. + - ``jnp.arccos`` follows the branch cut convention of :func:`numpy.arccos` for + complex inputs. + + See also: + - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of + input. + - :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of + trigonometric sine of each element of input. + - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of + trigonometric tangent of each element of input. + + Examples: + >>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arccos(x) + Array([ nan, 3.142, 2.094, 1.571, 1.047, 0. , nan], dtype=float32) + + For complex inputs: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arccos(4-1j) + Array(0.252+2.097j, dtype=complex64, weak_type=True) + """ return lax.acos(*promote_args_inexact('arccos', x)) -@implements(np.arctan, module='numpy') + @partial(jit, inline=True) def arctan(x: ArrayLike, /) -> Array: + """Compute element-wise inverse of trigonometric tangent of input. + + JAX implement of :obj:`numpy.arctan`. + + Args: + x: input array or scalar. + + Returns: + An array containing the inverse trigonometric tangent of each element ``x`` + in radians in the range ``[-pi/2, pi/2]``, promoting to inexact dtype. + + Note: + ``jnp.arctan`` follows the branch cut convention of :func:`numpy.arctan` for + complex inputs. + + See also: + - :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of + input. + - :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of + trigonometric sine of each element of input. + - :func:`jax.numpy.arccos` and :func:`jax.numpy.atan`: Computes the inverse of + trigonometric cosine of each element of input. + + Examples: + >>> x = jnp.array([-jnp.inf, -20, -1, 0, 1, 20, jnp.inf]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arctan(x) + Array([-1.571, -1.521, -0.785, 0. , 0.785, 1.521, 1.571], dtype=float32) + + For complex-valued inputs: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arctan(2+7j) + Array(1.532+0.133j, dtype=complex64, weak_type=True) + """ return lax.atan(*promote_args_inexact('arctan', x)) @implements(np.sinh, module='numpy') From 1fe0c5dad51cdb7acb8a8420df3d758be3a01bad Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 24 Sep 2024 19:00:33 -0700 Subject: [PATCH 651/702] Fix printing of saved_residual for `jit` by looking for `pjit` as the primitive instead of `xla_call` which was removed 2 years ago PiperOrigin-RevId: 678479141 --- jax/_src/ad_checkpoint.py | 2 +- tests/api_test.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 39df07359c18..bd7482eb50cf 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -473,7 +473,7 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]: if v in res_vars: if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]): results.append((v.aval, f"named '{eqn.params['name']}' from {src}")) - elif str(eqn.primitive) == 'xla_call': + elif str(eqn.primitive) == 'pjit': results.append((v.aval, f"output of jitted function '{eqn.params['name']}' " f"from {src}")) diff --git a/tests/api_test.py b/tests/api_test.py index d1c77c75eb79..c73d5960f123 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5768,6 +5768,27 @@ def f(x, y): self.assertStartsWith(res[4][1], "named 'z'") self.assertEqual(res[5][0].shape, ()) + def test_saved_residuals_utility_jit(self): + @jax.jit + def f(x, y): + x1, x2 = x + z = checkpoint_name(jnp.sin(3.), 'z') + return z * ((x1 * x2) * y) * np.array([3.]) + + res = saved_residuals(f, (2., 3.), y=4.) + self.assertLen(res, 6) + self.assertEqual(res[0][0].shape, ()) + self.assertEqual(res[0][1], "from the argument x[0]") + self.assertEqual(res[1][0].shape, ()) + self.assertEqual(res[1][1], "from the argument x[1]") + self.assertEqual(res[2][0].shape, ()) + self.assertEqual(res[2][1], "from the argument y") + self.assertEqual(res[3][0].shape, ()) + self.assertStartsWith(res[3][1], "output of jitted function 'f'") + self.assertEqual(res[4][0].shape, ()) + self.assertEqual(res[5][0].shape, (1,)) + self.assertStartsWith(res[5][1], "output of jitted function 'f'") + @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ From eff00cc4499cfe3f3f24bafda6c1ecf908232ff3 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Wed, 25 Sep 2024 04:52:28 -0700 Subject: [PATCH 652/702] [JAX] add support for gather/scatter batching dims following the new attributes in stablehlo. This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota. See https://github.com/openxla/stablehlo/pull/2259 PiperOrigin-RevId: 678649138 --- jax/BUILD | 4 + jax/_src/internal_test_util/test_harnesses.py | 104 ++-- jax/_src/lax/lax.py | 17 +- jax/_src/lax/parallel.py | 3 +- jax/_src/lax/slicing.py | 394 ++++++++++---- jax/_src/numpy/lax_numpy.py | 22 +- jax/_src/ops/scatter.py | 4 +- jax/experimental/jax2tf/jax2tf.py | 6 + tests/lax_test.py | 482 +++++++++++++++--- tests/lax_vmap_test.py | 32 ++ 10 files changed, 832 insertions(+), 236 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index c25d0004e772..d49e783e61d6 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -147,7 +147,11 @@ py_library( srcs = ["_src/internal_test_util/test_harnesses.py"], visibility = [":internal"] + jax_internal_test_harnesses_visibility, deps = [ + ":ad_util", + ":config", ":jax", + ":test_util", + "//jax/_src/lib", ] + py_deps("numpy"), ) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 2c94907568d9..31c3fec94536 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -1169,6 +1169,18 @@ def _make_broadcast_in_dim_harness(name, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3), True), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), + lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(1,), + start_index_map=(1,), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (1, 1), True), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + lax.GatherDimensionNumbers( + offset_dims=(2,), collapsed_slice_dims=(), + start_index_map=(2,), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (1, 1, 3), True) ]: dtype = np.float32 for enable_xla in ([True] if needs_xla else [True, False]): @@ -1276,15 +1288,16 @@ def _make_scatter_harness(name, update_shape=(2,), mode=lax.GatherScatterMode.FILL_OR_DROP, dtype=np.float32, - dimension_numbers=((), (0,), (0,)), + dimension_numbers=lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,)), enable_and_disable_xla=False): - dimension_numbers = lax.ScatterDimensionNumbers(*dimension_numbers) xla_options = [True, False] if enable_and_disable_xla else [True] for enable_xla in xla_options: define( f_lax.__name__, - f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_updatewindowdims={dimension_numbers.update_window_dims}_insertedwindowdims={dimension_numbers.inserted_window_dims}_scatterdimstooperanddims={dimension_numbers.scatter_dims_to_operand_dims}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}" + f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_{dimension_numbers=}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}" .replace(" ", ""), partial( f_lax, @@ -1328,8 +1341,19 @@ def _make_scatter_harness(name, # Validate shapes, dimension numbers and scatter indices. All are in bounds. for shape, scatter_indices, update_shape, dimension_numbers in [ - ((10,), [[0], [0], [0]], (3, 2), ((1,), (), (0,))), - ((10, 5), [[0], [2], [1]], (3, 3), ((1,), (0,), (0,))) + ((10,), [[0], [0], [0]], (3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0,))), + ((10, 5), [[0], [2], [1]], (3, 3), + lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,))), + ((2, 3, 10), [[[0], [1]], [[2], [3]], [[4], [5]]], (3, 2, 3), + lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]: _make_scatter_harness( "shapes_and_dimension_numbers", @@ -1358,13 +1382,16 @@ def _make_scatter_harness(name, _make_scatter_harness("modes_in_bounds", f_lax=f_lax, mode=mode) - _make_scatter_harness("modes_out_of_bounds", mode=mode, - shape=(1, 5), - f_lax=f_lax, - scatter_indices=np.array([10]), - update_shape=(1,), - dimension_numbers=((0,), (1,), (1,)), - enable_and_disable_xla=True) + _make_scatter_harness( + "modes_out_of_bounds", + mode=mode, + shape=(1, 5), + f_lax=f_lax, + scatter_indices=np.array([10]), + update_shape=(1,), + dimension_numbers=lax.ScatterDimensionNumbers((0,), (1,), (1,)), + enable_and_disable_xla=True, + ) # Validate no XLA scatters for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex) - set(jtu.dtypes.boolean): @@ -1372,22 +1399,34 @@ def _make_scatter_harness(name, lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min, lax.scatter ]: for shape, scatter_indices, update_shape, dimension_numbers in [ - ((1,), [0], (), ((), (0,), (0,))), # zero case - ((1, 1), [0], (1,), ((0,), (0,), (0,))), - ((1, 1, 1), [0], (1, 1), ((0, 1), (0,), (0,))), - ((1, 50, 3), [32], (1, 3), ((0, 1), (1,), (1,))), - ((1, 2, 3), [1], (1, 3), ((0, 1), (1,), (1,))), # slice 2nd dim - ((1, 2, 3), [0], (2, 3), ((0, 1), (0,), (0,))), # slice 1st dim - ((1, 2, 3), [1, 2], (1,), ((0,), (1, 2), (1, 2))), # 2nd and 3rd - ((4, 2, 3), [3, 2], (2,), ((0,), (0, 2), (0, 2))), # 1st and 3rd - ((4, 2, 3, 5), [0, 4], (4, 3), ((0, 1), (1, 3), (1, 3))), # 2nd and 4th + ((1,), [0], (), + lax.ScatterDimensionNumbers((), (0,), (0,))), # zero case + ((1, 1), [0], (1,), + lax.ScatterDimensionNumbers((0,), (0,), (0,))), + ((1, 1, 1), [0], (1, 1), + lax.ScatterDimensionNumbers((0, 1), (0,), (0,))), + ((1, 50, 3), [32], (1, 3), + lax.ScatterDimensionNumbers((0, 1), (1,), (1,))), + ((1, 2, 3), [1], (1, 3), + lax.ScatterDimensionNumbers((0, 1), (1,), (1,))), # slice 2nd dim + ((1, 2, 3), [0], (2, 3), + lax.ScatterDimensionNumbers((0, 1), (0,), (0,))), # slice 1st dim + ((1, 2, 3), [1, 2], (1,), + lax.ScatterDimensionNumbers((0,), (1, 2), (1, 2))), # 2nd and 3rd + ((4, 2, 3), [3, 2], (2,), + lax.ScatterDimensionNumbers((0,), (0, 2), (0, 2))), # 1st and 3rd + ((4, 2, 3, 5), [0, 4], (4, 3), + lax.ScatterDimensionNumbers((0, 1), (1, 3), (1, 3))), # 2nd and 4th ((5, 6, 7), [[0, 1], [2, 3]], (2, 7), - ((1,), (0, 1), (0, 1))), # .at[((3,4),(5,5))] shapes + lax.ScatterDimensionNumbers((1,), (0, 1), (0, 1))), + # .at[((3,4),(5,5))] shapes ((5, 6, 7), [[[0], [1]], [[2], [3]]], (5, 2, 2, 7), - ((0, 3), (1,), (1,))), # .at[:, ((3,4),(5,5))] shapes + lax.ScatterDimensionNumbers((0, 3), (1,), (1,))), + # .at[:, ((3,4),(5,5))] shapes ((5, 6, 7), [[[0, 1], [2, 3]], [[4, 0], [1, 2]]], (5, 2, 2), - ((0,), (1, 2), (1, 2))), # .at[:, ((3,4),(5,5)), 3] shapes - ((1, 125), [0], (1,), ((0,), (1,), (1,))), + lax.ScatterDimensionNumbers((0,), (1, 2), (1, 2))), + # .at[:, ((3,4),(5,5)), 3] shapes + ((1, 125), [0], (1,), lax.ScatterDimensionNumbers((0,), (1,), (1,))), ]: for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS, lax.GatherScatterMode.FILL_OR_DROP): @@ -1410,11 +1449,16 @@ def _make_scatter_harness(name, lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min ]: for shape, scatter_indices, update_shape, dimension_numbers in [ - ((1,), [[0],[0]], (2,), ((), (0,), (0,))), # .at[((0,0),)] - ((3,), [[1],[0],[1]], (3,), ((), (0,), (0,))), # .at[((1,0,1),)] - ((2, 3), [[[2],[2],[2]]], (2, 1, 3), ((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)] - ((3, 5, 40), [[1],[1]], (3, 5, 2), ((0, 1), (2,), (2,))), - ((3, 5, 4), [[1],[1]], (3, 2, 4), ((0, 2), (1,), (1,))), + ((1,), [[0],[0]], (2,), + lax.ScatterDimensionNumbers((), (0,), (0,))), # .at[((0,0),)] + ((3,), [[1],[0],[1]], (3,), + lax.ScatterDimensionNumbers((), (0,), (0,))), # .at[((1,0,1),)] + ((2, 3), [[[2],[2],[2]]], (2, 1, 3), + lax.ScatterDimensionNumbers((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)] + ((3, 5, 40), [[1],[1]], (3, 5, 2), + lax.ScatterDimensionNumbers((0, 1), (2,), (2,))), + ((3, 5, 4), [[1],[1]], (3, 2, 4), + lax.ScatterDimensionNumbers((0, 2), (1,), (1,))), ]: for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS, lax.GatherScatterMode.FILL_OR_DROP): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 7226ea25922b..e28d0857d624 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4654,18 +4654,15 @@ def _top_k_jvp(primals, tangents, *, k): idx_shape = k_idxs.shape rank = len(idx_shape) gather_index_shape = idx_shape + (1,) - gather_indices = [] - for i in range(rank-1): - _iota = iota(k_idxs.dtype, idx_shape[i]) - _iota = broadcast_in_dim(_iota, gather_index_shape, (i,)) - gather_indices.append(_iota) - gather_indices.append(reshape(k_idxs, gather_index_shape)) - gather_indices = concatenate(gather_indices, dimension=rank) + gather_indices = reshape(k_idxs, gather_index_shape) slice_sizes = (1,) * rank dnums = slicing.GatherDimensionNumbers( - offset_dims=(), - collapsed_slice_dims=tuple(range(rank)), - start_index_map=tuple(range(rank))) + offset_dims=(), + collapsed_slice_dims=(rank - 1,), + operand_batching_dims=tuple(range(rank - 1)), + start_indices_batching_dims=tuple(range(rank - 1)), + start_index_map=(rank - 1,), + ) tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes) return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1])) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index c9a07072ddc7..9d4614f344fb 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1500,7 +1500,8 @@ def _pgather_impl(src, idx, *, axes): dnums = slicing.GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=(0,), - start_index_map=(0,)) + start_index_map=(0,), + ) return slicing.gather(src_one_axis_front, idx, dimension_numbers=dnums, slice_sizes=tuple(slice_sizes)) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 60dfa0e1b3d2..372ebd1a8389 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -233,6 +233,16 @@ class GatherDimensionNumbers(NamedTuple): start_index_map: for each dimension in `start_indices`, gives the corresponding dimension in the `operand` that is to be sliced. Must be a tuple of integers with size equal to `start_indices.shape[-1]`. + operand_batching_dims: the set of batching dimensions `i` in `operand` that + have `slice_sizes[i] == 1` and that should have a corresponding dimension + in both the `start_indices` (at the same index in + `start_indices_batching_dims`) and output of the gather. Must be a tuple + of integers in ascending order. + start_indices_batching_dims: the set of batching dimensions `i` in + `start_indices` that should have a corresponding dimension in both the + `operand` (at the same index in `operand_batching_dims`) and output of the + gather. Must be a tuple of integers (order is fixed based on + correspondence with `operand_batching_dims`). Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the @@ -241,6 +251,8 @@ class GatherDimensionNumbers(NamedTuple): offset_dims: tuple[int, ...] collapsed_slice_dims: tuple[int, ...] start_index_map: tuple[int, ...] + operand_batching_dims: tuple[int, ...] = () + start_indices_batching_dims: tuple[int, ...] = () class GatherScatterMode(enum.Enum): @@ -370,6 +382,17 @@ class ScatterDimensionNumbers(NamedTuple): scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives the corresponding dimension in `operand`. Must be a sequence of integers with size equal to `scatter_indices.shape[-1]`. + operand_batching_dims: the set of batching dimensions `i` in `operand` that + should have a corresponding dimension in both the `scatter_indices` (at + the same index in `scatter_indices_batching_dims`) and `updates`. Must be + a tuple of integers in ascending order. These are the mirror image of + `operand_batching_dims` in the case of `gather`. + scatter_indices_batching_dims: the set of batching dimensions `i` in + `scatter_indices` that should have a corresponding dimension in both the + `operand` (at the same index in `operand_batching_dims`) and output of the + gather. Must be a tuple of integers (order is fixed based on + correspondence with `input_batching_dims`). These are the mirror image of + `start_indices_batching_dims` in the case of `gather`. Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the @@ -378,6 +401,8 @@ class ScatterDimensionNumbers(NamedTuple): update_window_dims: Sequence[int] inserted_window_dims: Sequence[int] scatter_dims_to_operand_dims: Sequence[int] + operand_batching_dims: Sequence[int] = () + scatter_indices_batching_dims: Sequence[int] = () def scatter_add( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, @@ -694,7 +719,8 @@ def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=tuple(axes), - start_index_map=tuple(axes)) + start_index_map=tuple(axes), + ) return gather(src, indices, dimension_numbers=dnums, slice_sizes=tuple(slice_sizes)) @@ -1256,8 +1282,11 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): dims = tuple(range(ndims)) start_indices, dyn_slice_sizes = util.split_list(start_indices_and_dyn, [ndims]) start_idx_bds, dyn_slice_size_bds = util.split_list(start_idx_and_dyn_bds, [ndims]) - dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(), - start_index_map=dims) + dnums = GatherDimensionNumbers( + offset_dims=dims, + collapsed_slice_dims=(), + start_index_map=dims, + ) index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds) return _gather_batching_rule( [operand, index, *dyn_slice_sizes], @@ -1396,9 +1425,11 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): update_shape = (np.shape(update) if update_bd is batching.not_mapped else tuple(np.delete(np.shape(update), update_bd))) dims = tuple(range(len(update_shape))) - dnums = ScatterDimensionNumbers(update_window_dims=dims, - inserted_window_dims=(), - scatter_dims_to_operand_dims=dims) + dnums = ScatterDimensionNumbers( + update_window_dims=dims, + inserted_window_dims=(), + scatter_dims_to_operand_dims=dims, + ) index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd) return api.vmap( partial(scatter, dimension_numbers=dnums, @@ -1437,6 +1468,12 @@ def _is_sorted(dims, op_name, name): if dims[i] < dims[i - 1]: raise TypeError(f"{name} in {op_name} op must be sorted; got {dims}") +def _dims_in_range(dims, rank, op_name, name): + for dim in dims: + if dim < 0 or dim >= rank: + raise TypeError(f"Invalid {name} set in {op_name} op; valid range is " + f"[0, {rank}); got: {dim}.") + def _sorted_dims_in_range(dims, rank, op_name, name): if len(dims) == 0: return @@ -1453,6 +1490,11 @@ def _no_duplicate_dims(dims, op_name, name): if len(set(dims)) != len(dims): raise TypeError(f"{name} in {op_name} op must not repeat; got: {dims}.") +def _disjoint_dims(dims1, dims2, op_name, name1, name2): + if not set(dims1).isdisjoint(set(dims2)): + raise TypeError(f"{name1} and {name2} in {op_name} op must be disjoint; " + f"got: {dims1} and {dims2}.") + def _gather_shape_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1466,6 +1508,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims + operand_batching_dims = dimension_numbers.operand_batching_dims + start_indices_batching_dims = dimension_numbers.start_indices_batching_dims start_index_map = dimension_numbers.start_index_map # Note: in JAX, index_vector_dim is always computed as below, cf. the @@ -1521,6 +1565,50 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims") _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") + + _no_duplicate_dims(operand_batching_dims, "gather", "operand_batching_dims") + _is_sorted(operand_batching_dims, "gather", "operand_batching_dims") + _sorted_dims_in_range( + operand_batching_dims, _rank(operand), "gather", "operand_batching_dims" + ) + + _disjoint_dims(collapsed_slice_dims, operand_batching_dims, "gather", + "collapsed_slice_dims", "operand_batching_dims") + _disjoint_dims(start_index_map, operand_batching_dims, "gather", + "start_index_map", "operand_batching_dims") + + _no_duplicate_dims( + start_indices_batching_dims, "gather", "start_indices_batching_dims" + ) + _dims_in_range( + start_indices_batching_dims, + _rank(indices), + "gather", + "start_indices_batching_dims", + ) + if index_vector_dim in start_indices_batching_dims: + raise TypeError( + "Gather op cannot have the index vector dimension as a batching " + f"dimension; got {start_indices_batching_dims}." + ) + + if len(operand_batching_dims) != len(start_indices_batching_dims): + raise TypeError( + "Gather op requires equal numbers of operand_batching_dims and " + f"start_indices_batching_dims, got {operand_batching_dims} and" + f"{start_indices_batching_dims}." + ) + + operand_batch_shape = tuple(operand.shape[i] for i in operand_batching_dims) + indices_batch_shape = tuple( + indices.shape[i] for i in start_indices_batching_dims + ) + if not core.definitely_equal_shape(operand_batch_shape, indices_batch_shape): + raise TypeError( + "Gather op requires operand batching dimensions and indices batching " + f"dimensions to have the same shape, got {operand_batch_shape} and " + f"{indices_batch_shape}." + ) # End ValidateGatherDimensions if _rank(operand) != len(slice_sizes): @@ -1528,12 +1616,17 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " f"input_shape.rank={_rank(operand)}") - if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): - raise TypeError(f"All components of the offset index in a gather op must " - f"either be a offset dimension or explicitly collapsed; " - f"got len(slice_sizes)={len(slice_sizes)}, " - f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" - f"{collapsed_slice_dims}.") + if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims) + len( + operand_batching_dims + ): + raise TypeError( + "All components of the offset index in a gather op must " + "either be a offset dimension or explicitly collapsed/batching; " + f"got len(slice_sizes)={len(slice_sizes)}, " + f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" + f"{collapsed_slice_dims}, operand_batching_dims=" + f"{operand_batching_dims}." + ) for i in range(len(slice_sizes)): slice_size = slice_sizes[i] @@ -1552,12 +1645,21 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, f"but bound is {bound} for index " f"{collapsed_slice_dims[i]} at position {i}.") + for i in range(len(operand_batching_dims)): + bound = slice_sizes[operand_batching_dims[i]] + if bound > 1: + raise TypeError(f"Gather op can only have operand batching dims with " + f"bound 0/1, but bound is {bound} for index " + f"{operand_batching_dims[i]} at position {i}." + ) + return _gather_shape_computation(indices, dimension_numbers, slice_sizes) def _gather_shape_computation(indices, dimension_numbers, slice_sizes): offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims + operand_batching_dims = dimension_numbers.operand_batching_dims output_shape_rank = len(offset_dims) + _rank(indices) - 1 index_vector_dim = _rank(indices) - 1 @@ -1572,8 +1674,11 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes): indices_shape_gen = iter(expanded_indices_shape) - slice_sizes_gen = (s for i, s in enumerate(slice_sizes) - if i not in collapsed_slice_dims) + slice_sizes_gen = ( + s + for i, s in enumerate(slice_sizes) + if i not in collapsed_slice_dims and i not in operand_batching_dims + ) ans = tuple(next(slice_sizes_gen) if i in offset_dims else next(indices_shape_gen) for i in range(output_shape_rank)) return ans @@ -1631,9 +1736,12 @@ def _gather_transpose_rule(t, operand, indices, *, dimension_numbers, else: zeros = lax.full(operand_shape, lax._zero(t)) scatter_dnums = ScatterDimensionNumbers( - update_window_dims=dimension_numbers.offset_dims, - inserted_window_dims=dimension_numbers.collapsed_slice_dims, - scatter_dims_to_operand_dims=dimension_numbers.start_index_map) + update_window_dims=dimension_numbers.offset_dims, + inserted_window_dims=dimension_numbers.collapsed_slice_dims, + scatter_dims_to_operand_dims=dimension_numbers.start_index_map, + operand_batching_dims=dimension_numbers.operand_batching_dims, + scatter_indices_batching_dims=dimension_numbers.start_indices_batching_dims, + ) out = scatter_add(zeros, indices, t, scatter_dnums, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, @@ -1652,11 +1760,17 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, slice_sizes = (operand.shape[0],) + slice_sizes offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims)) collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) + operand_batching_dims = tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map) + start_index_map=start_index_map, + operand_batching_dims=operand_batching_dims, + start_indices_batching_dims=dimension_numbers.start_indices_batching_dims, + ) if isinstance(operand_bdim, batching.RaggedAxis): ragged_slice_sizes = batching.bdim_as_shape(operand_bdim, slice_sizes) for orig, fabricated in zip( @@ -1687,10 +1801,16 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, elif operand_bdim is None and indices_bdim is not None: indices = batching.moveaxis(indices, indices_bdim, 0) offset_dims = tuple(1 + d for d in dimension_numbers.offset_dims) + start_indices_batching_dims = tuple( + np.add(1, dimension_numbers.start_indices_batching_dims) + ) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=dimension_numbers.collapsed_slice_dims, - start_index_map=dimension_numbers.start_index_map) + start_index_map=dimension_numbers.start_index_map, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=start_indices_batching_dims, + ) # If batching indexed accesses into the same array, the batched gather may # no longer have sorted or unique indices. return gather(operand, indices, dimension_numbers=dnums, @@ -1702,61 +1822,34 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, operand = batching.moveaxis(operand, operand_bdim, 0) indices = batching.moveaxis(indices, indices_bdim, 0) - # This slightly awkward special case is needed because the shape rule for - # gather does not allow size-1 slices out of a size-0 dimension, even if - # the number of slices is zero. Likely the best fix would be to change the - # definition of gather() so it can be batched without the construction of - # an explicit iota of size-1 slices. if core.definitely_equal(operand.shape[0], 0): - output_shape = _gather_shape_rule( - core.ShapedArray(operand.shape[1:], operand.dtype), - core.ShapedArray(indices.shape[1:], - dtypes.canonicalize_dtype(indices.dtype)), - dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, - unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, - mode=mode, fill_value=fill_value) - return lax.full((0,) + output_shape, lax._zero(operand)), 0 - - # Example: user code had indices shape (3, 4, 5), and we have to deal with - # indices shape (7, 3, 4, 5). We transform that to indices of shape - # (7, 3, 4, 6) where we concatenated an iota that counts along our batch - # dimension to the front of the ndindex. - index_dtype = _promote_dtype_for_size(indices.dtype, indices.shape[0]) - count_shape = list(indices.shape) - count_shape[-1] = 1 - counts = lax.broadcasted_iota(index_dtype, tuple(count_shape), 0) - indices = lax.concatenate([counts, indices.astype(index_dtype)], - len(count_shape) - 1) - - slice_sizes = (1,) + slice_sizes - collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) + slice_sizes = (0,) + slice_sizes + else: + slice_sizes = (1,) + slice_sizes + collapsed_slice_dims = tuple( + np.add(1, dimension_numbers.collapsed_slice_dims) + ) + operand_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) + start_indices_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.start_indices_batching_dims) + ) offset_dims = tuple(np.add(1, dimension_numbers.offset_dims)) - start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map)) + start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map) + start_index_map=start_index_map, + operand_batching_dims=operand_batching_dims, + start_indices_batching_dims=start_indices_batching_dims, + ) return gather(operand, indices, dimension_numbers=dnums, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value), 0 -def _promote_dtype_for_size(dtype, size): - if not dtypes.issubdtype(dtype, np.integer): - return dtype - # size may be a dynamic shape, in which case we return at least int32 - try: - size = int(size) - except: - return dtype if np.iinfo(dtype).bits >= 32 else np.dtype('int32') - if size <= np.iinfo(dtype).max: - return dtype - elif size <= np.iinfo(np.int32).max: - return np.dtype('int32') - else: - return dtypes.canonicalize_dtype(np.int64) - def _gather_pad_rule(in_avals, out_avals, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1821,8 +1914,10 @@ def _gather_lower(ctx, operand, indices, *, GatherScatterMode.CLIP), mode dnums = hlo.GatherDimensionNumbers.get( collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), - operand_batching_dims=[], - start_indices_batching_dims=[], + operand_batching_dims=list(dimension_numbers.operand_batching_dims), + start_indices_batching_dims=list( + dimension_numbers.start_indices_batching_dims + ), index_vector_dim=len(ctx.avals_in[1].shape) - 1, offset_dims=list(dimension_numbers.offset_dims), start_index_map=list(dimension_numbers.start_index_map), @@ -1872,6 +1967,8 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, update_window_dims = dimension_numbers.update_window_dims inserted_window_dims = dimension_numbers.inserted_window_dims + operand_batching_dims = dimension_numbers.operand_batching_dims + scatter_indices_batching_dims = dimension_numbers.scatter_indices_batching_dims scatter_dims_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims # Note: in JAX, index_vector_dim is always computed as below, cf. the # documentation of the ScatterDimensionNumbers class. @@ -1909,8 +2006,55 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, _sorted_dims_in_range(inserted_window_dims, _rank(operand), "scatter", "inserted_window_dims") + # Validate operand_batching_dims and scatter_indices_batching_dims + _is_sorted(operand_batching_dims, "scatter", "operand_batching_dims") + _no_duplicate_dims(operand_batching_dims, "scatter", "operand_batching_dims") + _sorted_dims_in_range( + operand_batching_dims, _rank(operand), "scatter", "operand_batching_dims" + ) + _disjoint_dims(inserted_window_dims, operand_batching_dims, "scatter", + "inserted_window_dims", "operand_batching_dims") + _disjoint_dims(scatter_dims_to_operand_dims, operand_batching_dims, "scatter", + "scatter_dims_to_operand_dims", "operand_batching_dims") + + _no_duplicate_dims( + scatter_indices_batching_dims, "scatter", "scatter_indices_batching_dims" + ) + _dims_in_range( + scatter_indices_batching_dims, + _rank(indices), + "scatter", + "scatter_indices_batching_dims", + ) + if index_vector_dim in scatter_indices_batching_dims: + raise TypeError( + "Scatter op cannot have the index vector dimension as a batching " + f"dimension; got {scatter_indices_batching_dims}.") + + if len(operand_batching_dims) != len(scatter_indices_batching_dims): + raise TypeError( + "Scatter op requires equal numbers of operand_batching_dims and " + f"scatter_indices_batching_dims, got {operand_batching_dims} and " + f"{scatter_indices_batching_dims}." + ) + + operand_batch_shape = tuple(operand.shape[i] for i in operand_batching_dims) + indices_batch_shape = tuple( + indices.shape[i] for i in scatter_indices_batching_dims + ) + if not core.definitely_equal_shape(operand_batch_shape, indices_batch_shape): + raise TypeError( + "Scatter op requires operand batching dimensions and indices batching " + f"dimensions to have the same shape, got {operand_batch_shape} and " + f"{indices_batch_shape}." + ) + # Validate window_size - window_size = len(update_window_dims) + len(inserted_window_dims) + window_size = ( + len(update_window_dims) + + len(inserted_window_dims) + + len(operand_batching_dims) + ) if _rank(operand) != window_size: raise TypeError(f"Scatter op has window of size {window_size}; doesn't " f"match operand of rank {_rank(operand)}.") @@ -1933,8 +2077,14 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, _no_duplicate_dims(scatter_dims_to_operand_dims, "scatter", "scatter_dims_to_operand_dims") - max_update_slice_sizes = [operand.shape[i] for i in range(len(operand.shape)) - if not i in set(inserted_window_dims)] + max_update_slice_sizes = [ + operand.shape[i] + for i in range(len(operand.shape)) + if ( + i not in set(inserted_window_dims) + and i not in set(operand_batching_dims) + ) + ] for i in range(len(update_window_dims)): update_window_dim = update_window_dims[i] @@ -1968,7 +2118,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums): slice_sizes = [] pos = 0 for i in range(len(operand.shape)): - if i in dnums.inserted_window_dims: + if i in dnums.inserted_window_dims or i in dnums.operand_batching_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) @@ -2029,13 +2179,19 @@ def _scatter_add_transpose_rule(t, operand, indices, updates, *, if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if i in dimension_numbers.inserted_window_dims: + if ( + i in dimension_numbers.inserted_window_dims + or i in dimension_numbers.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2067,13 +2223,19 @@ def _scatter_mul_transpose_rule(t, operand, indices, updates, *, raise NotImplementedError( "scatter_mul gradients are only implemented if `unique_indices=True`") gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if i in dimension_numbers.inserted_window_dims: + if ( + i in dimension_numbers.inserted_window_dims + or i in dimension_numbers.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2095,40 +2257,52 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims) if ax is not None) operand = batching.bdim_at_front(operand, operand_bdim, size) - operand_bdim = 0 updates = batching.bdim_at_front(updates, updates_bdim, size) if indices_bdim is None: inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims)) update_window_dims = (0,) + tuple(np.add(1, dimension_numbers.update_window_dims)) + operand_batching_dims = tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) scatter_dims_to_operand_dims = tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, + operand_batching_dims=operand_batching_dims, + scatter_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) return scatter_op.bind( operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0 - # see the third case in _gather_batching_rule for comparison and comments indices = batching.bdim_at_front(indices, indices_bdim, size) - count_shape = list(indices.shape) - count_shape[-1] = 1 - counts = lax.broadcasted_iota(indices.dtype, tuple(count_shape), 0) - indices = lax.concatenate([counts, indices], len(count_shape) - 1) - update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims)) - inserted_window_dims = (0,) + tuple(np.add(1, dimension_numbers.inserted_window_dims)) - scatter_dims_to_operand_dims = (0,) + tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) + inserted_window_dims = tuple( + np.add(1, dimension_numbers.inserted_window_dims) + ) + operand_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.operand_batching_dims) + ) + scatter_indices_batching_dims = (0,) + tuple( + np.add(1, dimension_numbers.scatter_indices_batching_dims) + ) + scatter_dims_to_operand_dims = tuple( + np.add(1, dimension_numbers.scatter_dims_to_operand_dims) + ) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, + operand_batching_dims=operand_batching_dims, + scatter_indices_batching_dims=scatter_indices_batching_dims, + ) return scatter_op.bind( operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, @@ -2190,12 +2364,18 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, gather_dnums = GatherDimensionNumbers( offset_dims=scatter_dnums.update_window_dims, collapsed_slice_dims=scatter_dnums.inserted_window_dims, - start_index_map=scatter_dnums.scatter_dims_to_operand_dims) + start_index_map=scatter_dnums.scatter_dims_to_operand_dims, + operand_batching_dims=scatter_dnums.operand_batching_dims, + start_indices_batching_dims=scatter_dnums.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(operand.shape)): - if i in scatter_dnums.inserted_window_dims: + if ( + i in scatter_dnums.inserted_window_dims + or i in scatter_dnums.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[scatter_dnums.update_window_dims[pos]]) @@ -2323,7 +2503,6 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, # of using scatter-add here is that we don't need a `scatter` transpose # rule. - # a) attach a positive ID to each update in `updates`, and perform a scatter # on the IDs. ids_shape = list(updates.shape) @@ -2344,13 +2523,16 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, # b) compute the inverse gather that "undoes" the scatter on the id values. gather_dnums = GatherDimensionNumbers( - offset_dims=dnums.update_window_dims, - collapsed_slice_dims=dnums.inserted_window_dims, - start_index_map=dnums.scatter_dims_to_operand_dims) + offset_dims=dnums.update_window_dims, + collapsed_slice_dims=dnums.inserted_window_dims, + start_index_map=dnums.scatter_dims_to_operand_dims, + operand_batching_dims=dnums.operand_batching_dims, + start_indices_batching_dims=dnums.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(scattered_ids.shape)): - if i in dnums.inserted_window_dims: + if i in dnums.inserted_window_dims or i in dnums.operand_batching_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) @@ -2405,13 +2587,19 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims, + operand_batching_dims=dimension_numbers.operand_batching_dims, + start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, + ) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if i in dimension_numbers.inserted_window_dims: + if ( + i in dimension_numbers.inserted_window_dims + or i in dimension_numbers.operand_batching_dims + ): slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2479,8 +2667,8 @@ def _scatter_lower(ctx, operand, indices, updates, *, scatter_dnums = hlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), - input_batching_dims=[], - scatter_indices_batching_dims=[], + input_batching_dims=list(dnums.operand_batching_dims), + scatter_indices_batching_dims=list(dnums.scatter_indices_batching_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1, ) @@ -2539,8 +2727,8 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates, scatter_dnums = hlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), - input_batching_dims=[], - scatter_indices_batching_dims=[], + input_batching_dims=list(dnums.operand_batching_dims), + scatter_indices_batching_dims=list(dnums.scatter_indices_batching_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1, ) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 559d17cd9514..ac3074f45934 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9824,6 +9824,8 @@ def replace(tup, val): offset_dims = [] start_index_map = [] collapsed_slice_dims = [] + operand_batching_dims = [] + start_indices_batching_dims = [] j = 0 for i in range(rank): if i == axis_int: @@ -9848,21 +9850,23 @@ def replace(tup, val): collapsed_slice_dims.append(i) j += 1 else: - # Otherwise, idx_shape[i] == arr_shape[i]. Use an iota index so - # corresponding elements of array and index are gathered. - # TODO(mattjj): next line needs updating for dynamic shapes - iota = lax.broadcasted_iota(index_dtype, gather_index_shape, j) - gather_indices.append(iota) - slice_sizes.append(1) - start_index_map.append(i) - collapsed_slice_dims.append(i) + # Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both + # array and index as batching so corresponding elements are gathered. + if core.definitely_equal(arr_shape[i], 0): + slice_sizes.append(0) + else: + slice_sizes.append(1) + operand_batching_dims.append(i) + start_indices_batching_dims.append(j) j += 1 gather_indices_arr = lax.concatenate(gather_indices, dimension=j) dnums = lax.GatherDimensionNumbers( offset_dims=tuple(offset_dims), collapsed_slice_dims=tuple(collapsed_slice_dims), - start_index_map=tuple(start_index_map)) + start_index_map=tuple(start_index_map), + operand_batching_dims=tuple(operand_batching_dims), + start_indices_batching_dims=tuple(start_indices_batching_dims)) return lax.gather(a, gather_indices_arr, dnums, tuple(slice_sizes), mode="fill" if mode is None else mode, fill_value=fill_value) diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 2bcfe96ad2f0..809df8195d54 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -122,7 +122,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, - scatter_dims_to_operand_dims=indexer.dnums.start_index_map + scatter_dims_to_operand_dims=indexer.dnums.start_index_map, + operand_batching_dims=indexer.dnums.operand_batching_dims, + scatter_indices_batching_dims=indexer.dnums.start_indices_batching_dims, ) out = scatter_op( x, indexer.gather_indices, y, dnums, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 6d7f2c2a1e2c..525b163a6140 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2870,6 +2870,9 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers): proto.offset_dims.extend(dimension_numbers.offset_dims) proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims) proto.start_index_map.extend(dimension_numbers.start_index_map) + proto.operand_batching_dims.extend(dimension_numbers.operand_batching_dims) + proto.start_indices_batching_dims.extend( + dimension_numbers.start_indices_batching_dims) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto @@ -2981,6 +2984,9 @@ def _scatter_dimensions_proto(indices_shape, dimension_numbers): proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims) proto.scatter_dims_to_operand_dims.extend( dimension_numbers.scatter_dims_to_operand_dims) + proto.input_batching_dims.extend(dimension_numbers.operand_batching_dims) + proto.scatter_indices_batching_dims.extend( + dimension_numbers.scatter_indices_batching_dims) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto diff --git a/tests/lax_test.py b/tests/lax_test.py index d82b35c6b711..a2d4c939df55 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2512,6 +2512,18 @@ def testIndexTake(self, shape, dtype, idxs, axes): ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), + lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(1,), + start_index_map=(1,), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (1, 1)), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + lax.GatherDimensionNumbers( + offset_dims=(2,), collapsed_slice_dims=(), + start_index_map=(2,), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (1, 1, 3)) ]], dtype=lax_test_util.all_dtypes, ) @@ -2529,63 +2541,196 @@ def testGather(self, shape, dtype, idxs, dnums, slice_sizes): @parameterized.named_parameters( {"testcase_name": f"_{testcase_name}", "operand_shape": operand_shape, "indices_shape": indices_shape, - "dimension_numbers": lax.GatherDimensionNumbers( - offset_dims=offset_dims, - collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map), + "dimension_numbers": dimension_numbers, "slice_sizes": slice_sizes, "msg": msg} - for (testcase_name, operand_shape, indices_shape, offset_dims, - collapsed_slice_dims, start_index_map, slice_sizes, msg) in [ + for (testcase_name, operand_shape, indices_shape, dimension_numbers, + slice_sizes, msg) in [ ("NonAscendingWindowIndices", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 6, 8, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "offset_dims in gather op must be sorted"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 8, 7), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "offset_dims in gather op must be sorted"), ("RepeatedWindowIndices", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 6, 7, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "offset_dims in gather op must not repeat"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 7), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "offset_dims in gather op must not repeat"), ("WindowIndexOutOfBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 100, 101, 102), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "Offset dimension 2 in gather op is out of bounds"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 100, 101, 102), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "Offset dimension 2 in gather op is out of bounds"), ("WindowIndexBarelyOutOfBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - (4, 5, 6, 7, 9), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "Offset dimension 4 in gather op is out of bounds"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 9), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "Offset dimension 4 in gather op is out of bounds"), ("MismatchingElidedWindowDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (4,), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(4,), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), ("All components of the offset index in a gather op must either be a " - "offset dimension or explicitly collapsed")), + "offset dimension or explicitly collapsed/batching")), + ("MismatchingElidedWindowDimsV2", (10, 9, 8, 7, 6, 5), (10, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(4,), + start_index_map=(1, 2, 3, 4), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (10, 9, 8, 7, 6, 5), + ("All components of the offset index in a gather op must either be a " + "offset dimension or explicitly collapsed/batching")), ("OutOfBoundsWindowToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (0, 1, 2, 3, 19), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(0, 1, 2, 3, 19), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "Invalid collapsed_slice_dims set in gather op; valid range is"), ("RepeatedWindowToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (0, 1, 2, 3, 3), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), - "collapsed_slice_dims in gather op must not repeat"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(0, 1, 2, 3, 3), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "collapsed_slice_dims in gather op must not repeat"), ("MismatchingGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (), (0, 1, 2, 3), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3)), + (10, 9, 8, 7, 6), ("Gather op has 4 elements in start_index_map and the bound of " "dimension index_vector_dim=4 of indices is 5. These two " "numbers must be equal.")), ("OutOfBoundsGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (), (0, 1, 2, 3, 7), (10, 9, 8, 7, 6), - "Invalid start_index_map"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 7)), + (10, 9, 8, 7, 6), "Invalid start_index_map"), ("RepeatedGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (), (0, 1, 2, 3, 3), (10, 9, 8, 7, 6), - "start_index_map in gather op must not repeat"), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 3)), + (10, 9, 8, 7, 6), "start_index_map in gather op must not repeat"), ("NonAscendingElidedWindowDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7, 8), (2, 1), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(2, 1), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), "collapsed_slice_dims in gather op must be sorted"), ("WindowBoundsTooLarge", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7), (2,), (0, 1, 2, 3, 4), (10, 9, 8, 100, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(2,), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 100, 6), "Slice size at index 3 in gather op is out of range"), ("MismatchingNumberOfWindowBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7), "Gather op must have one slice size for every input dimension"), ("WindowBoundsNot1ForElidedDim", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - (4, 5, 6, 7), (1,), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(1,), + start_index_map=(0, 1, 2, 3, 4)), + (10, 9, 8, 7, 6), ("Gather op can only collapse slice dims with bound 1, but bound " - "is 9 for index 1 at position 0.")) + "is 9 for index 1 at position 0.")), + ("RepeatedOperandBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 4), operand_batching_dims=(2, 3, 3)), + (10, 9, 8, 7, 6), + "operand_batching_dims in gather op must not repeat"), + ("NonAscendingOperandBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 4), operand_batching_dims=(3, 2)), + (10, 9, 8, 7, 6), + "operand_batching_dims in gather op must be sorted"), + ("OutOfBoundsOperandBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + operand_batching_dims=(0, 10)), + (10, 9, 8, 7, 6), + "Invalid operand_batching_dims set in gather op; valid range is"), + ("NonDisjointCollapsedAndBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1, 2), + start_index_map=(0, 1, 4), operand_batching_dims=(2, 3)), + (10, 9, 8, 7, 6), + ("collapsed_slice_dims and operand_batching_dims in gather op must be " + "disjoint")), + ("NonDisjointStartIndexMapAndBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 4), operand_batching_dims=(2, 3)), + (10, 9, 8, 7, 6), + ("start_index_map and operand_batching_dims in gather op must be " + "disjoint")), + ("WindowBoundsNot1ForBatchingDim", (10, 9, 8, 7, 6), (9, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(), + start_index_map=(0, 2, 3, 4), operand_batching_dims=(1,), + start_indices_batching_dims=(0,)), + (10, 9, 8, 7, 6), + ("Gather op can only have operand batching dims with bound 0/1, but " + "bound is 9 for index 1 at position 0.")), + ("RepeatedStartIndicesBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + start_indices_batching_dims=(0, 1, 0)), + (10, 9, 8, 7, 6), + "start_indices_batching_dims in gather op must not repeat"), + ("OutOfBoundsStartIndicesBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + start_indices_batching_dims=(0, 5)), + (10, 9, 8, 7, 6), + "Invalid start_indices_batching_dims set in gather op; valid range"), + ("IndexVectorDimInStartIndicesBatchingDims", (10, 9, 8, 7, 6), + (5, 4, 3, 2, 5), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), + start_index_map=(0, 1, 2, 3, 4), + start_indices_batching_dims=(0, 4)), + (10, 9, 8, 7, 6), + ("Gather op cannot have the index vector dimension as a batching " + "dimension")), + ("MismatchingNumberOfBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 4), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6), collapsed_slice_dims=(1, 2), + start_index_map=(1, 2, 3, 4), operand_batching_dims=(0,), + start_indices_batching_dims=(0, 1)), + (10, 9, 8, 7, 6), + ("Gather op requires equal numbers of operand_batching_dims and " + "start_indices_batching_dims")), + ("MismatchingBatchingDimSizes", (10, 9, 8, 7, 6), (10, 9, 3, 2, 3), + lax.GatherDimensionNumbers( + offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(2, 3, 4), + start_index_map=(2, 3, 4), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (10, 9, 8, 7, 6), + ("Gather op requires operand batching dimensions and indices batching " + "dimensions to have the same shape")) ] ) def testGatherShapeCheckingRule(self, operand_shape, indices_shape, dimension_numbers, slice_sizes, msg): + """ + + Args: + operand_shape: + indices_shape: + dimension_numbers: + slice_sizes: + msg: + """ operand = np.ones(operand_shape, dtype=np.int32) indices = np.ones(indices_shape, dtype=np.int32) @@ -2602,9 +2747,19 @@ def testGatherShapeCheckingRule(self, operand_shape, indices_shape, ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.inexact_dtypes, mode=["clip", "fill", None], @@ -2628,9 +2783,19 @@ def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, mode): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2653,9 +2818,19 @@ def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2677,9 +2852,19 @@ def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2701,9 +2886,19 @@ def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2721,84 +2916,207 @@ def testScatter(self, arg_shape, dtype, idxs, update_shape, dnums): # variations to account for the implicit setting of index_vector_dim in JAX. @parameterized.named_parameters( {"testcase_name": f"_{testcase_name}", "operand_shape": operand_shape, - "indices": indices, "update_shape": update_shape, - "dimension_numbers": lax.ScatterDimensionNumbers( - update_window_dims=update_window_dims, - inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims), + "indices_shape": indices_shape, "update_shape": update_shape, + "dimension_numbers": dimension_numbers, "msg": msg} - for (testcase_name, operand_shape, indices, update_shape, - update_window_dims, inserted_window_dims, - scatter_dims_to_operand_dims, msg) in [ - ("ScatterWithUpdatesBiggerThanInput", (64, 48), np.zeros((32, 1)), - (65, 32), (0,), (1,), (1,), "Bounds of the window dimensions"), - ("ScatterWithUpdatesBiggerThanInputV2", (64, 48), - np.zeros((32, 1)), (32, 49), (1,), (0,), (1,), + for (testcase_name, operand_shape, indices_shape, update_shape, + dimension_numbers, msg) in [ + ("ScatterWithUpdatesBiggerThanInput", (64, 48), (32, 1), (65, 32), + lax.ScatterDimensionNumbers( + update_window_dims=(0,), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,)), "Bounds of the window dimensions"), - ("ScatterWithUpdatesNotMatchingIndices", (64, 48), - np.zeros((32, 1)), (64, 31), (0,), (1,), (1,), + ("ScatterWithUpdatesBiggerThanInputV2", (64, 48), (32, 1), + (32, 49), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(1,)), + "Bounds of the window dimensions"), + ("ScatterWithUpdatesNotMatchingIndices", (64, 48), (32, 1), + (64, 31), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(1,)), "Bounds of the scatter dimensions"), - ("ScatterWithUpdatesNotMatchingIndicesV2", (64, 48), - np.zeros((32, 1)), (31, 48), (1,), (0,), (1,), + ("ScatterWithUpdatesNotMatchingIndicesV2", (64, 48), (32, 1), + (31, 48), lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(1,)), "Bounds of the scatter dimensions"), ("ScatterNdWithUpdatesBiggerThanInput", (64, 48), - np.zeros((10, 9, 8, 7, 1)), (10, 9, 8, 7, 65), (4,), (1,), - (0,), "Bounds of the window dimensions"), + (10, 9, 8, 7, 1), (10, 9, 8, 7, 65), + lax.ScatterDimensionNumbers( + update_window_dims=(4,), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,)), + "Bounds of the window dimensions"), ("ScatterNdWithUpdatesNotMatchingIndices", (64, 48), - np.zeros((10, 9, 8, 7, 1)), (9, 9, 8, 7, 64), (4,), (1,), (0,), + (10, 9, 8, 7, 1), (9, 9, 8, 7, 64), + lax.ScatterDimensionNumbers( + update_window_dims=(4,), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(0,)), "Bounds of the scatter dimensions"), - ("InvalidUpdates", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4, 1), - (4, 5, 6), (1, 2), (0, 1, 2, 3, 4), + ("InvalidUpdates", (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), + (10, 9, 8, 7, 3, 2, 4, 1), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "Updates tensor must be of rank 7; got 8."), - ("NonAscendingUpdateWindowDims", (6, 5, 4, 3, 2), - np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), - (4, 5, 6, 8, 7), (), (0, 1, 2, 3, 4), + ("NonAscendingUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), + (10, 9, 8, 7, 6, 5, 4, 3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6, 8, 7), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "update_window_dims in scatter op must be sorted"), - ("RepeatedUpdateWindowDims", (6, 5, 4, 3, 2), - np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), - (4, 5, 6, 7, 7), (), (0, 1, 2, 3, 4), + ("RepeatedUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), + (10, 9, 8, 7, 6, 5, 4, 3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6, 7, 7), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "update_window_dims in scatter op must not repeat"), - ("OutOfBoundsUpdateWindowDims", (6, 5, 4, 3, 2), - np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), - (4, 5, 6, 7, 9), (), (0, 1, 2, 3, 4), + ("OutOfBoundsUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), + (10, 9, 8, 7, 6, 5, 4, 3, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6, 7, 9), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "Invalid update_window_dims set in scatter op"), ("NonAscendingInsertedWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (2, 1), (0, 1, 2, 3, 4), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(2, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "inserted_window_dims in scatter op must be sorted"), ("RepeatedInsertedWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 1), (0, 1, 2, 3, 4), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "inserted_window_dims in scatter op must not repeat"), ("OutOfBoundsInsertedWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 5), (0, 1, 2, 3, 4), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 5), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), "Invalid inserted_window_dims set in scatter op"), ("MismatchingScatterDimsToOperandDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 2), (0, 1, 2, 3), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 3)), ("Scatter op has 4 elements in scatter_dims_to_operand_dims and " "the bound of dimension index_vector_dim=4 of indices " "is 5. These two numbers must be equal")), ("OutOfBoundsScatterDimsToOperandDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 2), (0, 1, 2, 3, 10), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 10)), "Invalid scatter_dims_to_operand_dims mapping"), ("RepeatedValuesInScatterDimsToOperandDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1, 2), (0, 1, 2, 2, 3), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(0, 1, 2, 2, 3)), "scatter_dims_to_operand_dims in scatter op must not repeat"), ("InsufficientWindowDims", (50, 49, 48, 47, 46), - np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), - (4, 5, 6), (1,), (0, 1, 2, 3), + (10, 9, 8, 7, 4), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(0, 1, 2, 3)), ("Scatter op has window of size 4; doesn't match operand of " - "rank 5.")) + "rank 5.")), + ("InsufficientWindowDimsV2", (10, 49, 48, 47, 46, 45), + (10, 9, 8, 7, 3), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1, 2, 3), + operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,)), + ("Scatter op has window of size 5; doesn't match operand of " + "rank 6.")), + ("RepeatedOperandBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 4), + operand_batching_dims=(2, 3, 3)), + "operand_batching_dims in scatter op must not repeat"), + ("NonAscendingOperandBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 4), + operand_batching_dims=(3, 2)), + "operand_batching_dims in scatter op must be sorted"), + ("OutOfBoundsOperandBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + operand_batching_dims=(0, 10)), + ("Invalid operand_batching_dims set in scatter op; valid range " + "is")), + ("NonDisjointCollapsedAndBatchingDims", (50, 49, 48, 47, 46, 45), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 4), + operand_batching_dims=(1, 2)), + ("inserted_window_dims and operand_batching_dims in scatter op " + "must be disjoint")), + ("NonDisjointScatterDimsToOperandDimsAndBatchingDims", + (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), + (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 4), + operand_batching_dims=(2, 3)), + ("scatter_dims_to_operand_dims and operand_batching_dims in " + "scatter op must be disjoint")), + ("RepeatedScatterIndicesBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + scatter_indices_batching_dims=(0, 1, 0)), + "scatter_indices_batching_dims in scatter op must not repeat"), + ("OutOfBoundsScatterIndicesBatchingDims", (50, 49, 48, 47, 46), + (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + scatter_indices_batching_dims=(0, 5)), + ("Invalid scatter_indices_batching_dims set in scatter op; " + "valid range")), + ("IndexVectorDimInScatterIndicesBatchingDims", + (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), + (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), + scatter_indices_batching_dims=(0, 4)), + ("Scatter op cannot have the index vector dimension as a " + "batching dimension")), + ("MismatchingNumberOfBatchingDims", (50, 49, 48, 47, 46, 45), + (10, 9, 8, 7, 4), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), + scatter_dims_to_operand_dims=(1, 2, 3, 4), + operand_batching_dims=(0,), + scatter_indices_batching_dims=(0, 1)), + ("Scatter op requires equal numbers of operand_batching_dims " + "and scatter_indices_batching_dims")), + ("MismatchingBatchingDimSizes", (10, 9, 48, 47, 46, 45), + (10, 9, 8, 7, 2), (10, 9, 8, 7, 3, 2, 4), + lax.ScatterDimensionNumbers( + update_window_dims=(4, 5, 6), inserted_window_dims=(2,), + scatter_dims_to_operand_dims=(2, 3), + operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0)), + ("Scatter op requires operand batching dimensions and indices " + "batching dimensions to have the same shape")) ] ) - def testScatterShapeCheckingRule(self, operand_shape, indices, + def testScatterShapeCheckingRule(self, operand_shape, indices_shape, update_shape, dimension_numbers, msg): - + indices = np.zeros(indices_shape, dtype=np.int32) def f(x, y): operand = lax.broadcast(x, operand_shape) updates = lax.broadcast(y, update_shape) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 37a0011e7bd0..0f259bf490e6 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -566,6 +566,18 @@ def testFft(self, fft_ndims, shape, bdims): ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), + lax.GatherDimensionNumbers( + offset_dims=(), collapsed_slice_dims=(1,), + start_index_map=(1,), operand_batching_dims=(0,), + start_indices_batching_dims=(0,)), + (1, 1)), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + lax.GatherDimensionNumbers( + offset_dims=(2,), collapsed_slice_dims=(), + start_index_map=(2,), operand_batching_dims=(0, 1), + start_indices_batching_dims=(1, 0)), + (1, 1, 3)) ] for bdims in lax_test_util.all_bdims(shape, idxs.shape)], dtype=lax_test_util.all_dtypes @@ -590,6 +602,16 @@ def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ] for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape, update_shape)], dtype=lax_test_util.float_dtypes @@ -613,6 +635,16 @@ def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), + ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), + lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(1,), + scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), + scatter_indices_batching_dims=(0,))), + ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), + (3, 2, 3), lax.ScatterDimensionNumbers( + update_window_dims=(2,), inserted_window_dims=(), + scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), + scatter_indices_batching_dims=(1, 0))) ] for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape)], dtype=lax_test_util.float_dtypes, From bc1e1a0220cbfd2761a42104f430a08ff8ed12f5 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 25 Sep 2024 06:16:22 -0700 Subject: [PATCH 653/702] Add support for setting a dot product "algorithm" for lax.dot_general. The StableHLO spec has a new "algorithm" parameter that allows specifying the algorithm that is used to execute a matrix multiplication, and it can tune the trade-off between performance and computational cost. Historically, in JAX, the precision and preferred_element_type parameters have been used to expose some level of control, but their behavior is platform dependent and not sufficiently flexible for performance use cases. This change adds a new "algorithm" parameter to dot_general to add support for the new explicit API. This parameter can be a member of the `SupportedDotAlgorithm` `Enum` to use an algorithm that is known to be supported on at least some hardware. Otherwise, it can be specified using the `DotAlgorithm` data structure which exposes the full generality of the StableHLO spec. Transposition is supported using the `transpose_algorithm` argument. PiperOrigin-RevId: 678672686 --- docs/jax.lax.rst | 1 + jax/_src/lax/lax.py | 340 +++++++++++++++++++++++-- jax/_src/pallas/triton/lowering.py | 4 +- jax/experimental/jax2tf/impl_no_xla.py | 3 + jax/experimental/jax2tf/jax2tf.py | 3 + jax/experimental/sparse/bcoo.py | 11 +- jax/experimental/sparse/bcsr.py | 8 +- jax/experimental/sparse/util.py | 3 +- jax/lax/__init__.py | 3 + tests/lax_test.py | 174 +++++++++++++ 10 files changed, 528 insertions(+), 22 deletions(-) diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index e0fc5ad46b3e..3a03665b3217 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -249,6 +249,7 @@ Argument classes .. autoclass:: ConvDimensionNumbers .. autoclass:: ConvGeneralDilatedDimensionNumbers +.. autoclass:: DotAlgorithm .. autoclass:: GatherDimensionNumbers .. autoclass:: GatherScatterMode .. autoclass:: Precision diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e28d0857d624..ef93304697f3 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -22,7 +22,7 @@ import itertools import math import operator -from typing import Any, TypeVar, Union, cast as type_cast, overload +from typing import Any, NamedTuple, TypeVar, Union, cast as type_cast, overload import warnings import numpy as np @@ -709,8 +709,197 @@ def __str__(self) -> str: None, ] + +class DotAlgorithm(NamedTuple): + """Specify the algorithm used for computing dot products. + + When used as input to :func:`~jax.lax.dot_general`, this data structure is + used for controlling the properties of the algorithm used for computing the + dot product. This API controls the precision used for the computation, and + allows users to access hardware-specific accelerations. + + Support for these algorithms is platform dependent, and using an unsupported + algorithm will raise a Python exception when the computation is compiled. The + algorithms that are known to be supported on at least some platforms are + listed in the :class:`~jax.lax.DotAlgorithm.Preset` enum, and these are a + good starting point for experimenting with this API. + + A "dot algorithm" is specified by the following parameters: + + * ``lhs_precision_type`` and ``rhs_precision_type``, the data types that the + LHS and RHS of the operation are rounded to. + * ``accumulation_type`` the data type used for accumulation. + * ``lhs_component_count``, ``rhs_component_count``, and + ``num_primitive_operations`` apply to algorithms that decompose the LHS + and/or RHS into multiple components and execute multiple operations on + those values, usually to emulate a higher precision. For algorithms with no + decomposition, these values should be set to ``1``. + * ``allow_imprecise_accumulation`` to specify if accumulation in lower + precision is permitted for some steps (e.g. + ``CUBLASLT_MATMUL_DESC_FAST_ACCUM``). + + The `StableHLO spec `_ for + the dot operation doesn't require that the precision types be the same as the + storage types for the inputs or outputs, but some plaforms may require that + these types match. Furthermore, the return type of + :func:`~jax.lax.dot_general` is always defined by the ``accumulation_type`` + parameter of the input algorithm, if specified. + + Examples: + + Accumulate two 16-bit floats using a 32-bit float accumulator: + + >>> algorithm = DotAlgorithm( + ... lhs_precision_type=np.float16, + ... rhs_precision_type=np.float16, + ... accumulation_type=np.float32, + ... ) + >>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) + >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) + >>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP + array([ 1., 4., 9., 16.], dtype=float32) + + Or, equivalently, using a preset: + + >>> algorithm = DotAlgorithm.Preset.F16_F16_F32 + >>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP + array([ 1., 4., 9., 16.], dtype=float32) + """ + + lhs_precision_type: DTypeLike + rhs_precision_type: DTypeLike + accumulation_type: DTypeLike + lhs_component_count: int = 1 + rhs_component_count: int = 1 + num_primitive_operations: int = 1 + allow_imprecise_accumulation: bool = False + + def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, + rhs_dtype: DTypeLike) -> hlo.DotAlgorithm: + del lhs_dtype, rhs_dtype # unused + return hlo.DotAlgorithm.get( + mlir.dtype_to_ir_type(dtypes.dtype(self.lhs_precision_type)), + mlir.dtype_to_ir_type(dtypes.dtype(self.rhs_precision_type)), + mlir.dtype_to_ir_type(dtypes.dtype(self.accumulation_type)), + self.lhs_component_count, + self.rhs_component_count, + self.num_primitive_operations, + self.allow_imprecise_accumulation, + ) + + # mypy doesn't currently support nested classes in a NamedTuple definition. + class Preset(enum.Enum): # type: ignore[misc] + DEFAULT = 0 + ANY_F8_ANY_F8_F32 = 1 + ANY_F8_ANY_F8_F32_FAST_ACCUM = 2 + F16_F16_F16 = 3 + F16_F16_F32 = 4 + BF16_BF16_BF16 = 5 + BF16_BF16_F32 = 6 + BF16_BF16_F32_X3 = 7 + BF16_BF16_F32_X6 = 8 + TF32_TF32_F32 = 9 + TF32_TF32_F32_X3 = 10 + F32_F32_F32 = 11 + F64_F64_F64 = 12 + + def __repr__(self) -> str: + return f'{self.__class__.__name__}.{self.name}' + + def __str__(self) -> str: + return self.name + + @property + def accumulation_type(self) -> DTypeLike: + match self: + case DotAlgorithm.Preset.DEFAULT: + raise TypeError( + "The default dot algorithm does not have an accumulation type.") + case DotAlgorithm.Preset.F16_F16_F16: + return np.float16 + case DotAlgorithm.Preset.BF16_BF16_BF16: + return dtypes.bfloat16 + case DotAlgorithm.Preset.F64_F64_F64: + return np.float64 + case _: + return np.float32 + + def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, + rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None: + if self == DotAlgorithm.Preset.DEFAULT: + return None + + if self in (DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, + DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM): + fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz), + np.dtype(dtypes.float8_e4m3fn), + np.dtype(dtypes.float8_e4m3fnuz), + np.dtype(dtypes.float8_e5m2), + np.dtype(dtypes.float8_e5m2fnuz)) + if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: + raise ValueError( + f"The dot algorithm '{self}' requires both inputs to have float8 " + f"dtypes. Got {lhs_dtype} and {rhs_dtype} instead.") + lhs = mlir.dtype_to_ir_type(dtypes.dtype(lhs_dtype)) + rhs = mlir.dtype_to_ir_type(dtypes.dtype(rhs_dtype)) + acc = ir.F32Type.get() + return hlo.DotAlgorithm.get( + lhs, rhs, acc, 1, 1, 1, + self == DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM) + + else: + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + f64 = ir.F64Type.get() + bf16 = ir.BF16Type.get() + tf32 = ir.FloatTF32Type.get() + match self: + case DotAlgorithm.Preset.F16_F16_F16: + return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False) + case DotAlgorithm.Preset.F16_F16_F32: + return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False) + case DotAlgorithm.Preset.BF16_BF16_BF16: + return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False) + case DotAlgorithm.Preset.BF16_BF16_F32: + return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False) + case DotAlgorithm.Preset.BF16_BF16_F32_X3: + return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False) + case DotAlgorithm.Preset.BF16_BF16_F32_X6: + return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False) + case DotAlgorithm.Preset.TF32_TF32_F32: + return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False) + case DotAlgorithm.Preset.TF32_TF32_F32_X3: + return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False) + case DotAlgorithm.Preset.F32_F32_F32: + return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False) + case DotAlgorithm.Preset.F64_F64_F64: + return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False) + case _: + raise NotImplementedError("unreachable") + + +DotAlgorithmLike = Union[ + DotAlgorithm, + DotAlgorithm.Preset, + str, + None, +] +_DotAlgorithmLike = Union[ + DotAlgorithm, + DotAlgorithm.Preset, + None, +] +DotTransposeAlgorithmLike = Union[ + DotAlgorithmLike, + tuple[DotAlgorithmLike, DotAlgorithmLike], +] +DotTransposeAlgorithm = tuple[_DotAlgorithmLike, _DotAlgorithmLike] + + def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + algorithm: DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array: """Vector/vector, matrix/vector, and matrix/matrix multiplication. Wraps XLA's `Dot @@ -729,6 +918,17 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, preferred_element_type: Optional. Either ``None``, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. + algorithm: Optional. Specify the algorithm used for accumulating the dot + product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument + cannot be used with ``precision`` or ``preferred_element_type``. + transpose_algorithm: Optional. This allows specifying the algorithm used when + this operation is transposed, typically as part of reverse-mode automatic + differentiation. This argument can either be a single + :class:`~jax.lax.DotAlgorithm` or a tuple of two + :class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the + algorithm for transposing the LHS and RHS, respectively. + ``transpose_algorithm`` must be explicitly specified when transposing a + dot product where a specific ``algorithm`` was used on the forward pass. Returns: An array containing the product. @@ -736,7 +936,9 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]): return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())), precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + algorithm=algorithm, + transpose_algorithm=transpose_algorithm) else: raise TypeError("Incompatible shapes for dot: got {} and {}.".format( lhs.shape, rhs.shape)) @@ -747,7 +949,9 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + algorithm: DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array: """General dot product/contraction operator. Wraps XLA's `DotGeneral @@ -774,6 +978,17 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN preferred_element_type: Optional. Either ``None``, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. + algorithm: Optional. Specify the algorithm used for accumulating the dot + product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument + cannot be used with ``precision`` or ``preferred_element_type``. + transpose_algorithm: Optional. This allows specifying the algorithm used when + this operation is transposed, typically as part of reverse-mode automatic + differentiation. This argument can either be a single + :class:`~jax.lax.DotAlgorithm` or a tuple of two + :class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the + algorithm for transposing the LHS and RHS, respectively. + ``transpose_algorithm`` must be explicitly specified when transposing a + dot product where a specific ``algorithm`` was used on the forward pass. Returns: An array whose first dimensions are the (shared) batch dimensions, followed by @@ -791,7 +1006,9 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + algorithm=canonicalize_dot_algorithm(algorithm), + transpose_algorithm=canonicalize_dot_transpose_algorithm(transpose_algorithm)) def ragged_dot( @@ -2838,7 +3055,9 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type): def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None): (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): @@ -2914,7 +3133,10 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None): + del dimension_numbers, precision # unused # We're mostly matching XLA's logic here, namely in shape_inference.cc and # primitive_util.h's HigherPrecisionType, e.g. # https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329 @@ -2936,6 +3158,21 @@ def type_properties(dt): f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}") result_dtype = lhs.dtype + if transpose_algorithm is not None and algorithm is None: + raise ValueError( + "When the algorithm argument to dot_general is None, the " + "transpose_algorithm argument is unused and must also be None.") + + if algorithm is not None and algorithm != DotAlgorithm.Preset.DEFAULT: + if preferred_element_type is not None: + raise ValueError( + "The preferred_element_type and algorithm arguments to dot_general " + "cannot both be specified.") + + # This is used to ensure that the output type is equal to the accumulation + # type whenever an algorithm is specified. + preferred_element_type = algorithm.accumulation_type + return _maybe_upcast(result_dtype, preferred_element_type) def _bit_width(d): @@ -2959,6 +3196,8 @@ def _maybe_upcast(result_dtype, preferred_element_type): def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None, swap_ans=False): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim @@ -2971,20 +3210,35 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, dims = ((ans_y, y_kept), (ans_batch, y_batch)) x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) + if algorithm is not None: + if transpose_algorithm is None or transpose_algorithm[0] is None: + raise ValueError( + "When a dot_general algorithm is specified on the forward pass, " + "transpose_algorithm must be specified for the backward pass.") + lhs_alg, rhs_alg = transpose_algorithm + transpose_algorithm = (algorithm, rhs_alg) + algorithm = lhs_alg x_bar = transpose(dot_general(g, y, dims, precision=precision, - preferred_element_type=preferred_element_type), + preferred_element_type=preferred_element_type, + algorithm=algorithm, + transpose_algorithm=transpose_algorithm), tuple(out_axes)) if x_bar.dtype != x.aval.dtype: x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type) return x_bar def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) + transpose_algorithm = None if transpose_algorithm is None else ( + transpose_algorithm[1], transpose_algorithm[0]) y_bar = _dot_general_transpose_lhs( g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision, - preferred_element_type=preferred_element_type, + preferred_element_type=preferred_element_type, algorithm=algorithm, + transpose_algorithm=transpose_algorithm, swap_ans=True) if y_bar.dtype != y.aval.dtype: y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type) @@ -2992,7 +3246,9 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None): lhs, rhs = batched_args lbd, rbd = batch_dims (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers @@ -3018,7 +3274,9 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, rhs_shape = np.shape(rhs) batched_out = dot_general(lhs, rhs, new_dimension_numbers, precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + algorithm=algorithm, + transpose_algorithm=transpose_algorithm) result_batch_dim = batching.shape_as_bdim( result_stack_dim, _dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers)) @@ -3115,9 +3373,19 @@ def precision_attr(precision: Precision) -> ir.ArrayAttr: [hlo.PrecisionAttr.get(str(p)) for p in full_precision]) +def dot_algorithm_attr(algorithm: _DotAlgorithmLike, lhs_dtype: DTypeLike, + rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None: + if algorithm is None: + return None + return algorithm._convert_to_hlo_attr(lhs_dtype, rhs_dtype) + + def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: np.dtype | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None, platform: str = "default"): + del transpose_algorithm # unused def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) @@ -3158,13 +3426,30 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): rhs_batching_dimensions=list(rhs_batch), lhs_contracting_dimensions=list(lhs_contracting), rhs_contracting_dimensions=list(rhs_contracting)) + + if algorithm is not None and precision not in { + None, Precision.DEFAULT, (Precision.DEFAULT, Precision.DEFAULT)}: + raise ValueError( + "The dot_general precision must be None or DEFAULT when an algorithm " + "is specified.") + if jaxlib_version <= (0, 4, 33): + if algorithm is not None: + raise ValueError( + "The dot_general algorithm parameter is only supported for jaxlib " + "versions larger than 0.4.33.") + algorithm_kwargs = {} + else: + algorithm_kwargs = {"algorithm": dot_algorithm_attr(algorithm, lhs_dtype, + rhs_dtype)} return [ hlo.dot_general( mlir.aval_to_ir_type(aval_out), lhs, rhs, dot_dnums, - precision_config=precision_attr(precision)) + precision_config=precision_attr(precision), + **algorithm_kwargs, + ) ] mlir.register_lowering(dot_general_p, _dot_general_lower) @@ -3189,11 +3474,13 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S _RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], [])) def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, - precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype: + precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype: if not dtypes.issubdtype(group_sizes.dtype, np.integer): raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.") # defer the output dtype to dot_general, which is part of the _ragged_dot_impl. - return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type) + return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + precision=precision, preferred_element_type=preferred_element_type, + algorithm=None, transpose_algorithm=None) def _ragged_dot_jvp_rule( @@ -5387,6 +5674,29 @@ def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precisi "a lax.Precision value or a tuple of two lax.Precision values or " f"strings; got {precision}.") +def canonicalize_dot_algorithm(algorithm: DotAlgorithmLike) -> _DotAlgorithmLike: + if isinstance(algorithm, str): + algorithm = DotAlgorithm.Preset[algorithm] + if algorithm is None or algorithm == DotAlgorithm.Preset.DEFAULT: + return None + return algorithm + +def canonicalize_dot_transpose_algorithm( + algorithm: DotTransposeAlgorithmLike) -> DotTransposeAlgorithm | None: + if algorithm is None: + return None + elif isinstance(algorithm, DotAlgorithm): + return (algorithm, algorithm) + elif isinstance(algorithm, tuple): + if len(algorithm) != 2: + raise ValueError( + "The transpose_algorithm argument must be a single value or a tuple " + f"of two values; got {algorithm}.") + return (canonicalize_dot_algorithm(algorithm[0]), + canonicalize_dot_algorithm(algorithm[1])) + algorithm = canonicalize_dot_algorithm(algorithm) + return (algorithm, algorithm) + def _balanced_eq(x, z, y): return div(select(_eq_meet(x, z), _ones(z), _zeros(z)), select(_eq_meet(y, z), _twos(z), _ones(z))) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 4722a31db92c..9db5e4081239 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -2094,8 +2094,10 @@ def _dot_general_lowering( dimension_numbers, precision, preferred_element_type, + algorithm, + transpose_algorithm, ): - del preferred_element_type # Unused. + del preferred_element_type, algorithm, transpose_algorithm # Unused. ((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers assert batch_dims == ((), ()) diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 310cbaab6d59..3c51e5d63f25 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -364,12 +364,15 @@ def _conv_general_dilated( def _dot_general(lhs, rhs, *, dimension_numbers, precision: tuple[PrecisionType, PrecisionType] | None, preferred_element_type: DType | None, + algorithm: Any, transpose_algorithm: Any, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" # Unused arguments. del precision del preferred_element_type + del algorithm + del transpose_algorithm lhs, rhs, convert_result = jax2tf._dot_general_convert_to_common_dtype( lhs, _in_avals[0], rhs, _in_avals[1], _out_aval) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 525b163a6140..bc533e6d145b 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2176,9 +2176,12 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None): def _dot_general(lhs, rhs, *, dimension_numbers, precision: tuple[PrecisionType, PrecisionType] | None, preferred_element_type: DType | None, + algorithm: Any, transpose_algorithm: Any, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" + del algorithm, transpose_algorithm # unused + # TODO(b/293247337): we ought to turn on this safety check, but this leads to # failures. Since we are going to turn on native serializaton soon, wait # until then to turn on this check. diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 9f2f0f69be63..b20ed8da0326 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -609,7 +609,8 @@ def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation: Sequenc bcoo_dot_general_p = core.Primitive('bcoo_dot_general') def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers, - precision: None = None, preferred_element_type: None = None) -> BCOO | Array: + precision: None = None, preferred_element_type: None = None, + algorithm: None = None, transpose_algorithm: None = None) -> BCOO | Array: """A general contraction operation. Args: @@ -620,6 +621,8 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: (lhs_batch_dims, rhs_batch_dims))`. precision: unused preferred_element_type: unused + algorithm: unused + transpose_algorithm: unused Returns: An ndarray or BCOO-format sparse array containing the result. If both inputs @@ -627,7 +630,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: the result will be dense, of type ndarray. """ # TODO(jakevdp) make use of these? - del precision # unused + del precision, algorithm, transpose_algorithm # unused if isinstance(lhs, BCOO) and isinstance(rhs, BCOO): shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers) @@ -1053,7 +1056,9 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers) indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True) kwds = {'dimension_numbers': dimension_numbers, 'precision': None, - 'preferred_element_type': None} + 'preferred_element_type': None, + 'algorithm': None, + 'transpose_algorithm': None} A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds) return A, B, indices diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7275d6bb20aa..1b877aec9c75 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -463,7 +463,9 @@ def _bcsr_extract_batching_rule(batched_args, batch_dims): def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, dimension_numbers: DotDimensionNumbers, precision: None = None, - preferred_element_type: None = None) -> Array: + preferred_element_type: None = None, + algorithm: None = None, + transpose_algorithm: None = None) -> Array: """A general contraction operation. Args: @@ -474,13 +476,15 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, (lhs_batch_dims, rhs_batch_dims))`. precision: unused preferred_element_type: unused + algorithm: unused + transpose_algorithm: unused Returns: An ndarray or BCSR-format sparse array containing the result. If both inputs are sparse, the result will be sparse, of type BCSR. If either input is dense, the result will be dense, of type ndarray. """ - del precision # unused + del precision, algorithm, transpose_algorithm # unused if isinstance(rhs, (np.ndarray, jax.Array)): if isinstance(lhs, (np.ndarray, jax.Array)): return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers, diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 9aa9e42f2a60..c79dee09cec2 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -113,4 +113,5 @@ def _dot_general_validated_shape( rhs = core.ShapedArray(rhs_shape, np.float32) return _dot_general_shape_rule( lhs, rhs, dimension_numbers=dimension_numbers, - precision=None, preferred_element_type=None) + precision=None, preferred_element_type=None, algorithm=None, + transpose_algorithm=None) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 293bd02446fe..bb72abb2ec32 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -19,6 +19,9 @@ DotDimensionNumbers as DotDimensionNumbers, Precision as Precision, PrecisionLike as PrecisionLike, + DotAlgorithm as DotAlgorithm, + DotAlgorithmLike as DotAlgorithmLike, + DotTransposeAlgorithmLike as DotTransposeAlgorithmLike, RandomAlgorithm as RandomAlgorithm, RoundingMethod as RoundingMethod, abs as abs, diff --git a/tests/lax_test.py b/tests/lax_test.py index a2d4c939df55..e8f3b730e3da 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -41,10 +41,12 @@ from jax._src import dtypes from jax._src import lax_reference from jax._src import test_util as jtu +from jax._src import xla_bridge from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal +from jax._src.lib import version as jaxlib_version from jax._src.util import NumpyComplexWarning, safe_zip from jax._src.tree_util import tree_map @@ -1041,6 +1043,178 @@ def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker) + @parameterized.parameters([ + (algorithm, dtype) + for algorithm, test_dtypes in [ + (lax.DotAlgorithm( + lhs_precision_type=np.float32, + rhs_precision_type=np.float32, + accumulation_type=np.float32, + lhs_component_count=1, + rhs_component_count=1, + num_primitive_operations=1, + allow_imprecise_accumulation=False, + ), [np.float32]), + (lax.DotAlgorithm( + lhs_precision_type=np.float16, + rhs_precision_type=np.float16, + accumulation_type=np.float32, + ), [np.float16]), + ("F16_F16_F32", [np.float16]), + (lax.DotAlgorithm.Preset.DEFAULT, lax_test_util.float_dtypes), + (lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes), + (lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes), + (lax.DotAlgorithm.Preset.F16_F16_F16, [np.float16]), + (lax.DotAlgorithm.Preset.F16_F16_F32, [np.float16]), + (lax.DotAlgorithm.Preset.BF16_BF16_BF16, [dtypes.bfloat16]), + (lax.DotAlgorithm.Preset.BF16_BF16_F32, [dtypes.bfloat16]), + (lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, [np.float32]), + (lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, [np.float32]), + (lax.DotAlgorithm.Preset.TF32_TF32_F32, [np.float32]), + (lax.DotAlgorithm.Preset.TF32_TF32_F32_X3, [np.float32]), + (lax.DotAlgorithm.Preset.F32_F32_F32, [np.float32]), + (lax.DotAlgorithm.Preset.F64_F64_F64, [np.float64]), + ] for dtype in test_dtypes + if jtu.dtypes.supported([dtype]) + ]) + def testDotAlgorithm(self, algorithm, dtype): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest( + "The dot algorithm attribute is not supported by PJRT C API.") + if jaxlib_version <= (0, 4, 33): + raise SkipTest( + "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + if jtu.test_device_matches(["gpu"]): + # GPU algorithm support is a little spotty. It is checked in + # xla/service/algorithm_util.cc and the logic is copied here. + if algorithm in { + lax.DotAlgorithm.Preset.F16_F16_F32, + lax.DotAlgorithm.Preset.TF32_TF32_F32, + lax.DotAlgorithm.Preset.BF16_BF16_F32, + lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, # Must have f32 input + lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, # Must have f32 input + }: + if not jtu.is_cuda_compute_capability_at_least("8.0"): + raise SkipTest( + f"The dot algorithm '{algorithm}' requires CUDA compute " + "capability >= 8.0.") + elif algorithm not in { + lax.DotAlgorithm.Preset.DEFAULT, + lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, + lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM, + lax.DotAlgorithm.Preset.F32_F32_F32, + lax.DotAlgorithm.Preset.F64_F64_F64, + }: + raise SkipTest( + f"The dot algorithm '{algorithm}' is not supported on GPU.") + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + self._CompileAndCheck(partial(lax.dot, algorithm=algorithm), args_maker) + # Check that accumulation type sets the output type + output = lax.dot(*args_maker(), algorithm=algorithm) + algorithm = lax_internal.canonicalize_dot_algorithm(algorithm) + expected_dtype = dtype if algorithm is None else algorithm.accumulation_type + self.assertEqual(output.dtype, expected_dtype) + + def testDotAlgorithmInvalidFloat8Type(self): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest( + "The dot algorithm attribute is not supported by PJRT C API.") + if jaxlib_version <= (0, 4, 33): + raise SkipTest( + "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, dtypes.float8_e4m3fn) + with self.assertRaisesRegex(ValueError, "The dot algorithm"): + lax.dot(lhs, rhs, algorithm="ANY_F8_ANY_F8_F32") + + @parameterized.parameters([ + ({"precision": lax.Precision.HIGHEST}, "The dot_general precision must be None or DEFAULT"), + ({"preferred_element_type": np.float32}, "The preferred_element_type and algorithm arguments"), + ]) + def testDotAlgorithmInvalidParameters(self, kwargs, pattern): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest( + "The dot algorithm attribute is not supported by PJRT C API.") + if jaxlib_version <= (0, 4, 33): + raise SkipTest( + "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) + with self.assertRaisesRegex(ValueError, pattern): + lax.dot(lhs, rhs, algorithm="F32_F32_F32", **kwargs) + + def testDotAlgorithmTransposeRequired(self): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest( + "The dot algorithm attribute is not supported by PJRT C API.") + if jaxlib_version <= (0, 4, 33): + raise SkipTest( + "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) + fun = partial(lax.dot, algorithm="F32_F32_F32") + out = fun(lhs, rhs) + _, vjp_fun = jax.vjp(fun, lhs, rhs) + with self.assertRaisesRegex( + ValueError, "When a dot_general algorithm is specified"): + vjp_fun(out) + + @parameterized.parameters([ + ("F32_F32_F32", "F16_F16_F32"), + ("F32_F32_F32", ("F16_F16_F32", "F64_F64_F64")), + ]) + def testDotAlgorithmTranspose(self, algorithm, transpose_algorithm): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest( + "The dot algorithm attribute is not supported by PJRT C API.") + if jaxlib_version <= (0, 4, 33): + raise SkipTest( + "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + def fun(x, y): + return lax.dot(x, y, algorithm=algorithm, + transpose_algorithm=transpose_algorithm) + + algorithm_ = lax_internal.canonicalize_dot_algorithm(algorithm) + lhs_alg, rhs_alg = lax_internal.canonicalize_dot_transpose_algorithm( + transpose_algorithm) + + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) + out = fun(lhs, rhs) + + def check_transpose_algorithm(f, arg, alg, trans_alg, trans_trans_alg): + fun_trans = jax.linear_transpose(f, arg) + jaxpr = jax.make_jaxpr(fun_trans)(out) + eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr.eqns)) + self.assertEqual(eqn.params["algorithm"], alg) + self.assertEqual(eqn.params["transpose_algorithm"], trans_alg) + + fun_ = jax.linear_transpose(lambda x: fun_trans(x)[0], out) + jaxpr_ = jax.make_jaxpr(fun_)(arg) + eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr_.eqns)) + self.assertEqual(eqn.params["algorithm"], algorithm_) + + # Note that transposing the RHS of a dot_general introduce extra + # transposes on the input and output, so we don't actually end up with + # the same `transpose_algorithm` parameter after 2 transposes. + self.assertEqual(eqn.params["transpose_algorithm"], trans_trans_alg) + + check_transpose_algorithm(partial(fun, y=rhs), lhs, lhs_alg, + (algorithm_, rhs_alg), (lhs_alg, rhs_alg)) + check_transpose_algorithm(partial(fun, lhs), rhs, rhs_alg, + (algorithm_, lhs_alg), (rhs_alg, lhs_alg)) + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]], From cdea3d40508fbdae6fedf8eb5b3ee7c1df8f9090 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 25 Sep 2024 06:34:39 -0700 Subject: [PATCH 654/702] lax.fori_loop now allows scalars in its cary when lowering to Mosaic GPU PiperOrigin-RevId: 678677508 --- jax/_src/pallas/mosaic_gpu/lowering.py | 9 +++++---- tests/pallas/mosaic_gpu_test.py | 17 ++++++++++++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 80222dbaea22..ca961ac5848a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -798,12 +798,13 @@ def _scan_lowering_rule( _consts_avals, arg_avals = util.split_list(ctx.avals_in, [num_consts]) if has_loop_index: start, *args = args - index_aval, *_arg_avals = arg_avals - start = _ensure_ir_value(start, index_aval) + index_aval, *arg_avals = arg_avals + start = _ensure_ir_value(start, index_aval.dtype) length = _ir_constant(length, start.type) else: start = _i32_constant(0) length = _i32_constant(length) + args = map(lambda arg, aval: _ensure_fa(arg, aval.dtype), args, arg_avals) for_out = _lower_jaxpr_to_for_loop( ctx, jaxpr, start, length, consts, *args, has_loop_index=has_loop_index ) @@ -853,11 +854,11 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: raise NotImplementedError(f"Unsupported type: {type(x)}") -def _ensure_ir_value(x: object, aval: jax_core.ShapedArray) -> ir.Value: +def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value: if isinstance(x, ir.Value): return x elif isinstance(x, (np.number, np.ndarray, int, float)): - return _ir_constant(x, mgpu_utils.dtype_to_ir_type(aval.dtype)) + return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) raise NotImplementedError(f"Unsupported type: {type(x)}") diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 2f247ca60cff..cdca977bf122 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -328,7 +328,7 @@ def kernel(x_ref, o_ref): result = kernel(x) self.assertEqual(result.shape, (4, 2, 64, 64)) - def test_fori_loop(self): + def test_fori_loop_array(self): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), @@ -340,6 +340,21 @@ def kernel(x_ref, o_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 2.0 + 3.0) + def test_fori_loop_scalar(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(o_ref): + # Equivalent to 2 + 3. + o_ref[...] = jax.lax.broadcast( + jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0.0), o_ref.shape + ) + + np.testing.assert_array_equal( + kernel(), jnp.full([256], 5.0, dtype=jnp.float32) + ) + def test_wgmma(self): dtype = jnp.float16 swizzle = 128 From 390b0ba4a6432f46cc2b8495edc2c9f9e91f7dc5 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 25 Sep 2024 06:59:34 -0700 Subject: [PATCH 655/702] [pallas::mosaic_gpu] Support for tiled transpose transforms. For the time being this feature only supports 2D on the GMEM side and 4D after tiling on the SMEM side. PiperOrigin-RevId: 678683983 --- jax/_src/pallas/mosaic_gpu/core.py | 23 +++++++++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 12 ++++++++---- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index b6d2ada5ee28..c4489226c860 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -105,6 +105,26 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) +@dataclasses.dataclass(frozen=True) +class TransposeTransform(MemoryRefTransform): + """Transpose a tiled memref.""" + + permutation: tuple[int, ...] + + def __call__( + self, block_aval: pallas_core.AbstractMemoryRef + ) -> pallas_core.AbstractMemoryRef: + shape = block_aval.shape # pytype: disable=attribute-error + return block_aval.update( + inner_aval=block_aval.inner_aval.update( + shape=self.to_gpu_transform().transform_shape(shape) + ) + ) + + def to_gpu_transform(self) -> mgpu.MemRefTransform: + return mgpu.TransposeTransform(self.permutation) + + @dataclasses.dataclass(frozen=True) class GPUBlockMapping(pallas_core.BlockMapping): swizzle: int | None = None @@ -114,6 +134,7 @@ class GPUBlockMapping(pallas_core.BlockMapping): class GPUBlockSpec(pallas_core.BlockSpec): # TODO(justinfu): Replace tiling a list of transforms. tiling: tuple[int, ...] | None = None + transpose_permutation: tuple[int, ...] | None = None swizzle: int | None = None def to_block_mapping( @@ -137,6 +158,8 @@ def to_block_mapping( transforms: tuple[pallas_core.MemoryRefTransform, ...] = () if self.tiling is not None: transforms += (TilingTransform(self.tiling),) + if self.transpose_permutation is not None: + transforms += (TransposeTransform(self.transpose_permutation),) return GPUBlockMapping( block_shape=bm.block_shape, block_aval=bm.block_aval, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index cdca977bf122..d2a5ba1511e0 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -355,13 +355,16 @@ def kernel(o_ref): kernel(), jnp.full([256], 5.0, dtype=jnp.float32) ) - def test_wgmma(self): - dtype = jnp.float16 + @parameterized.parameters(jnp.float16, jnp.float32) + def test_wgmma(self, dtype): + # TensorCores can only fuse transposes of 16-bit values, and RHS + # is expected to be column major by default. + rhs_transpose = jnp.dtype(dtype).itemsize != 2 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize def kernel(a_ref, b_ref, o_ref): acc = plgpu.zero_accumulator((64, 128), jnp.float32) - acc = plgpu.wgmma(acc, a_ref, b_ref, rhs_transpose=False) + acc = plgpu.wgmma(acc, a_ref, b_ref, rhs_transpose=rhs_transpose) plgpu.wgmma_wait(0) # TODO(cperivol): turn acc into a reference so we can reason about effects. o_ref[...] = acc.as_array() @@ -382,6 +385,7 @@ def kernel(a_ref, b_ref, o_ref): plgpu.GPUBlockSpec( (128, 128), lambda *i: i, + transpose_permutation=(1, 0, 2, 3) if rhs_transpose else None, tiling=(elems_128b, elems_128b), swizzle=128, ), @@ -391,7 +395,7 @@ def kernel(a_ref, b_ref, o_ref): grid=(1, 1), )(a, b) np.testing.assert_allclose( - res, a @ b, rtol=1e-3 + res, a @ (b.T if rhs_transpose else b), rtol=1e-3 ) def test_input_output_aliases(self): From d556b592558b9691552f6f3717e8a0e0cd744048 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:24:13 +0000 Subject: [PATCH 656/702] Bump actions/setup-python from 5.1.1 to 5.2.0 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5.1.1 to 5.2.0. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/39cd14951b08e74b54015e9e001cdefcf80e669f...f677139bbe7f9c59b41e40162b753c062f5d49a3) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/ci-build.yaml | 12 ++++++------ .github/workflows/jax-array-api.yml | 2 +- .github/workflows/upstream-nightly.yml | 4 ++-- .github/workflows/wheel_win_x64.yml | 2 +- .github/workflows/windows_ci.yml | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index f5c1a1d6348c..315db489a818 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -31,7 +31,7 @@ jobs: steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python 3.11 - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: 3.11 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet: pre-commit/action@v3.0.1 @@ -59,7 +59,7 @@ jobs: steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -106,7 +106,7 @@ jobs: steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -143,7 +143,7 @@ jobs: steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -179,7 +179,7 @@ jobs: steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -218,7 +218,7 @@ jobs: steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python 3.11 - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: 3.11 - name: Get pip cache dir diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index bbbe53732a69..f1dc8eee8a75 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -32,7 +32,7 @@ jobs: submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 79c1e22d2cea..2bdd8ba5192e 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -38,7 +38,7 @@ jobs: steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements @@ -107,7 +107,7 @@ jobs: shell: bash steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: "3.x" - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # ratchet:actions/download-artifact@v4 diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 0195032ceaf1..f4fb7727da6b 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -27,7 +27,7 @@ jobs: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.pyver }} cache: 'pip' diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 7097d5589426..03a6876cdbb1 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -35,7 +35,7 @@ jobs: with: path: jax - - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.pyver }} cache: 'pip' From a43c7f2aceec94214be4eaec857a981837257140 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 25 Sep 2024 07:37:04 -0700 Subject: [PATCH 657/702] Enable more H100 tests in CI. Rename "gpu" config CI tag to "gpu_v100". PiperOrigin-RevId: 678695003 --- benchmarks/mosaic/BUILD | 2 +- tests/BUILD | 8 ++++---- tests/mosaic/BUILD | 7 +------ tests/pallas/BUILD | 18 +++++++++--------- 4 files changed, 15 insertions(+), 20 deletions(-) diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD index 4345e620a3ae..727e347e5a64 100644 --- a/benchmarks/mosaic/BUILD +++ b/benchmarks/mosaic/BUILD @@ -34,7 +34,7 @@ DISABLED_BACKENDS = [ ] DISABLED_CONFIGS = [ - "gpu", + "gpu_v100", "gpu_a100", "gpu_p100", "gpu_p100_x32", diff --git a/tests/BUILD b/tests/BUILD index 0cc6ed6d9d8c..11e382b1a95c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1270,7 +1270,7 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "gpu", + "gpu_h100", "cpu", ], tags = ["multiaccelerator"], @@ -1280,7 +1280,7 @@ jax_multiplatform_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], enable_configs = [ - "gpu", + "gpu_h100", "cpu", ], ) @@ -1301,7 +1301,7 @@ jax_multiplatform_test( name = "debugger_test", srcs = ["debugger_test.py"], enable_configs = [ - "gpu", + "gpu_h100", "cpu", ], ) @@ -1317,7 +1317,7 @@ jax_multiplatform_test( "tpu_pjrt_c_api": ["--jax_num_generated_cases=1"], }, enable_configs = [ - "gpu", + "gpu_h100", "cpu", ], shard_count = { diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 9eadc08d4987..4d33e228b906 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -39,18 +39,13 @@ DISABLED_CONFIGS = [ "gpu_p100", "gpu_p100_x32", "gpu_pjrt_c_api", + "gpu_v100", "gpu_x32", - "gpu", ] jax_multiplatform_test( name = "gpu_test", srcs = ["gpu_test.py"], - config_tags_overrides = { - "gpu_h100": { - "ondemand": False, # Include in presubmit. - }, - }, disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, enable_configs = [ diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index e535f1f59dac..ba82b8c4223c 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -39,7 +39,7 @@ jax_multiplatform_test( }, }, disable_configs = [ - "gpu", + "gpu_v100", "gpu_x32", "gpu_p100", "gpu_p100_x32", @@ -68,7 +68,7 @@ jax_multiplatform_test( "pallas_jumble_test.py", ], disable_configs = [ - "gpu", + "gpu_v100", "gpu_x32", "gpu_a100", "gpu_p100", @@ -96,7 +96,7 @@ jax_multiplatform_test( }, }, disable_configs = [ - "gpu", + "gpu_v100", "gpu_x32", "gpu_p100", "gpu_p100_x32", @@ -155,7 +155,7 @@ jax_multiplatform_test( }, }, disable_configs = [ - "gpu", + "gpu_v100", "gpu_x32", "gpu_a100", "gpu_h100", @@ -191,7 +191,7 @@ jax_multiplatform_test( "tpu", ], disable_configs = [ - "gpu", + "gpu_v100", "gpu_x32", "gpu_a100", "gpu_a100_x32", @@ -222,7 +222,7 @@ jax_multiplatform_test( }, }, disable_configs = [ - "gpu", + "gpu_v100", "gpu_x32", "gpu_a100", "gpu_h100", @@ -253,7 +253,7 @@ jax_multiplatform_test( }, }, disable_configs = [ - "gpu", + "gpu_v100", "gpu_x32", "gpu_a100", "gpu_h100", @@ -537,7 +537,7 @@ jax_multiplatform_test( "tpu", ], disable_configs = [ - "gpu", + "gpu_v100", "gpu_x32", "gpu_p100", "gpu_p100_x32", @@ -570,7 +570,7 @@ jax_multiplatform_test( "tpu", ], disable_configs = [ - "gpu", + "gpu_v100", "gpu_x32", "gpu_a100", "gpu_h100", From 60a06fd4c9b9867648d5584e1dccb928933c5336 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:55:46 +0000 Subject: [PATCH 658/702] Update pillow version in JAX build test-requirements.txt --- build/test-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 0c9aa086f109..bec6afce1853 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -7,7 +7,7 @@ flatbuffers hypothesis mpmath>=1.3 numpy>=1.22 -pillow>=9.1.0 +pillow>=10.4.0 portpicker pytest-xdist wheel From a373e37be2176bbb3a4b657dc7346574107dec44 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 25 Sep 2024 08:49:36 -0700 Subject: [PATCH 659/702] Fixed `mgpu.FragmentedArray.reduce_sum` for integer types The implementation previously assumed the type is floating and used addf. PiperOrigin-RevId: 678718871 --- .../mosaic/gpu/fragmented_array.py | 16 +++++----- tests/mosaic/gpu_test.py | 30 +++++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 1c1ec18d3cf2..ea45a7dcb7b9 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -660,12 +660,19 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): ) def reduce_sum(self, scratch) -> ir.Value: + if ir.FloatType.isinstance(self.mlir_dtype): + op = arith.addf + elif ir.IntegerType.isinstance(self.mlir_dtype): + op = arith.addi + else: + raise NotImplementedError(self.mlir_dtype) + index = ir.IndexType.get() if not isinstance(self.layout, WGStridedFragLayout): raise NotImplementedError(f"Unsupported layout {self.layout}") result = c(0, self.mlir_dtype) for reg in self.registers: - result = arith.addf( + result = op( result, vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg), ) @@ -673,13 +680,6 @@ def reduce_sum(self, scratch) -> ir.Value: if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]: raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})") - if ir.FloatType.isinstance(self.mlir_dtype): - op = arith.addf - elif ir.IntegerType.isinstance(self.mlir_dtype): - op = arith.addi - else: - raise NotImplementedError(self.mlir_dtype) - warp_result = utils.warp_tree_reduce(result, op, 32) warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index)) memref.store(warp_result, scratch, [warp_id]) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 30f830c31ccf..f949b63c7844 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1311,6 +1311,36 @@ def kernel(ctx, dst, _): rtol = 4e-6 if approx else 2e-7 np.testing.assert_allclose(result, np_op(x), atol=atol, rtol=rtol) + @parameterized.product( + dtype=[jnp.float32, jnp.int32], + m=[128], + n=[32, 64], + ) + def test_reduce_sum(self, dtype, m, n): + def kernel(ctx, src, dst, scratch): + src = mgpu.FragmentedArray.load_strided( + src, is_signed=utils.is_signed(dtype) + ) + acc = mgpu.FragmentedArray.splat( + src.reduce_sum(scratch), + (m,), + is_signed=src.is_signed + ) + acc.store_untiled(dst) + + in_shape = jax.ShapeDtypeStruct((m, n), dtype) + out_shape = jax.ShapeDtypeStruct((m,), dtype) + kernel_fn = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + in_shape, + out_shape, + smem_scratch_shape=jax.ShapeDtypeStruct((4,), dtype), + ) + x = np.arange(m * n, dtype=dtype).reshape(m, n) + np.testing.assert_array_equal(kernel_fn(x), jnp.full((m,), x.sum())) + @parameterized.product( op=(arith.addf, arith.maximumf), m=(64, 128), From 19494137392510078965c6769d00db705c259410 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 25 Sep 2024 08:53:47 -0700 Subject: [PATCH 660/702] Increase sharding of checkify_test on TPU to fix CI flakes. PiperOrigin-RevId: 678720498 --- tests/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index 11e382b1a95c..46345b6475d9 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1072,7 +1072,7 @@ jax_multiplatform_test( srcs = ["checkify_test.py"], shard_count = { "gpu": 2, - "tpu": 2, + "tpu": 4, }, ) From 13774d1382322ffa410a72fc203767f77da66440 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 25 Sep 2024 21:26:05 +0530 Subject: [PATCH 661/702] Fix Typos --- docs/key-concepts.md | 8 ++++---- docs/quickstart.md | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/key-concepts.md b/docs/key-concepts.md index b87808d14faf..daab2c9fdde4 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -23,13 +23,13 @@ This section briefly introduces some key concepts of the JAX package. ## JAX arrays ({class}`jax.Array`) The default array implementation in JAX is {class}`jax.Array`. In many ways it is similar to -the {class}`numpy.ndarray` type that you may be familar with from the NumPy package, but it +the {class}`numpy.ndarray` type that you may be familiar with from the NumPy package, but it has some important differences. ### Array creation We typically don't call the {class}`jax.Array` constructor directly, but rather create arrays via JAX API functions. -For example, {mod}`jax.numpy` provides familar NumPy-style array construction functionality +For example, {mod}`jax.numpy` provides familiar NumPy-style array construction functionality such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc. ```{code-cell} @@ -147,10 +147,10 @@ jaxprs later in {ref}`jax-internals-jaxpr`. ## Pytrees JAX functions and transformations fundamentally operate on arrays, but in practice it is -convenient to write code that work with collections of arrays: for example, a neural +convenient to write code that works with collection of arrays: for example, a neural network might organize its parameters in a dictionary of arrays with meaningful keys. Rather than handle such structures on a case-by-case basis, JAX relies on the {term}`pytree` -abstraction to treat such collections in a uniform matter. +abstraction to treat such collections in a uniform manner. Here are some examples of objects that can be treated as pytrees: diff --git a/docs/quickstart.md b/docs/quickstart.md index e19cb33ea9c5..77cbb9d46ab8 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -16,7 +16,7 @@ kernelspec: -**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. +**JAX is a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. This document provides a quick overview of essential JAX features, so you can get started with JAX quickly: @@ -88,8 +88,8 @@ _ = selu_jit(x) # compiles on first call %timeit selu_jit(x).block_until_ready() ``` -The above timing represent execution on CPU, but the same code can be run on GPU or TPU, -typically for an even greater speedup. +The above timing represents execution on CPU, but the same code can be run on GPU or +TPU, typically for an even greater speedup. For more on JIT compilation in JAX, check out {ref}`jit-compilation`. @@ -183,7 +183,7 @@ print('Naively batched') %timeit naively_batched_apply_matrix(batched_x).block_until_ready() ``` -A programmer familiar with the the `jnp.dot` function might recognize that `apply_matrix` can +A programmer familiar with the `jnp.dot` function might recognize that `apply_matrix` can be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`: ```{code-cell} From b49d8b2615ed9df8ef36e176e8ed36fa32a1d65c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 25 Sep 2024 09:10:12 -0700 Subject: [PATCH 662/702] Fixed `pl.debug_print`ing of scalar fragmented arrays under Mosaic GPU PiperOrigin-RevId: 678726245 --- jax/_src/pallas/mosaic_gpu/lowering.py | 18 +++++++++--- jax/experimental/mosaic/gpu/__init__.py | 1 + tests/pallas/mosaic_gpu_test.py | 37 ++++++++++++++++++------- 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ca961ac5848a..702f5ca17e61 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -708,12 +708,19 @@ def _debug_print_lowering_rule( del has_placeholders # Unused. primitives.check_debug_print_format(fmt, *args) if not any(aval.shape for aval in ctx.avals_in): - mgpu.debug_print(fmt, *args) + mgpu.debug_print( + fmt, + *( + _ensure_ir_value(arg, aval.dtype) + for arg, aval in zip(args, ctx.avals_in) + ), + ) elif len(ctx.avals_in) == 1: - @args[0].foreach + [arg] = args + @arg.foreach def _(val, idx): idx_fmt = ", ".join(["{}"] * len(idx)) - fmt_str = fmt.format(f"[{idx_fmt}]/{list(args[0].shape)}: {{}}") + fmt_str = fmt.format(f"[{idx_fmt}]/{list(arg.shape)}: {{}}") mgpu.debug_print(fmt_str, *idx, val, uniform=False) else: raise NotImplementedError( @@ -799,7 +806,7 @@ def _scan_lowering_rule( if has_loop_index: start, *args = args index_aval, *arg_avals = arg_avals - start = _ensure_ir_value(start, index_aval.dtype) + start: ir.Value = _ensure_ir_value(start, index_aval.dtype) length = _ir_constant(length, start.type) else: start = _i32_constant(0) @@ -859,6 +866,9 @@ def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value: return x elif isinstance(x, (np.number, np.ndarray, int, float)): return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) + elif isinstance(x, mgpu.FragmentedArray): + if isinstance(x.layout, mgpu.WGSplatFragLayout): + return x.registers.item() raise NotImplementedError(f"Unsupported type: {type(x)}") diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 21c7f666b233..8057f97a7dfc 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -30,6 +30,7 @@ FragmentedLayout, WGMMA_LAYOUT, WGMMA_ROW_LAYOUT, + WGSplatFragLayout, WGStridedFragLayout, ) from .utils import ( diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d2a5ba1511e0..0eb7d91960ff 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -227,35 +227,52 @@ def kernel(x_ref, o_ref): self.assertEqual(output(), "It works!\n") - def test_print_with_values(self): + def test_print_scalar(self): @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): del o_ref - pl.debug_print("x[0] = {}", x_ref[0]) + pl.debug_print("x.sum() = {}", x_ref[...].sum()) - x = jnp.arange(256).astype(jnp.float32) - with self.assertRaises(Exception): - # TODO(slebedev): Remove assertRaises() once we support indexing. - kernel(x) + x = jnp.arange(256) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertIn(f"x.sum() = {x.sum()}", output()) + + def test_print_scalar_array(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(x_ref, o_ref): + del o_ref + pl.debug_print("x.sum() = {}", x_ref[...].sum() + 1) + + x = jnp.arange(256) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertIn(f"x.sum() = {x.sum() + 1}", output()) def test_print_array(self): in_shape = [2, 1, 64, 64] + @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(in_shape, jnp.float32), + out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32), ) def kernel(x_ref, o_ref): del o_ref pl.debug_print("x: {}", x_ref[...]) - x = jnp.arange(math.prod(in_shape)).reshape(in_shape).astype(jnp.float32) + x = jnp.arange(math.prod(in_shape)).reshape(in_shape) with jtu.capture_stdout() as output: jax.block_until_ready(kernel(x)) - self.assertIn(f"x: [1, 0, 43, 23]/{list(in_shape)}: 6871.000000\n", output()) + self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output()) def test_scoped_allocation(self): def kernel(x_ref, o_ref): From 96268dcae6edb425d1c5250a6e9f44c007ea527b Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 25 Sep 2024 12:55:43 -0400 Subject: [PATCH 663/702] Fix dtype bug in jax.scipy.fft.idct --- jax/_src/scipy/fft.py | 6 +++--- tests/scipy_fft_test.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index a826d4746b1e..a0050cc81055 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -21,7 +21,7 @@ from jax import lax import jax.numpy as jnp from jax._src.util import canonicalize_axis -from jax._src.numpy.util import promote_dtypes_complex +from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact from jax._src.typing import Array def _W4(N: int, k: Array) -> Array: @@ -298,12 +298,12 @@ def idct(x: Array, type: int = 2, n: int | None = None, [(0, n - x.shape[axis] if a == axis else 0, 0) for a in range(x.ndim)]) N = x.shape[axis] - x = x.astype(jnp.float32) + x, = promote_dtypes_inexact(x) if norm is None or norm == 'backward': x = _dct_ortho_norm(x, axis) x = _dct_ortho_norm(x, axis) - k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis]) + k = lax.expand_dims(jnp.arange(N, dtype=x.dtype), [a for a in range(x.ndim) if a != axis]) # everything is complex from here... w4 = _W4(N,k) x = x.astype(w4.dtype) diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py index 6c549f5ed10c..a6fdd1b79f58 100644 --- a/tests/scipy_fft_test.py +++ b/tests/scipy_fft_test.py @@ -13,9 +13,12 @@ # limitations under the License. import itertools +import numpy as np + from absl.testing import absltest import jax +from jax._src import config from jax._src import test_util as jtu import jax.scipy.fft as jsp_fft import scipy.fft as osp_fft @@ -117,5 +120,15 @@ def testiDctn(self, shape, dtype, s, axes, norm): tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) + def testIdctNormalizationPrecision(self): + # reported in https://github.com/jax-ml/jax/issues/23895 + if not config.enable_x64.value: + raise self.skipTest("requires jax_enable_x64=true") + x = np.ones(3, dtype="float64") + n = 10 + expected = osp_fft.idct(x, n=n, type=2) + actual = jsp_fft.idct(x, n=n, type=2) + self.assertArraysAllClose(actual, expected, atol=1e-14) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 111f13e2795ea88ec42fb63df04498a1f8461fd0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 25 Sep 2024 10:13:53 -0700 Subject: [PATCH 664/702] Reverts dffac29e63de6a51047fe77cf9d553ab762ef19b PiperOrigin-RevId: 678748794 --- CHANGELOG.md | 23 +++++++++++++++-------- tests/tree_util_test.py | 7 +++++++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bdcd1c20106..9c3c63b6ee04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,16 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * This release includes wheels for Python 3.13. Free-threading mode is not yet supported. +* Breaking changes + * `jax_pmap_no_rank_reduction` flag is set to `True` by default. + * array[0] on a pmap result now introduces a reshape (use array[0:1] + instead). + * The per-shard shape (accessable via jax_array.addressable_shards or + jax_array.addressable_data(0)) now has a leading (1, ...). Update code + that directly accesses shards accordingly. The rank of the per-shard-shape + now matches that of the global shape which is the same behavior as jit. + This avoids costly reshapes when passing results from pmap into jit. + * Deprecations * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike arguments with `ndim != 1` are now deprecated, and in the future will result @@ -34,6 +44,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. * {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument. The argument was only used by `xmap` which was removed in 0.4.31. + * `jax.tree.map(f, None, non-None)`, which previously emitted a + `DeprecationWarning`, now raises an error in a future version of jax. `None` + is only a tree-prefix of itself. To preserve the current behavior, you can + ask `jax.tree.map` to treat `None` as a leaf value by writing: + `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`. * Bug fixes * Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs @@ -62,14 +77,6 @@ See the 0.4.33 release notes for more details. C++ and CUDA code from JAX. * Changes - * `jax_pmap_no_rank_reduction` flag is set to `True` by default. - * array[0] on a pmap result now introduces a reshape (use array[0:1] - instead). - * The per-shard shape (accessable via jax_array.addressable_shards or - jax_array.addressable_data(0)) now has a leading (1, ...). Update code - that directly accesses shards accordingly. The rank of the per-shard-shape - now matches that of the global shape which is the same behavior as jit. - This avoids costly reshapes when passing results from pmap into jit. * `jax_enable_memories` flag is set to `True` by default. * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. See {ref}`python-array-api` for more information. diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index f8792a263117..c5342a99365d 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -24,6 +24,7 @@ import jax from jax import flatten_util from jax import tree_util +from jax._src.lib import xla_extension_version from jax._src import test_util as jtu from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp @@ -395,6 +396,7 @@ def testFlattenOrder(self): ({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]), ({"a": 1}, {"a": (7,)}, [(7,)]), ({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]), + (None, None, []) ) def testFlattenUpTo(self, tree, xs, expected): _, tree_def = tree_util.tree_flatten(tree) @@ -483,6 +485,11 @@ def testFlattenUpTo(self, tree, xs, expected): [([1], (2,), {"a": [1]})], re.escape("Custom node type mismatch"), ), + *( + [] + if xla_extension_version < 288 + else [(None, [2], re.escape("Expected None, got [2]."))] + ), ) def testFlattenUpToErrors(self, tree, xs, error): _, tree_def = tree_util.tree_flatten(tree) From 6cf09f8c24c67ff650b95d174501fff3cb59db0d Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Wed, 25 Sep 2024 10:33:01 -0700 Subject: [PATCH 665/702] Reverts eff00cc4499cfe3f3f24bafda6c1ecf908232ff3 PiperOrigin-RevId: 678756266 --- jax/BUILD | 4 - jax/_src/internal_test_util/test_harnesses.py | 104 ++-- jax/_src/lax/lax.py | 17 +- jax/_src/lax/parallel.py | 3 +- jax/_src/lax/slicing.py | 394 ++++---------- jax/_src/numpy/lax_numpy.py | 22 +- jax/_src/ops/scatter.py | 4 +- jax/experimental/jax2tf/jax2tf.py | 6 - tests/lax_test.py | 482 +++--------------- tests/lax_vmap_test.py | 32 -- 10 files changed, 236 insertions(+), 832 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index d49e783e61d6..c25d0004e772 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -147,11 +147,7 @@ py_library( srcs = ["_src/internal_test_util/test_harnesses.py"], visibility = [":internal"] + jax_internal_test_harnesses_visibility, deps = [ - ":ad_util", - ":config", ":jax", - ":test_util", - "//jax/_src/lib", ] + py_deps("numpy"), ) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 31c3fec94536..2c94907568d9 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -1169,18 +1169,6 @@ def _make_broadcast_in_dim_harness(name, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3), True), - ((2, 5), np.array([[[0], [2]], [[1], [1]]]), - lax.GatherDimensionNumbers( - offset_dims=(), collapsed_slice_dims=(1,), - start_index_map=(1,), operand_batching_dims=(0,), - start_indices_batching_dims=(0,)), - (1, 1), True), - ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), - lax.GatherDimensionNumbers( - offset_dims=(2,), collapsed_slice_dims=(), - start_index_map=(2,), operand_batching_dims=(0, 1), - start_indices_batching_dims=(1, 0)), - (1, 1, 3), True) ]: dtype = np.float32 for enable_xla in ([True] if needs_xla else [True, False]): @@ -1288,16 +1276,15 @@ def _make_scatter_harness(name, update_shape=(2,), mode=lax.GatherScatterMode.FILL_OR_DROP, dtype=np.float32, - dimension_numbers=lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,)), + dimension_numbers=((), (0,), (0,)), enable_and_disable_xla=False): + dimension_numbers = lax.ScatterDimensionNumbers(*dimension_numbers) xla_options = [True, False] if enable_and_disable_xla else [True] for enable_xla in xla_options: define( f_lax.__name__, - f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_{dimension_numbers=}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}" + f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_scatterindices={scatter_indices.tolist()}_updateshape={update_shape}_updatewindowdims={dimension_numbers.update_window_dims}_insertedwindowdims={dimension_numbers.inserted_window_dims}_scatterdimstooperanddims={dimension_numbers.scatter_dims_to_operand_dims}_indicesaresorted={indices_are_sorted}_uniqueindices={unique_indices}_{mode=!s}_enablexla={enable_xla}" .replace(" ", ""), partial( f_lax, @@ -1341,19 +1328,8 @@ def _make_scatter_harness(name, # Validate shapes, dimension numbers and scatter indices. All are in bounds. for shape, scatter_indices, update_shape, dimension_numbers in [ - ((10,), [[0], [0], [0]], (3, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(0,))), - ((10, 5), [[0], [2], [1]], (3, 3), - lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,))), - ((2, 3, 10), [[[0], [1]], [[2], [3]], [[4], [5]]], (3, 2, 3), - lax.ScatterDimensionNumbers( - update_window_dims=(2,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), - scatter_indices_batching_dims=(1, 0))) + ((10,), [[0], [0], [0]], (3, 2), ((1,), (), (0,))), + ((10, 5), [[0], [2], [1]], (3, 3), ((1,), (0,), (0,))) ]: _make_scatter_harness( "shapes_and_dimension_numbers", @@ -1382,16 +1358,13 @@ def _make_scatter_harness(name, _make_scatter_harness("modes_in_bounds", f_lax=f_lax, mode=mode) - _make_scatter_harness( - "modes_out_of_bounds", - mode=mode, - shape=(1, 5), - f_lax=f_lax, - scatter_indices=np.array([10]), - update_shape=(1,), - dimension_numbers=lax.ScatterDimensionNumbers((0,), (1,), (1,)), - enable_and_disable_xla=True, - ) + _make_scatter_harness("modes_out_of_bounds", mode=mode, + shape=(1, 5), + f_lax=f_lax, + scatter_indices=np.array([10]), + update_shape=(1,), + dimension_numbers=((0,), (1,), (1,)), + enable_and_disable_xla=True) # Validate no XLA scatters for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex) - set(jtu.dtypes.boolean): @@ -1399,34 +1372,22 @@ def _make_scatter_harness(name, lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min, lax.scatter ]: for shape, scatter_indices, update_shape, dimension_numbers in [ - ((1,), [0], (), - lax.ScatterDimensionNumbers((), (0,), (0,))), # zero case - ((1, 1), [0], (1,), - lax.ScatterDimensionNumbers((0,), (0,), (0,))), - ((1, 1, 1), [0], (1, 1), - lax.ScatterDimensionNumbers((0, 1), (0,), (0,))), - ((1, 50, 3), [32], (1, 3), - lax.ScatterDimensionNumbers((0, 1), (1,), (1,))), - ((1, 2, 3), [1], (1, 3), - lax.ScatterDimensionNumbers((0, 1), (1,), (1,))), # slice 2nd dim - ((1, 2, 3), [0], (2, 3), - lax.ScatterDimensionNumbers((0, 1), (0,), (0,))), # slice 1st dim - ((1, 2, 3), [1, 2], (1,), - lax.ScatterDimensionNumbers((0,), (1, 2), (1, 2))), # 2nd and 3rd - ((4, 2, 3), [3, 2], (2,), - lax.ScatterDimensionNumbers((0,), (0, 2), (0, 2))), # 1st and 3rd - ((4, 2, 3, 5), [0, 4], (4, 3), - lax.ScatterDimensionNumbers((0, 1), (1, 3), (1, 3))), # 2nd and 4th + ((1,), [0], (), ((), (0,), (0,))), # zero case + ((1, 1), [0], (1,), ((0,), (0,), (0,))), + ((1, 1, 1), [0], (1, 1), ((0, 1), (0,), (0,))), + ((1, 50, 3), [32], (1, 3), ((0, 1), (1,), (1,))), + ((1, 2, 3), [1], (1, 3), ((0, 1), (1,), (1,))), # slice 2nd dim + ((1, 2, 3), [0], (2, 3), ((0, 1), (0,), (0,))), # slice 1st dim + ((1, 2, 3), [1, 2], (1,), ((0,), (1, 2), (1, 2))), # 2nd and 3rd + ((4, 2, 3), [3, 2], (2,), ((0,), (0, 2), (0, 2))), # 1st and 3rd + ((4, 2, 3, 5), [0, 4], (4, 3), ((0, 1), (1, 3), (1, 3))), # 2nd and 4th ((5, 6, 7), [[0, 1], [2, 3]], (2, 7), - lax.ScatterDimensionNumbers((1,), (0, 1), (0, 1))), - # .at[((3,4),(5,5))] shapes + ((1,), (0, 1), (0, 1))), # .at[((3,4),(5,5))] shapes ((5, 6, 7), [[[0], [1]], [[2], [3]]], (5, 2, 2, 7), - lax.ScatterDimensionNumbers((0, 3), (1,), (1,))), - # .at[:, ((3,4),(5,5))] shapes + ((0, 3), (1,), (1,))), # .at[:, ((3,4),(5,5))] shapes ((5, 6, 7), [[[0, 1], [2, 3]], [[4, 0], [1, 2]]], (5, 2, 2), - lax.ScatterDimensionNumbers((0,), (1, 2), (1, 2))), - # .at[:, ((3,4),(5,5)), 3] shapes - ((1, 125), [0], (1,), lax.ScatterDimensionNumbers((0,), (1,), (1,))), + ((0,), (1, 2), (1, 2))), # .at[:, ((3,4),(5,5)), 3] shapes + ((1, 125), [0], (1,), ((0,), (1,), (1,))), ]: for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS, lax.GatherScatterMode.FILL_OR_DROP): @@ -1449,16 +1410,11 @@ def _make_scatter_harness(name, lax.scatter_add, lax.scatter_mul, lax.scatter_max, lax.scatter_min ]: for shape, scatter_indices, update_shape, dimension_numbers in [ - ((1,), [[0],[0]], (2,), - lax.ScatterDimensionNumbers((), (0,), (0,))), # .at[((0,0),)] - ((3,), [[1],[0],[1]], (3,), - lax.ScatterDimensionNumbers((), (0,), (0,))), # .at[((1,0,1),)] - ((2, 3), [[[2],[2],[2]]], (2, 1, 3), - lax.ScatterDimensionNumbers((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)] - ((3, 5, 40), [[1],[1]], (3, 5, 2), - lax.ScatterDimensionNumbers((0, 1), (2,), (2,))), - ((3, 5, 4), [[1],[1]], (3, 2, 4), - lax.ScatterDimensionNumbers((0, 2), (1,), (1,))), + ((1,), [[0],[0]], (2,), ((), (0,), (0,))), # .at[((0,0),)] + ((3,), [[1],[0],[1]], (3,), ((), (0,), (0,))), # .at[((1,0,1),)] + ((2, 3), [[[2],[2],[2]]], (2, 1, 3), ((0,), (1,), (1,))), # 2nd dim, .at[:, ((2,2,2),)] + ((3, 5, 40), [[1],[1]], (3, 5, 2), ((0, 1), (2,), (2,))), + ((3, 5, 4), [[1],[1]], (3, 2, 4), ((0, 2), (1,), (1,))), ]: for mode in (lax.GatherScatterMode.PROMISE_IN_BOUNDS, lax.GatherScatterMode.FILL_OR_DROP): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ef93304697f3..f51f0436b7a9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4941,15 +4941,18 @@ def _top_k_jvp(primals, tangents, *, k): idx_shape = k_idxs.shape rank = len(idx_shape) gather_index_shape = idx_shape + (1,) - gather_indices = reshape(k_idxs, gather_index_shape) + gather_indices = [] + for i in range(rank-1): + _iota = iota(k_idxs.dtype, idx_shape[i]) + _iota = broadcast_in_dim(_iota, gather_index_shape, (i,)) + gather_indices.append(_iota) + gather_indices.append(reshape(k_idxs, gather_index_shape)) + gather_indices = concatenate(gather_indices, dimension=rank) slice_sizes = (1,) * rank dnums = slicing.GatherDimensionNumbers( - offset_dims=(), - collapsed_slice_dims=(rank - 1,), - operand_batching_dims=tuple(range(rank - 1)), - start_indices_batching_dims=tuple(range(rank - 1)), - start_index_map=(rank - 1,), - ) + offset_dims=(), + collapsed_slice_dims=tuple(range(rank)), + start_index_map=tuple(range(rank))) tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes) return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1])) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 9d4614f344fb..c9a07072ddc7 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1500,8 +1500,7 @@ def _pgather_impl(src, idx, *, axes): dnums = slicing.GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=(0,), - start_index_map=(0,), - ) + start_index_map=(0,)) return slicing.gather(src_one_axis_front, idx, dimension_numbers=dnums, slice_sizes=tuple(slice_sizes)) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 372ebd1a8389..60dfa0e1b3d2 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -233,16 +233,6 @@ class GatherDimensionNumbers(NamedTuple): start_index_map: for each dimension in `start_indices`, gives the corresponding dimension in the `operand` that is to be sliced. Must be a tuple of integers with size equal to `start_indices.shape[-1]`. - operand_batching_dims: the set of batching dimensions `i` in `operand` that - have `slice_sizes[i] == 1` and that should have a corresponding dimension - in both the `start_indices` (at the same index in - `start_indices_batching_dims`) and output of the gather. Must be a tuple - of integers in ascending order. - start_indices_batching_dims: the set of batching dimensions `i` in - `start_indices` that should have a corresponding dimension in both the - `operand` (at the same index in `operand_batching_dims`) and output of the - gather. Must be a tuple of integers (order is fixed based on - correspondence with `operand_batching_dims`). Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the @@ -251,8 +241,6 @@ class GatherDimensionNumbers(NamedTuple): offset_dims: tuple[int, ...] collapsed_slice_dims: tuple[int, ...] start_index_map: tuple[int, ...] - operand_batching_dims: tuple[int, ...] = () - start_indices_batching_dims: tuple[int, ...] = () class GatherScatterMode(enum.Enum): @@ -382,17 +370,6 @@ class ScatterDimensionNumbers(NamedTuple): scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives the corresponding dimension in `operand`. Must be a sequence of integers with size equal to `scatter_indices.shape[-1]`. - operand_batching_dims: the set of batching dimensions `i` in `operand` that - should have a corresponding dimension in both the `scatter_indices` (at - the same index in `scatter_indices_batching_dims`) and `updates`. Must be - a tuple of integers in ascending order. These are the mirror image of - `operand_batching_dims` in the case of `gather`. - scatter_indices_batching_dims: the set of batching dimensions `i` in - `scatter_indices` that should have a corresponding dimension in both the - `operand` (at the same index in `operand_batching_dims`) and output of the - gather. Must be a tuple of integers (order is fixed based on - correspondence with `input_batching_dims`). These are the mirror image of - `start_indices_batching_dims` in the case of `gather`. Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the @@ -401,8 +378,6 @@ class ScatterDimensionNumbers(NamedTuple): update_window_dims: Sequence[int] inserted_window_dims: Sequence[int] scatter_dims_to_operand_dims: Sequence[int] - operand_batching_dims: Sequence[int] = () - scatter_indices_batching_dims: Sequence[int] = () def scatter_add( operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike, @@ -719,8 +694,7 @@ def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=tuple(axes), - start_index_map=tuple(axes), - ) + start_index_map=tuple(axes)) return gather(src, indices, dimension_numbers=dnums, slice_sizes=tuple(slice_sizes)) @@ -1282,11 +1256,8 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): dims = tuple(range(ndims)) start_indices, dyn_slice_sizes = util.split_list(start_indices_and_dyn, [ndims]) start_idx_bds, dyn_slice_size_bds = util.split_list(start_idx_and_dyn_bds, [ndims]) - dnums = GatherDimensionNumbers( - offset_dims=dims, - collapsed_slice_dims=(), - start_index_map=dims, - ) + dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(), + start_index_map=dims) index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds) return _gather_batching_rule( [operand, index, *dyn_slice_sizes], @@ -1425,11 +1396,9 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): update_shape = (np.shape(update) if update_bd is batching.not_mapped else tuple(np.delete(np.shape(update), update_bd))) dims = tuple(range(len(update_shape))) - dnums = ScatterDimensionNumbers( - update_window_dims=dims, - inserted_window_dims=(), - scatter_dims_to_operand_dims=dims, - ) + dnums = ScatterDimensionNumbers(update_window_dims=dims, + inserted_window_dims=(), + scatter_dims_to_operand_dims=dims) index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd) return api.vmap( partial(scatter, dimension_numbers=dnums, @@ -1468,12 +1437,6 @@ def _is_sorted(dims, op_name, name): if dims[i] < dims[i - 1]: raise TypeError(f"{name} in {op_name} op must be sorted; got {dims}") -def _dims_in_range(dims, rank, op_name, name): - for dim in dims: - if dim < 0 or dim >= rank: - raise TypeError(f"Invalid {name} set in {op_name} op; valid range is " - f"[0, {rank}); got: {dim}.") - def _sorted_dims_in_range(dims, rank, op_name, name): if len(dims) == 0: return @@ -1490,11 +1453,6 @@ def _no_duplicate_dims(dims, op_name, name): if len(set(dims)) != len(dims): raise TypeError(f"{name} in {op_name} op must not repeat; got: {dims}.") -def _disjoint_dims(dims1, dims2, op_name, name1, name2): - if not set(dims1).isdisjoint(set(dims2)): - raise TypeError(f"{name1} and {name2} in {op_name} op must be disjoint; " - f"got: {dims1} and {dims2}.") - def _gather_shape_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1508,8 +1466,6 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims - operand_batching_dims = dimension_numbers.operand_batching_dims - start_indices_batching_dims = dimension_numbers.start_indices_batching_dims start_index_map = dimension_numbers.start_index_map # Note: in JAX, index_vector_dim is always computed as below, cf. the @@ -1565,50 +1521,6 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims") _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") - - _no_duplicate_dims(operand_batching_dims, "gather", "operand_batching_dims") - _is_sorted(operand_batching_dims, "gather", "operand_batching_dims") - _sorted_dims_in_range( - operand_batching_dims, _rank(operand), "gather", "operand_batching_dims" - ) - - _disjoint_dims(collapsed_slice_dims, operand_batching_dims, "gather", - "collapsed_slice_dims", "operand_batching_dims") - _disjoint_dims(start_index_map, operand_batching_dims, "gather", - "start_index_map", "operand_batching_dims") - - _no_duplicate_dims( - start_indices_batching_dims, "gather", "start_indices_batching_dims" - ) - _dims_in_range( - start_indices_batching_dims, - _rank(indices), - "gather", - "start_indices_batching_dims", - ) - if index_vector_dim in start_indices_batching_dims: - raise TypeError( - "Gather op cannot have the index vector dimension as a batching " - f"dimension; got {start_indices_batching_dims}." - ) - - if len(operand_batching_dims) != len(start_indices_batching_dims): - raise TypeError( - "Gather op requires equal numbers of operand_batching_dims and " - f"start_indices_batching_dims, got {operand_batching_dims} and" - f"{start_indices_batching_dims}." - ) - - operand_batch_shape = tuple(operand.shape[i] for i in operand_batching_dims) - indices_batch_shape = tuple( - indices.shape[i] for i in start_indices_batching_dims - ) - if not core.definitely_equal_shape(operand_batch_shape, indices_batch_shape): - raise TypeError( - "Gather op requires operand batching dimensions and indices batching " - f"dimensions to have the same shape, got {operand_batch_shape} and " - f"{indices_batch_shape}." - ) # End ValidateGatherDimensions if _rank(operand) != len(slice_sizes): @@ -1616,17 +1528,12 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " f"input_shape.rank={_rank(operand)}") - if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims) + len( - operand_batching_dims - ): - raise TypeError( - "All components of the offset index in a gather op must " - "either be a offset dimension or explicitly collapsed/batching; " - f"got len(slice_sizes)={len(slice_sizes)}, " - f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" - f"{collapsed_slice_dims}, operand_batching_dims=" - f"{operand_batching_dims}." - ) + if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): + raise TypeError(f"All components of the offset index in a gather op must " + f"either be a offset dimension or explicitly collapsed; " + f"got len(slice_sizes)={len(slice_sizes)}, " + f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" + f"{collapsed_slice_dims}.") for i in range(len(slice_sizes)): slice_size = slice_sizes[i] @@ -1645,21 +1552,12 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers, f"but bound is {bound} for index " f"{collapsed_slice_dims[i]} at position {i}.") - for i in range(len(operand_batching_dims)): - bound = slice_sizes[operand_batching_dims[i]] - if bound > 1: - raise TypeError(f"Gather op can only have operand batching dims with " - f"bound 0/1, but bound is {bound} for index " - f"{operand_batching_dims[i]} at position {i}." - ) - return _gather_shape_computation(indices, dimension_numbers, slice_sizes) def _gather_shape_computation(indices, dimension_numbers, slice_sizes): offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims - operand_batching_dims = dimension_numbers.operand_batching_dims output_shape_rank = len(offset_dims) + _rank(indices) - 1 index_vector_dim = _rank(indices) - 1 @@ -1674,11 +1572,8 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes): indices_shape_gen = iter(expanded_indices_shape) - slice_sizes_gen = ( - s - for i, s in enumerate(slice_sizes) - if i not in collapsed_slice_dims and i not in operand_batching_dims - ) + slice_sizes_gen = (s for i, s in enumerate(slice_sizes) + if i not in collapsed_slice_dims) ans = tuple(next(slice_sizes_gen) if i in offset_dims else next(indices_shape_gen) for i in range(output_shape_rank)) return ans @@ -1736,12 +1631,9 @@ def _gather_transpose_rule(t, operand, indices, *, dimension_numbers, else: zeros = lax.full(operand_shape, lax._zero(t)) scatter_dnums = ScatterDimensionNumbers( - update_window_dims=dimension_numbers.offset_dims, - inserted_window_dims=dimension_numbers.collapsed_slice_dims, - scatter_dims_to_operand_dims=dimension_numbers.start_index_map, - operand_batching_dims=dimension_numbers.operand_batching_dims, - scatter_indices_batching_dims=dimension_numbers.start_indices_batching_dims, - ) + update_window_dims=dimension_numbers.offset_dims, + inserted_window_dims=dimension_numbers.collapsed_slice_dims, + scatter_dims_to_operand_dims=dimension_numbers.start_index_map) out = scatter_add(zeros, indices, t, scatter_dnums, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, @@ -1760,17 +1652,11 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, slice_sizes = (operand.shape[0],) + slice_sizes offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims)) collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) - operand_batching_dims = tuple( - np.add(1, dimension_numbers.operand_batching_dims) - ) start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map, - operand_batching_dims=operand_batching_dims, - start_indices_batching_dims=dimension_numbers.start_indices_batching_dims, - ) + start_index_map=start_index_map) if isinstance(operand_bdim, batching.RaggedAxis): ragged_slice_sizes = batching.bdim_as_shape(operand_bdim, slice_sizes) for orig, fabricated in zip( @@ -1801,16 +1687,10 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, elif operand_bdim is None and indices_bdim is not None: indices = batching.moveaxis(indices, indices_bdim, 0) offset_dims = tuple(1 + d for d in dimension_numbers.offset_dims) - start_indices_batching_dims = tuple( - np.add(1, dimension_numbers.start_indices_batching_dims) - ) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=dimension_numbers.collapsed_slice_dims, - start_index_map=dimension_numbers.start_index_map, - operand_batching_dims=dimension_numbers.operand_batching_dims, - start_indices_batching_dims=start_indices_batching_dims, - ) + start_index_map=dimension_numbers.start_index_map) # If batching indexed accesses into the same array, the batched gather may # no longer have sorted or unique indices. return gather(operand, indices, dimension_numbers=dnums, @@ -1822,34 +1702,61 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, operand = batching.moveaxis(operand, operand_bdim, 0) indices = batching.moveaxis(indices, indices_bdim, 0) + # This slightly awkward special case is needed because the shape rule for + # gather does not allow size-1 slices out of a size-0 dimension, even if + # the number of slices is zero. Likely the best fix would be to change the + # definition of gather() so it can be batched without the construction of + # an explicit iota of size-1 slices. if core.definitely_equal(operand.shape[0], 0): - slice_sizes = (0,) + slice_sizes - else: - slice_sizes = (1,) + slice_sizes - collapsed_slice_dims = tuple( - np.add(1, dimension_numbers.collapsed_slice_dims) - ) - operand_batching_dims = (0,) + tuple( - np.add(1, dimension_numbers.operand_batching_dims) - ) - start_indices_batching_dims = (0,) + tuple( - np.add(1, dimension_numbers.start_indices_batching_dims) - ) + output_shape = _gather_shape_rule( + core.ShapedArray(operand.shape[1:], operand.dtype), + core.ShapedArray(indices.shape[1:], + dtypes.canonicalize_dtype(indices.dtype)), + dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, + unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, + mode=mode, fill_value=fill_value) + return lax.full((0,) + output_shape, lax._zero(operand)), 0 + + # Example: user code had indices shape (3, 4, 5), and we have to deal with + # indices shape (7, 3, 4, 5). We transform that to indices of shape + # (7, 3, 4, 6) where we concatenated an iota that counts along our batch + # dimension to the front of the ndindex. + index_dtype = _promote_dtype_for_size(indices.dtype, indices.shape[0]) + count_shape = list(indices.shape) + count_shape[-1] = 1 + counts = lax.broadcasted_iota(index_dtype, tuple(count_shape), 0) + indices = lax.concatenate([counts, indices.astype(index_dtype)], + len(count_shape) - 1) + + slice_sizes = (1,) + slice_sizes + collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) offset_dims = tuple(np.add(1, dimension_numbers.offset_dims)) - start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) + start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, - start_index_map=start_index_map, - operand_batching_dims=operand_batching_dims, - start_indices_batching_dims=start_indices_batching_dims, - ) + start_index_map=start_index_map) return gather(operand, indices, dimension_numbers=dnums, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value), 0 +def _promote_dtype_for_size(dtype, size): + if not dtypes.issubdtype(dtype, np.integer): + return dtype + # size may be a dynamic shape, in which case we return at least int32 + try: + size = int(size) + except: + return dtype if np.iinfo(dtype).bits >= 32 else np.dtype('int32') + if size <= np.iinfo(dtype).max: + return dtype + elif size <= np.iinfo(np.int32).max: + return np.dtype('int32') + else: + return dtypes.canonicalize_dtype(np.int64) + def _gather_pad_rule(in_avals, out_avals, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1914,10 +1821,8 @@ def _gather_lower(ctx, operand, indices, *, GatherScatterMode.CLIP), mode dnums = hlo.GatherDimensionNumbers.get( collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), - operand_batching_dims=list(dimension_numbers.operand_batching_dims), - start_indices_batching_dims=list( - dimension_numbers.start_indices_batching_dims - ), + operand_batching_dims=[], + start_indices_batching_dims=[], index_vector_dim=len(ctx.avals_in[1].shape) - 1, offset_dims=list(dimension_numbers.offset_dims), start_index_map=list(dimension_numbers.start_index_map), @@ -1967,8 +1872,6 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, update_window_dims = dimension_numbers.update_window_dims inserted_window_dims = dimension_numbers.inserted_window_dims - operand_batching_dims = dimension_numbers.operand_batching_dims - scatter_indices_batching_dims = dimension_numbers.scatter_indices_batching_dims scatter_dims_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims # Note: in JAX, index_vector_dim is always computed as below, cf. the # documentation of the ScatterDimensionNumbers class. @@ -2006,55 +1909,8 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, _sorted_dims_in_range(inserted_window_dims, _rank(operand), "scatter", "inserted_window_dims") - # Validate operand_batching_dims and scatter_indices_batching_dims - _is_sorted(operand_batching_dims, "scatter", "operand_batching_dims") - _no_duplicate_dims(operand_batching_dims, "scatter", "operand_batching_dims") - _sorted_dims_in_range( - operand_batching_dims, _rank(operand), "scatter", "operand_batching_dims" - ) - _disjoint_dims(inserted_window_dims, operand_batching_dims, "scatter", - "inserted_window_dims", "operand_batching_dims") - _disjoint_dims(scatter_dims_to_operand_dims, operand_batching_dims, "scatter", - "scatter_dims_to_operand_dims", "operand_batching_dims") - - _no_duplicate_dims( - scatter_indices_batching_dims, "scatter", "scatter_indices_batching_dims" - ) - _dims_in_range( - scatter_indices_batching_dims, - _rank(indices), - "scatter", - "scatter_indices_batching_dims", - ) - if index_vector_dim in scatter_indices_batching_dims: - raise TypeError( - "Scatter op cannot have the index vector dimension as a batching " - f"dimension; got {scatter_indices_batching_dims}.") - - if len(operand_batching_dims) != len(scatter_indices_batching_dims): - raise TypeError( - "Scatter op requires equal numbers of operand_batching_dims and " - f"scatter_indices_batching_dims, got {operand_batching_dims} and " - f"{scatter_indices_batching_dims}." - ) - - operand_batch_shape = tuple(operand.shape[i] for i in operand_batching_dims) - indices_batch_shape = tuple( - indices.shape[i] for i in scatter_indices_batching_dims - ) - if not core.definitely_equal_shape(operand_batch_shape, indices_batch_shape): - raise TypeError( - "Scatter op requires operand batching dimensions and indices batching " - f"dimensions to have the same shape, got {operand_batch_shape} and " - f"{indices_batch_shape}." - ) - # Validate window_size - window_size = ( - len(update_window_dims) + - len(inserted_window_dims) + - len(operand_batching_dims) - ) + window_size = len(update_window_dims) + len(inserted_window_dims) if _rank(operand) != window_size: raise TypeError(f"Scatter op has window of size {window_size}; doesn't " f"match operand of rank {_rank(operand)}.") @@ -2077,14 +1933,8 @@ def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, _no_duplicate_dims(scatter_dims_to_operand_dims, "scatter", "scatter_dims_to_operand_dims") - max_update_slice_sizes = [ - operand.shape[i] - for i in range(len(operand.shape)) - if ( - i not in set(inserted_window_dims) - and i not in set(operand_batching_dims) - ) - ] + max_update_slice_sizes = [operand.shape[i] for i in range(len(operand.shape)) + if not i in set(inserted_window_dims)] for i in range(len(update_window_dims)): update_window_dim = update_window_dims[i] @@ -2118,7 +1968,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums): slice_sizes = [] pos = 0 for i in range(len(operand.shape)): - if i in dnums.inserted_window_dims or i in dnums.operand_batching_dims: + if i in dnums.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) @@ -2179,19 +2029,13 @@ def _scatter_add_transpose_rule(t, operand, indices, updates, *, if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims, - operand_batching_dims=dimension_numbers.operand_batching_dims, - start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, - ) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if ( - i in dimension_numbers.inserted_window_dims - or i in dimension_numbers.operand_batching_dims - ): + if i in dimension_numbers.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2223,19 +2067,13 @@ def _scatter_mul_transpose_rule(t, operand, indices, updates, *, raise NotImplementedError( "scatter_mul gradients are only implemented if `unique_indices=True`") gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims, - operand_batching_dims=dimension_numbers.operand_batching_dims, - start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, - ) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if ( - i in dimension_numbers.inserted_window_dims - or i in dimension_numbers.operand_batching_dims - ): + if i in dimension_numbers.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2257,52 +2095,40 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims) if ax is not None) operand = batching.bdim_at_front(operand, operand_bdim, size) + operand_bdim = 0 updates = batching.bdim_at_front(updates, updates_bdim, size) if indices_bdim is None: inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims)) update_window_dims = (0,) + tuple(np.add(1, dimension_numbers.update_window_dims)) - operand_batching_dims = tuple( - np.add(1, dimension_numbers.operand_batching_dims) - ) scatter_dims_to_operand_dims = tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, - operand_batching_dims=operand_batching_dims, - scatter_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, - ) + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) return scatter_op.bind( operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, update_jaxpr=update_jaxpr, update_consts=update_consts), 0 + # see the third case in _gather_batching_rule for comparison and comments indices = batching.bdim_at_front(indices, indices_bdim, size) + count_shape = list(indices.shape) + count_shape[-1] = 1 + counts = lax.broadcasted_iota(indices.dtype, tuple(count_shape), 0) + indices = lax.concatenate([counts, indices], len(count_shape) - 1) + update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims)) - inserted_window_dims = tuple( - np.add(1, dimension_numbers.inserted_window_dims) - ) - operand_batching_dims = (0,) + tuple( - np.add(1, dimension_numbers.operand_batching_dims) - ) - scatter_indices_batching_dims = (0,) + tuple( - np.add(1, dimension_numbers.scatter_indices_batching_dims) - ) - scatter_dims_to_operand_dims = tuple( - np.add(1, dimension_numbers.scatter_dims_to_operand_dims) - ) + inserted_window_dims = (0,) + tuple(np.add(1, dimension_numbers.inserted_window_dims)) + scatter_dims_to_operand_dims = (0,) + tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, - operand_batching_dims=operand_batching_dims, - scatter_indices_batching_dims=scatter_indices_batching_dims, - ) + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) return scatter_op.bind( operand, indices, updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, @@ -2364,18 +2190,12 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, gather_dnums = GatherDimensionNumbers( offset_dims=scatter_dnums.update_window_dims, collapsed_slice_dims=scatter_dnums.inserted_window_dims, - start_index_map=scatter_dnums.scatter_dims_to_operand_dims, - operand_batching_dims=scatter_dnums.operand_batching_dims, - start_indices_batching_dims=scatter_dnums.scatter_indices_batching_dims, - ) + start_index_map=scatter_dnums.scatter_dims_to_operand_dims) slice_sizes = [] pos = 0 for i in range(len(operand.shape)): - if ( - i in scatter_dnums.inserted_window_dims - or i in scatter_dnums.operand_batching_dims - ): + if i in scatter_dnums.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates_shape[scatter_dnums.update_window_dims[pos]]) @@ -2503,6 +2323,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, # of using scatter-add here is that we don't need a `scatter` transpose # rule. + # a) attach a positive ID to each update in `updates`, and perform a scatter # on the IDs. ids_shape = list(updates.shape) @@ -2523,16 +2344,13 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, # b) compute the inverse gather that "undoes" the scatter on the id values. gather_dnums = GatherDimensionNumbers( - offset_dims=dnums.update_window_dims, - collapsed_slice_dims=dnums.inserted_window_dims, - start_index_map=dnums.scatter_dims_to_operand_dims, - operand_batching_dims=dnums.operand_batching_dims, - start_indices_batching_dims=dnums.scatter_indices_batching_dims, - ) + offset_dims=dnums.update_window_dims, + collapsed_slice_dims=dnums.inserted_window_dims, + start_index_map=dnums.scatter_dims_to_operand_dims) slice_sizes = [] pos = 0 for i in range(len(scattered_ids.shape)): - if i in dnums.inserted_window_dims or i in dnums.operand_batching_dims: + if i in dnums.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) @@ -2587,19 +2405,13 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( - offset_dims=dimension_numbers.update_window_dims, - collapsed_slice_dims=dimension_numbers.inserted_window_dims, - start_index_map=dimension_numbers.scatter_dims_to_operand_dims, - operand_batching_dims=dimension_numbers.operand_batching_dims, - start_indices_batching_dims=dimension_numbers.scatter_indices_batching_dims, - ) + offset_dims=dimension_numbers.update_window_dims, + collapsed_slice_dims=dimension_numbers.inserted_window_dims, + start_index_map=dimension_numbers.scatter_dims_to_operand_dims) slice_sizes = [] pos = 0 for i in range(len(t.shape)): - if ( - i in dimension_numbers.inserted_window_dims - or i in dimension_numbers.operand_batching_dims - ): + if i in dimension_numbers.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) @@ -2667,8 +2479,8 @@ def _scatter_lower(ctx, operand, indices, updates, *, scatter_dnums = hlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), - input_batching_dims=list(dnums.operand_batching_dims), - scatter_indices_batching_dims=list(dnums.scatter_indices_batching_dims), + input_batching_dims=[], + scatter_indices_batching_dims=[], scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1, ) @@ -2727,8 +2539,8 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates, scatter_dnums = hlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), - input_batching_dims=list(dnums.operand_batching_dims), - scatter_indices_batching_dims=list(dnums.scatter_indices_batching_dims), + input_batching_dims=[], + scatter_indices_batching_dims=[], scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1, ) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ac3074f45934..559d17cd9514 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9824,8 +9824,6 @@ def replace(tup, val): offset_dims = [] start_index_map = [] collapsed_slice_dims = [] - operand_batching_dims = [] - start_indices_batching_dims = [] j = 0 for i in range(rank): if i == axis_int: @@ -9850,23 +9848,21 @@ def replace(tup, val): collapsed_slice_dims.append(i) j += 1 else: - # Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both - # array and index as batching so corresponding elements are gathered. - if core.definitely_equal(arr_shape[i], 0): - slice_sizes.append(0) - else: - slice_sizes.append(1) - operand_batching_dims.append(i) - start_indices_batching_dims.append(j) + # Otherwise, idx_shape[i] == arr_shape[i]. Use an iota index so + # corresponding elements of array and index are gathered. + # TODO(mattjj): next line needs updating for dynamic shapes + iota = lax.broadcasted_iota(index_dtype, gather_index_shape, j) + gather_indices.append(iota) + slice_sizes.append(1) + start_index_map.append(i) + collapsed_slice_dims.append(i) j += 1 gather_indices_arr = lax.concatenate(gather_indices, dimension=j) dnums = lax.GatherDimensionNumbers( offset_dims=tuple(offset_dims), collapsed_slice_dims=tuple(collapsed_slice_dims), - start_index_map=tuple(start_index_map), - operand_batching_dims=tuple(operand_batching_dims), - start_indices_batching_dims=tuple(start_indices_batching_dims)) + start_index_map=tuple(start_index_map)) return lax.gather(a, gather_indices_arr, dnums, tuple(slice_sizes), mode="fill" if mode is None else mode, fill_value=fill_value) diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 809df8195d54..2bcfe96ad2f0 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -122,9 +122,7 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, - scatter_dims_to_operand_dims=indexer.dnums.start_index_map, - operand_batching_dims=indexer.dnums.operand_batching_dims, - scatter_indices_batching_dims=indexer.dnums.start_indices_batching_dims, + scatter_dims_to_operand_dims=indexer.dnums.start_index_map ) out = scatter_op( x, indexer.gather_indices, y, dnums, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index bc533e6d145b..f01a3ab7a036 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2873,9 +2873,6 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers): proto.offset_dims.extend(dimension_numbers.offset_dims) proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims) proto.start_index_map.extend(dimension_numbers.start_index_map) - proto.operand_batching_dims.extend(dimension_numbers.operand_batching_dims) - proto.start_indices_batching_dims.extend( - dimension_numbers.start_indices_batching_dims) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto @@ -2987,9 +2984,6 @@ def _scatter_dimensions_proto(indices_shape, dimension_numbers): proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims) proto.scatter_dims_to_operand_dims.extend( dimension_numbers.scatter_dims_to_operand_dims) - proto.input_batching_dims.extend(dimension_numbers.operand_batching_dims) - proto.scatter_indices_batching_dims.extend( - dimension_numbers.scatter_indices_batching_dims) assert indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto diff --git a/tests/lax_test.py b/tests/lax_test.py index e8f3b730e3da..c8f3ca797903 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2686,18 +2686,6 @@ def testIndexTake(self, shape, dtype, idxs, axes): ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), - ((2, 5), np.array([[[0], [2]], [[1], [1]]]), - lax.GatherDimensionNumbers( - offset_dims=(), collapsed_slice_dims=(1,), - start_index_map=(1,), operand_batching_dims=(0,), - start_indices_batching_dims=(0,)), - (1, 1)), - ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), - lax.GatherDimensionNumbers( - offset_dims=(2,), collapsed_slice_dims=(), - start_index_map=(2,), operand_batching_dims=(0, 1), - start_indices_batching_dims=(1, 0)), - (1, 1, 3)) ]], dtype=lax_test_util.all_dtypes, ) @@ -2715,196 +2703,63 @@ def testGather(self, shape, dtype, idxs, dnums, slice_sizes): @parameterized.named_parameters( {"testcase_name": f"_{testcase_name}", "operand_shape": operand_shape, "indices_shape": indices_shape, - "dimension_numbers": dimension_numbers, + "dimension_numbers": lax.GatherDimensionNumbers( + offset_dims=offset_dims, + collapsed_slice_dims=collapsed_slice_dims, + start_index_map=start_index_map), "slice_sizes": slice_sizes, "msg": msg} - for (testcase_name, operand_shape, indices_shape, dimension_numbers, - slice_sizes, msg) in [ + for (testcase_name, operand_shape, indices_shape, offset_dims, + collapsed_slice_dims, start_index_map, slice_sizes, msg) in [ ("NonAscendingWindowIndices", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 8, 7), collapsed_slice_dims=(), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 7, 6), "offset_dims in gather op must be sorted"), + (4, 5, 6, 8, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + "offset_dims in gather op must be sorted"), ("RepeatedWindowIndices", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 7), collapsed_slice_dims=(), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 7, 6), "offset_dims in gather op must not repeat"), + (4, 5, 6, 7, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + "offset_dims in gather op must not repeat"), ("WindowIndexOutOfBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 100, 101, 102), collapsed_slice_dims=(), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 7, 6), "Offset dimension 2 in gather op is out of bounds"), + (4, 5, 100, 101, 102), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + "Offset dimension 2 in gather op is out of bounds"), ("WindowIndexBarelyOutOfBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 1), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 9), collapsed_slice_dims=(), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 7, 6), "Offset dimension 4 in gather op is out of bounds"), + (4, 5, 6, 7, 9), (), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + "Offset dimension 4 in gather op is out of bounds"), ("MismatchingElidedWindowDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(4,), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 7, 6), + (4, 5, 6, 7, 8), (4,), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), ("All components of the offset index in a gather op must either be a " - "offset dimension or explicitly collapsed/batching")), - ("MismatchingElidedWindowDimsV2", (10, 9, 8, 7, 6, 5), (10, 4, 3, 2, 4), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(4,), - start_index_map=(1, 2, 3, 4), operand_batching_dims=(0,), - start_indices_batching_dims=(0,)), - (10, 9, 8, 7, 6, 5), - ("All components of the offset index in a gather op must either be a " - "offset dimension or explicitly collapsed/batching")), + "offset dimension or explicitly collapsed")), ("OutOfBoundsWindowToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(0, 1, 2, 3, 19), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 7, 6), + (4, 5, 6, 7, 8), (0, 1, 2, 3, 19), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), "Invalid collapsed_slice_dims set in gather op; valid range is"), ("RepeatedWindowToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(0, 1, 2, 3, 3), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 7, 6), "collapsed_slice_dims in gather op must not repeat"), + (4, 5, 6, 7, 8), (0, 1, 2, 3, 3), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), + "collapsed_slice_dims in gather op must not repeat"), ("MismatchingGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), - start_index_map=(0, 1, 2, 3)), - (10, 9, 8, 7, 6), + (4, 5, 6, 7, 8), (), (0, 1, 2, 3), (10, 9, 8, 7, 6), ("Gather op has 4 elements in start_index_map and the bound of " "dimension index_vector_dim=4 of indices is 5. These two " "numbers must be equal.")), ("OutOfBoundsGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), - start_index_map=(0, 1, 2, 3, 7)), - (10, 9, 8, 7, 6), "Invalid start_index_map"), + (4, 5, 6, 7, 8), (), (0, 1, 2, 3, 7), (10, 9, 8, 7, 6), + "Invalid start_index_map"), ("RepeatedGatherToInputMapping", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(), - start_index_map=(0, 1, 2, 3, 3)), - (10, 9, 8, 7, 6), "start_index_map in gather op must not repeat"), + (4, 5, 6, 7, 8), (), (0, 1, 2, 3, 3), (10, 9, 8, 7, 6), + "start_index_map in gather op must not repeat"), ("NonAscendingElidedWindowDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(2, 1), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 7, 6), + (4, 5, 6, 7, 8), (2, 1), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), "collapsed_slice_dims in gather op must be sorted"), ("WindowBoundsTooLarge", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(2,), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 100, 6), + (4, 5, 6, 7), (2,), (0, 1, 2, 3, 4), (10, 9, 8, 100, 6), "Slice size at index 3 in gather op is out of range"), ("MismatchingNumberOfWindowBounds", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 7), + (4, 5, 6, 7), (), (0, 1, 2, 3, 4), (10, 9, 8, 7), "Gather op must have one slice size for every input dimension"), ("WindowBoundsNot1ForElidedDim", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(1,), - start_index_map=(0, 1, 2, 3, 4)), - (10, 9, 8, 7, 6), + (4, 5, 6, 7), (1,), (0, 1, 2, 3, 4), (10, 9, 8, 7, 6), ("Gather op can only collapse slice dims with bound 1, but bound " - "is 9 for index 1 at position 0.")), - ("RepeatedOperandBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 3), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), - start_index_map=(0, 1, 4), operand_batching_dims=(2, 3, 3)), - (10, 9, 8, 7, 6), - "operand_batching_dims in gather op must not repeat"), - ("NonAscendingOperandBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 3), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), - start_index_map=(0, 1, 4), operand_batching_dims=(3, 2)), - (10, 9, 8, 7, 6), - "operand_batching_dims in gather op must be sorted"), - ("OutOfBoundsOperandBatchingDims", (10, 9, 8, 7, 6), - (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), - start_index_map=(0, 1, 2, 3, 4), - operand_batching_dims=(0, 10)), - (10, 9, 8, 7, 6), - "Invalid operand_batching_dims set in gather op; valid range is"), - ("NonDisjointCollapsedAndBatchingDims", (10, 9, 8, 7, 6), - (5, 4, 3, 2, 3), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1, 2), - start_index_map=(0, 1, 4), operand_batching_dims=(2, 3)), - (10, 9, 8, 7, 6), - ("collapsed_slice_dims and operand_batching_dims in gather op must be " - "disjoint")), - ("NonDisjointStartIndexMapAndBatchingDims", (10, 9, 8, 7, 6), - (5, 4, 3, 2, 4), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), - start_index_map=(0, 1, 2, 4), operand_batching_dims=(2, 3)), - (10, 9, 8, 7, 6), - ("start_index_map and operand_batching_dims in gather op must be " - "disjoint")), - ("WindowBoundsNot1ForBatchingDim", (10, 9, 8, 7, 6), (9, 4, 3, 2, 4), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7), collapsed_slice_dims=(), - start_index_map=(0, 2, 3, 4), operand_batching_dims=(1,), - start_indices_batching_dims=(0,)), - (10, 9, 8, 7, 6), - ("Gather op can only have operand batching dims with bound 0/1, but " - "bound is 9 for index 1 at position 0.")), - ("RepeatedStartIndicesBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), - start_index_map=(0, 1, 2, 3, 4), - start_indices_batching_dims=(0, 1, 0)), - (10, 9, 8, 7, 6), - "start_indices_batching_dims in gather op must not repeat"), - ("OutOfBoundsStartIndicesBatchingDims", (10, 9, 8, 7, 6), - (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), - start_index_map=(0, 1, 2, 3, 4), - start_indices_batching_dims=(0, 5)), - (10, 9, 8, 7, 6), - "Invalid start_indices_batching_dims set in gather op; valid range"), - ("IndexVectorDimInStartIndicesBatchingDims", (10, 9, 8, 7, 6), - (5, 4, 3, 2, 5), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6), collapsed_slice_dims=(0, 1), - start_index_map=(0, 1, 2, 3, 4), - start_indices_batching_dims=(0, 4)), - (10, 9, 8, 7, 6), - ("Gather op cannot have the index vector dimension as a batching " - "dimension")), - ("MismatchingNumberOfBatchingDims", (10, 9, 8, 7, 6), (5, 4, 3, 2, 4), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6), collapsed_slice_dims=(1, 2), - start_index_map=(1, 2, 3, 4), operand_batching_dims=(0,), - start_indices_batching_dims=(0, 1)), - (10, 9, 8, 7, 6), - ("Gather op requires equal numbers of operand_batching_dims and " - "start_indices_batching_dims")), - ("MismatchingBatchingDimSizes", (10, 9, 8, 7, 6), (10, 9, 3, 2, 3), - lax.GatherDimensionNumbers( - offset_dims=(4, 5, 6, 7, 8), collapsed_slice_dims=(2, 3, 4), - start_index_map=(2, 3, 4), operand_batching_dims=(0, 1), - start_indices_batching_dims=(1, 0)), - (10, 9, 8, 7, 6), - ("Gather op requires operand batching dimensions and indices batching " - "dimensions to have the same shape")) + "is 9 for index 1 at position 0.")) ] ) def testGatherShapeCheckingRule(self, operand_shape, indices_shape, dimension_numbers, slice_sizes, msg): - """ - - Args: - operand_shape: - indices_shape: - dimension_numbers: - slice_sizes: - msg: - """ operand = np.ones(operand_shape, dtype=np.int32) indices = np.ones(indices_shape, dtype=np.int32) @@ -2921,19 +2776,9 @@ def testGatherShapeCheckingRule(self, operand_shape, indices_shape, ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), - ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), - scatter_indices_batching_dims=(0,))), - ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), - (3, 2, 3), lax.ScatterDimensionNumbers( - update_window_dims=(2,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), - scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.inexact_dtypes, mode=["clip", "fill", None], @@ -2957,19 +2802,9 @@ def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, mode): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5,), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), - ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), - scatter_indices_batching_dims=(0,))), - ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), - (3, 2, 3), lax.ScatterDimensionNumbers( - update_window_dims=(2,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), - scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -2992,19 +2827,9 @@ def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), - ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), - scatter_indices_batching_dims=(0,))), - ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), - (3, 2, 3), lax.ScatterDimensionNumbers( - update_window_dims=(2,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), - scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -3026,19 +2851,9 @@ def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), - ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), - scatter_indices_batching_dims=(0,))), - ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), - (3, 2, 3), lax.ScatterDimensionNumbers( - update_window_dims=(2,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), - scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -3060,19 +2875,9 @@ def testScatterApply(self, arg_shape, dtype, idxs, update_shape, dnums): ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), - ((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( + ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), - ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), - scatter_indices_batching_dims=(0,))), - ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), - (3, 2, 3), lax.ScatterDimensionNumbers( - update_window_dims=(2,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), - scatter_indices_batching_dims=(1, 0))) ]], dtype=lax_test_util.float_dtypes, ) @@ -3090,207 +2895,84 @@ def testScatter(self, arg_shape, dtype, idxs, update_shape, dnums): # variations to account for the implicit setting of index_vector_dim in JAX. @parameterized.named_parameters( {"testcase_name": f"_{testcase_name}", "operand_shape": operand_shape, - "indices_shape": indices_shape, "update_shape": update_shape, - "dimension_numbers": dimension_numbers, + "indices": indices, "update_shape": update_shape, + "dimension_numbers": lax.ScatterDimensionNumbers( + update_window_dims=update_window_dims, + inserted_window_dims=inserted_window_dims, + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims), "msg": msg} - for (testcase_name, operand_shape, indices_shape, update_shape, - dimension_numbers, msg) in [ - ("ScatterWithUpdatesBiggerThanInput", (64, 48), (32, 1), (65, 32), - lax.ScatterDimensionNumbers( - update_window_dims=(0,), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(1,)), + for (testcase_name, operand_shape, indices, update_shape, + update_window_dims, inserted_window_dims, + scatter_dims_to_operand_dims, msg) in [ + ("ScatterWithUpdatesBiggerThanInput", (64, 48), np.zeros((32, 1)), + (65, 32), (0,), (1,), (1,), "Bounds of the window dimensions"), + ("ScatterWithUpdatesBiggerThanInputV2", (64, 48), + np.zeros((32, 1)), (32, 49), (1,), (0,), (1,), "Bounds of the window dimensions"), - ("ScatterWithUpdatesBiggerThanInputV2", (64, 48), (32, 1), - (32, 49), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(1,)), - "Bounds of the window dimensions"), - ("ScatterWithUpdatesNotMatchingIndices", (64, 48), (32, 1), - (64, 31), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(1,)), + ("ScatterWithUpdatesNotMatchingIndices", (64, 48), + np.zeros((32, 1)), (64, 31), (0,), (1,), (1,), "Bounds of the scatter dimensions"), - ("ScatterWithUpdatesNotMatchingIndicesV2", (64, 48), (32, 1), - (31, 48), lax.ScatterDimensionNumbers( - update_window_dims=(1,), inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(1,)), + ("ScatterWithUpdatesNotMatchingIndicesV2", (64, 48), + np.zeros((32, 1)), (31, 48), (1,), (0,), (1,), "Bounds of the scatter dimensions"), ("ScatterNdWithUpdatesBiggerThanInput", (64, 48), - (10, 9, 8, 7, 1), (10, 9, 8, 7, 65), - lax.ScatterDimensionNumbers( - update_window_dims=(4,), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(1,)), - "Bounds of the window dimensions"), + np.zeros((10, 9, 8, 7, 1)), (10, 9, 8, 7, 65), (4,), (1,), + (0,), "Bounds of the window dimensions"), ("ScatterNdWithUpdatesNotMatchingIndices", (64, 48), - (10, 9, 8, 7, 1), (9, 9, 8, 7, 64), - lax.ScatterDimensionNumbers( - update_window_dims=(4,), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(0,)), + np.zeros((10, 9, 8, 7, 1)), (9, 9, 8, 7, 64), (4,), (1,), (0,), "Bounds of the scatter dimensions"), - ("InvalidUpdates", (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), - (10, 9, 8, 7, 3, 2, 4, 1), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), + ("InvalidUpdates", (50, 49, 48, 47, 46), + np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4, 1), + (4, 5, 6), (1, 2), (0, 1, 2, 3, 4), "Updates tensor must be of rank 7; got 8."), - ("NonAscendingUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), - (10, 9, 8, 7, 6, 5, 4, 3, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6, 8, 7), inserted_window_dims=(), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), + ("NonAscendingUpdateWindowDims", (6, 5, 4, 3, 2), + np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), + (4, 5, 6, 8, 7), (), (0, 1, 2, 3, 4), "update_window_dims in scatter op must be sorted"), - ("RepeatedUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), - (10, 9, 8, 7, 6, 5, 4, 3, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6, 7, 7), inserted_window_dims=(), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), + ("RepeatedUpdateWindowDims", (6, 5, 4, 3, 2), + np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), + (4, 5, 6, 7, 7), (), (0, 1, 2, 3, 4), "update_window_dims in scatter op must not repeat"), - ("OutOfBoundsUpdateWindowDims", (6, 5, 4, 3, 2), (5, 4, 3, 2, 1), - (10, 9, 8, 7, 6, 5, 4, 3, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6, 7, 9), inserted_window_dims=(), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), + ("OutOfBoundsUpdateWindowDims", (6, 5, 4, 3, 2), + np.zeros((5, 4, 3, 2, 1)), (10, 9, 8, 7, 6, 5, 4, 3, 2), + (4, 5, 6, 7, 9), (), (0, 1, 2, 3, 4), "Invalid update_window_dims set in scatter op"), ("NonAscendingInsertedWindowDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(2, 1), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), + np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), + (4, 5, 6), (2, 1), (0, 1, 2, 3, 4), "inserted_window_dims in scatter op must be sorted"), ("RepeatedInsertedWindowDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(1, 1), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), + np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), + (4, 5, 6), (1, 1), (0, 1, 2, 3, 4), "inserted_window_dims in scatter op must not repeat"), ("OutOfBoundsInsertedWindowDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(1, 5), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4)), + np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), + (4, 5, 6), (1, 5), (0, 1, 2, 3, 4), "Invalid inserted_window_dims set in scatter op"), ("MismatchingScatterDimsToOperandDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), - scatter_dims_to_operand_dims=(0, 1, 2, 3)), + np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), + (4, 5, 6), (1, 2), (0, 1, 2, 3), ("Scatter op has 4 elements in scatter_dims_to_operand_dims and " "the bound of dimension index_vector_dim=4 of indices " "is 5. These two numbers must be equal")), ("OutOfBoundsScatterDimsToOperandDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 10)), + np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), + (4, 5, 6), (1, 2), (0, 1, 2, 3, 10), "Invalid scatter_dims_to_operand_dims mapping"), ("RepeatedValuesInScatterDimsToOperandDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), - scatter_dims_to_operand_dims=(0, 1, 2, 2, 3)), + np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), + (4, 5, 6), (1, 2), (0, 1, 2, 2, 3), "scatter_dims_to_operand_dims in scatter op must not repeat"), ("InsufficientWindowDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 4), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(0, 1, 2, 3)), + np.zeros((10, 9, 8, 7, 5)), (10, 9, 8, 7, 3, 2, 4), + (4, 5, 6), (1,), (0, 1, 2, 3), ("Scatter op has window of size 4; doesn't match operand of " - "rank 5.")), - ("InsufficientWindowDimsV2", (10, 49, 48, 47, 46, 45), - (10, 9, 8, 7, 3), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(1, 2, 3), - operand_batching_dims=(0,), - scatter_indices_batching_dims=(0,)), - ("Scatter op has window of size 5; doesn't match operand of " - "rank 6.")), - ("RepeatedOperandBatchingDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), - scatter_dims_to_operand_dims=(0, 1, 4), - operand_batching_dims=(2, 3, 3)), - "operand_batching_dims in scatter op must not repeat"), - ("NonAscendingOperandBatchingDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), - scatter_dims_to_operand_dims=(0, 1, 4), - operand_batching_dims=(3, 2)), - "operand_batching_dims in scatter op must be sorted"), - ("OutOfBoundsOperandBatchingDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), - operand_batching_dims=(0, 10)), - ("Invalid operand_batching_dims set in scatter op; valid range " - "is")), - ("NonDisjointCollapsedAndBatchingDims", (50, 49, 48, 47, 46, 45), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), - scatter_dims_to_operand_dims=(0, 1, 4), - operand_batching_dims=(1, 2)), - ("inserted_window_dims and operand_batching_dims in scatter op " - "must be disjoint")), - ("NonDisjointScatterDimsToOperandDimsAndBatchingDims", - (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), - (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), - scatter_dims_to_operand_dims=(0, 1, 2, 4), - operand_batching_dims=(2, 3)), - ("scatter_dims_to_operand_dims and operand_batching_dims in " - "scatter op must be disjoint")), - ("RepeatedScatterIndicesBatchingDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), - scatter_indices_batching_dims=(0, 1, 0)), - "scatter_indices_batching_dims in scatter op must not repeat"), - ("OutOfBoundsScatterIndicesBatchingDims", (50, 49, 48, 47, 46), - (10, 9, 8, 7, 5), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), - scatter_indices_batching_dims=(0, 5)), - ("Invalid scatter_indices_batching_dims set in scatter op; " - "valid range")), - ("IndexVectorDimInScatterIndicesBatchingDims", - (50, 49, 48, 47, 46), (10, 9, 8, 7, 5), - (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(0, 1), - scatter_dims_to_operand_dims=(0, 1, 2, 3, 4), - scatter_indices_batching_dims=(0, 4)), - ("Scatter op cannot have the index vector dimension as a " - "batching dimension")), - ("MismatchingNumberOfBatchingDims", (50, 49, 48, 47, 46, 45), - (10, 9, 8, 7, 4), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(1, 2), - scatter_dims_to_operand_dims=(1, 2, 3, 4), - operand_batching_dims=(0,), - scatter_indices_batching_dims=(0, 1)), - ("Scatter op requires equal numbers of operand_batching_dims " - "and scatter_indices_batching_dims")), - ("MismatchingBatchingDimSizes", (10, 9, 48, 47, 46, 45), - (10, 9, 8, 7, 2), (10, 9, 8, 7, 3, 2, 4), - lax.ScatterDimensionNumbers( - update_window_dims=(4, 5, 6), inserted_window_dims=(2,), - scatter_dims_to_operand_dims=(2, 3), - operand_batching_dims=(0, 1), - scatter_indices_batching_dims=(1, 0)), - ("Scatter op requires operand batching dimensions and indices " - "batching dimensions to have the same shape")) + "rank 5.")) ] ) - def testScatterShapeCheckingRule(self, operand_shape, indices_shape, + def testScatterShapeCheckingRule(self, operand_shape, indices, update_shape, dimension_numbers, msg): - indices = np.zeros(indices_shape, dtype=np.int32) + def f(x, y): operand = lax.broadcast(x, operand_shape) updates = lax.broadcast(y, update_shape) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 0f259bf490e6..37a0011e7bd0 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -566,18 +566,6 @@ def testFft(self, fft_ndims, shape, bdims): ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)), (1, 3)), - ((2, 5), np.array([[[0], [2]], [[1], [1]]]), - lax.GatherDimensionNumbers( - offset_dims=(), collapsed_slice_dims=(1,), - start_index_map=(1,), operand_batching_dims=(0,), - start_indices_batching_dims=(0,)), - (1, 1)), - ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), - lax.GatherDimensionNumbers( - offset_dims=(2,), collapsed_slice_dims=(), - start_index_map=(2,), operand_batching_dims=(0, 1), - start_indices_batching_dims=(1, 0)), - (1, 1, 3)) ] for bdims in lax_test_util.all_bdims(shape, idxs.shape)], dtype=lax_test_util.all_dtypes @@ -602,16 +590,6 @@ def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), - ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), - scatter_indices_batching_dims=(0,))), - ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), - (3, 2, 3), lax.ScatterDimensionNumbers( - update_window_dims=(2,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), - scatter_indices_batching_dims=(1, 0))) ] for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape, update_shape)], dtype=lax_test_util.float_dtypes @@ -635,16 +613,6 @@ def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), - ((2, 5), np.array([[[0], [2]], [[1], [1]]]), (2, 2), - lax.ScatterDimensionNumbers( - update_window_dims=(), inserted_window_dims=(1,), - scatter_dims_to_operand_dims=(1,), operand_batching_dims=(0,), - scatter_indices_batching_dims=(0,))), - ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]), - (3, 2, 3), lax.ScatterDimensionNumbers( - update_window_dims=(2,), inserted_window_dims=(), - scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1), - scatter_indices_batching_dims=(1, 0))) ] for bdims in lax_test_util.all_bdims(arg_shape, idxs.shape)], dtype=lax_test_util.float_dtypes, From ee6fd5aeb2a1222063fb20797b9c2c4116486460 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 25 Sep 2024 09:47:31 -0700 Subject: [PATCH 666/702] Improve documentation for jnp.interp --- jax/_src/numpy/lax_numpy.py | 54 ++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 559d17cd9514..73e27245cfa9 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -35,7 +35,6 @@ import types from typing import (overload, Any, Literal, NamedTuple, Protocol, TypeVar, Union) -from textwrap import dedent as _dedent import warnings import numpy as np @@ -2316,15 +2315,58 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return f -@util.implements(np.interp, - lax_description=_dedent(""" - In addition to constant interpolation supported by NumPy, jnp.interp also - supports left='extrapolate' and right='extrapolate' to indicate linear - extrapolation instead.""")) def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, period: ArrayLike | None = None) -> Array: + """One-dimensional linear interpolation. + + JAX implementation of :func:`numpy.interp`. + + Args: + x: N-dimensional array of x coordinates at which to evaluate the interpolation. + xp: one-dimensional sorted array of points to be interpolated. + fp: array of shape ``xp.shape`` containing the function values associated with ``xp``. + left: specify how to handle points ``x < xp[0]``. Default is to return ``fp[0]``. + If ``left`` is a scalar value, it will return this value. if ``left`` is the string + ``"extrapolate"``, then the value will be determined by linear extrapolation. + ``left`` is ignored if ``period`` is specified. + right: specify how to handle points ``x > xp[-1]``. Default is to return ``fp[-1]``. + If ``right`` is a scalar value, it will return this value. if ``right`` is the string + ``"extrapolate"``, then the value will be determined by linear extrapolation. + ``right`` is ignored if ``period`` is specified. + period: optionally specify the period for the *x* coordinates, for e.g. interpolation + in angular space. + + Returns: + an array of shape ``x.shape`` containing the interpolated function at values ``x``. + + Examples: + >>> xp = jnp.arange(10) + >>> fp = 2 * xp + >>> x = jnp.array([0.5, 2.0, 3.5]) + >>> interp(x, xp, fp) + Array([1., 4., 7.], dtype=float32) + + Unless otherwise specified, extrapolation will be constant: + + >>> x = jnp.array([-10., 10.]) + >>> interp(x, xp, fp) + Array([ 0., 18.], dtype=float32) + + Use ``"extrapolate"`` mode for linear extrapolation: + + >>> interp(x, xp, fp, left='extrapolate', right='extrapolate') + Array([-20., 20.], dtype=float32) + + For periodic interpolation, specify the ``period``: + + >>> xp = jnp.array([0, jnp.pi / 2, jnp.pi, 3 * jnp.pi / 2]) + >>> fp = jnp.sin(xp) + >>> x = 2 * jnp.pi # note: not in input array + >>> jnp.interp(x, xp, fp, period=2 * jnp.pi) + Array(0., dtype=float32) + """ static_argnames = [] if isinstance(left, str) or left is None: static_argnames.append('left') From c93b272b784b4f41a41f2d592571ca6e90424985 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 25 Sep 2024 10:52:42 -0700 Subject: [PATCH 667/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a473d30392e2cea68dc90f95377de3f568ea2055. PiperOrigin-RevId: 678764121 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index dcc12e68eae3..4ad4e48c02d1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "bcc98dcd1cd334a1aa833a1055a840bcd2ac87f5" -XLA_SHA256 = "c69acc5dd6eef894a400a5ae9076d3b53c0586acbd7d5970e7f9556d28b28462" +XLA_COMMIT = "a473d30392e2cea68dc90f95377de3f568ea2055" +XLA_SHA256 = "30324095a4d9454b5a8fdf0397b62cfd6f06155a077ce93cf75b64fb78f98fc0" def repo(): tf_http_archive( From 37641dd4fade625563321b7e1e87165df23cf4a8 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 25 Sep 2024 10:56:30 -0700 Subject: [PATCH 668/702] [Mosaic TPU] Support bitcast without forcing retiling. PiperOrigin-RevId: 678765762 --- .../tpu/transforms/infer_vector_layout.cc | 44 ++++++++++--------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 2894b0797e7b..408731e89415 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -938,16 +938,16 @@ class VectorLayoutInferer { auto out_ty = cast(op.getOutput().getType()); auto in_bitwidth = in_ty.getElementTypeBitWidth(); auto out_bitwidth = out_ty.getElementTypeBitWidth(); - auto src_layout = getLayout(op.getInput()); - LayoutOffsets src_offsets = src_layout->offsets(); - auto implicit_dim = src_layout->implicit_dim(); - if (src_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) { + auto in_layout = getLayout(op.getInput()); + LayoutOffsets in_offsets = in_layout->offsets(); + auto implicit_dim = in_layout->implicit_dim(); + if (in_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) { // Force offset to zero if the input offset on the second minor dimension // is not a multiple of the ratio of output and input bitwidth. - src_offsets[0] = 0; - } else if (!src_offsets[0].has_value() && in_bitwidth > out_bitwidth) { + in_offsets[0] = 0; + } else if (!in_offsets[0].has_value() && in_bitwidth > out_bitwidth) { // We can't preserve replicated offset for decreasing bitwidth. - src_offsets[0] = 0; + in_offsets[0] = 0; } // Force implicit dim to None if the bitwidth changes. Because we expect 2nd // minor dim size ratio matches the bitwidth ratio in input and output. @@ -959,20 +959,24 @@ class VectorLayoutInferer { } implicit_dim = ImplicitDim::kNone; } - // TODO(b/348485035): Instead of forcing to native tiling, bitcast should - // keep the input tiling and infer bitcastable tiling for output. For - // example, it is valid to bitcast vector<8x128xi32> with tile (1, 128) to - // vector<8x128xbf16> with tile (2, 128). + auto in_tiling = in_layout->tiling(); + auto out_tiling = in_tiling; + auto out_offsets = in_offsets; + if (in_offsets[0].has_value()) { + out_offsets[0] = in_offsets[0].value() * in_bitwidth / out_bitwidth; + } + if ((in_tiling[0] * in_bitwidth) % out_bitwidth == 0) { + out_tiling[0] = out_tiling[0] * in_bitwidth / out_bitwidth; + } else { + // If the input sublane tiling is not bitcastable to output, we use native + // tiling for both input and output. + in_tiling = nativeTiling(in_bitwidth); + out_tiling = nativeTiling(out_bitwidth); + } + setLayout( - op, - VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth), - implicit_dim), - VectorLayout(out_bitwidth, - {src_offsets[0].has_value() - ? src_offsets[0].value() * in_bitwidth / out_bitwidth - : src_offsets[0], - src_offsets[1]}, - nativeTiling(out_bitwidth), implicit_dim)); + op, VectorLayout(in_bitwidth, in_offsets, in_tiling, implicit_dim), + VectorLayout(out_bitwidth, out_offsets, out_tiling, implicit_dim)); return success(); } From e05c37c667b624b875b50a2cf1ba629e4b3f5ea3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 25 Sep 2024 11:19:20 -0700 Subject: [PATCH 669/702] Finalize deprecation of pretty-printing utils in `jax.core.pp_*` PiperOrigin-RevId: 678775782 --- CHANGELOG.md | 2 ++ jax/core.py | 38 +++++++++++++------------------------- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c3c63b6ee04..063324f0179c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike arguments with `ndim != 1` are now deprecated, and in the future will result in an error. + * Internal pretty-printing tools `jax.core.pp_*` have been removed, after + being deprecated in JAX v0.4.30. * Deletion: * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation diff --git a/jax/core.py b/jax/core.py index 035dcdcdb7e8..90ef668b2493 100644 --- a/jax/core.py +++ b/jax/core.py @@ -154,19 +154,19 @@ ("jax.core.check_valid_jaxtype is deprecated. Instead, you can manually" " raise an error if core.valid_jaxtype() returns False."), _src_core.check_valid_jaxtype), - # Added 2024-06-12 - "pp_aval": ("jax.core.pp_aval is deprecated.", _src_core.pp_aval), - "pp_eqn": ("jax.core.pp_eqn is deprecated.", _src_core.pp_eqn), - "pp_eqn_rules": ("jax.core.pp_eqn_rules is deprecated.", _src_core.pp_eqn_rules), - "pp_eqns": ("jax.core.pp_eqns is deprecated.", _src_core.pp_eqns), - "pp_jaxpr": ("jax.core.pp_jaxpr is deprecated.", _src_core.pp_jaxpr), - "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range is deprecated.", _src_core.pp_jaxpr_eqn_range), - "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton is deprecated.", _src_core.pp_jaxpr_skeleton), - "pp_jaxprs": ("jax.core.pp_jaxprs is deprecated.", _src_core.pp_jaxprs), - "pp_kv_pair": ("jax.core.pp_kv_pair is deprecated.", _src_core.pp_kv_pair), - "pp_kv_pairs": ("jax.core.pp_kv_pairs is deprecated.", _src_core.pp_kv_pairs), - "pp_var": ("jax.core.pp_var is deprecated.", _src_core.pp_var), - "pp_vars": ("jax.core.pp_vars is deprecated.", _src_core.pp_vars), + # Finalized 2024-09-25; remove after 2024-12-25 + "pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None), + "pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None), + "pp_eqn_rules": ("jax.core.pp_eqn_rules was removed in JAX v0.4.34.", None), + "pp_eqns": ("jax.core.pp_eqns was removed in JAX v0.4.34.", None), + "pp_jaxpr": ("jax.core.pp_jaxpr was removed in JAX v0.4.34.", None), + "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range was removed in JAX v0.4.34.", None), + "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton was removed in JAX v0.4.34.", None), + "pp_jaxprs": ("jax.core.pp_jaxprs was removed in JAX v0.4.34.", None), + "pp_kv_pair": ("jax.core.pp_kv_pair was removed in JAX v0.4.34.", None), + "pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None), + "pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None), + "pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None), # Finalized 2024-05-13; remove after 2024-08-13 "DimSize": ( "jax.core.DimSize is deprecated. Use DimSize = int | Any.", @@ -201,18 +201,6 @@ check_type = _src_core.check_type check_valid_jaxtype = _src_core.check_valid_jaxtype non_negative_dim = _src_core.non_negative_dim - pp_aval = _src_core.pp_aval - pp_eqn = _src_core.pp_eqn - pp_eqn_rules = _src_core.pp_eqn_rules - pp_eqns = _src_core.pp_eqns - pp_jaxpr = _src_core.pp_jaxpr - pp_jaxpr_eqn_range = _src_core.pp_jaxpr_eqn_range - pp_jaxpr_skeleton = _src_core.pp_jaxpr_skeleton - pp_jaxprs = _src_core.pp_jaxprs - pp_kv_pair = _src_core.pp_kv_pair - pp_kv_pairs = _src_core.pp_kv_pairs - pp_var = _src_core.pp_var - pp_vars = _src_core.pp_vars else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) From 70346bda7412349db2af9a04a4ef410832e1e530 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 25 Sep 2024 11:24:58 -0700 Subject: [PATCH 670/702] [Pallas] Add scalar f32 downcast test cases. PiperOrigin-RevId: 678779025 --- tests/pallas/ops_test.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 8d242617efbb..ce9403bbc890 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -606,6 +606,30 @@ def kernel(x_ref, y_ref): y, y_ref = y.astype(np.float32), y_ref.astype(np.float32) np.testing.assert_allclose(y, y_ref, atol=0., rtol=0.) + @parameterized.parameters( + jnp.bfloat16, + jnp.float8_e5m2, + jnp.float8_e4m3fn, + ) + @jtu.skip_on_devices("gpu") + def test_scalar_downcast_float32(self, dtype): + + def kernel(x_ref, o_ref): + o_ref[0, 0] = x_ref[:][0, 0].astype(dtype) + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + result = self.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec((8, 128), lambda *_: (0, 0)), + ], + out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()), + out_shape=jax.ShapeDtypeStruct([1, 1], dtype), + grid=(1,), + )(x) + + np.testing.assert_array_equal(result[0, 0], x[0, 0].astype(dtype)) + @parameterized.product( shape=((64,), (8, 8)), dtype=(jnp.int32, jnp.int16, jnp.int8), From ce99c18a7461eb3d46bd3fb03e3838752a3da8a1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 25 Sep 2024 13:34:00 -0700 Subject: [PATCH 671/702] Remove `CC="/usr/lib/llvm-18/bin/clang"` from clang config in .bazelrc Restore `cuda_clang` config in .bazelrc PiperOrigin-RevId: 678828039 --- .bazelrc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.bazelrc b/.bazelrc index 458ce69fae8b..5b7bc653373b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -57,7 +57,6 @@ build:native_arch_posix --host_copt=-march=native build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 -build:clang --action_env=CC="/usr/lib/llvm-18/bin/clang" # Disable clang extention that rejects type definitions within offsetof. # This was added in clang-16 by https://reviews.llvm.org/D133574. # Can be removed once upb is updated, since a type definition is used within @@ -91,12 +90,14 @@ build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" # acceptable, because the workaround is "remove the nvidia-..." pip packages. # The list of CUDA pip packages that JAX depends on are present in setup.py. build:cuda --linkopt=-Wl,--disable-new-dtags -build:cuda --@local_config_cuda//:cuda_compiler=clang -build:cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" # This flag is needed to include CUDA libraries for bazel tests. test:cuda --@local_config_cuda//cuda:include_cuda_libs=true +build:cuda_clang --config=clang +build:cuda_clang --@local_config_cuda//:cuda_compiler=clang +build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" + # Build with NVCC for CUDA build:cuda_nvcc --config=cuda build:cuda_nvcc --config=clang @@ -202,7 +203,7 @@ build:rbe_linux --host_linkopt=-lm # Use the GPU toolchain until the CPU one is ready. # https://github.com/bazelbuild/bazel/issues/13623 build:rbe_cpu_linux_base --config=rbe_linux -build:rbe_cpu_linux_base --config=clang +build:rbe_cpu_linux_base --config=cuda_clang build:rbe_cpu_linux_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" build:rbe_cpu_linux_base --crosstool_top="@local_config_cuda//crosstool:toolchain" build:rbe_cpu_linux_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" From a1f2edc968507836638bdb224af7e310e7099cd9 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Wed, 25 Sep 2024 13:39:39 -0700 Subject: [PATCH 672/702] Fix make_remote_async_copy -> make_async_remote_copy in async doc. --- docs/pallas/async_note.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/pallas/async_note.md b/docs/pallas/async_note.md index 96370ee48625..42e32a074fd7 100644 --- a/docs/pallas/async_note.md +++ b/docs/pallas/async_note.md @@ -37,7 +37,7 @@ Now imagine we aren’t using XLA’s `ppermute` but have our own custom Pallas ```py def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem): right_neighbor = ... - descriptor = pltpu.make_remote_async_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor) + descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor) descriptor.start() descriptor.wait_send() descriptor.wait_recv() @@ -54,7 +54,7 @@ def add_one(x_ref, z_ref): def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem): right_neighbor = ... - descriptor = pltpu.make_remote_async_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor) + descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor) descriptor.start() # Explicitly schedule inner kernel between start/wait From 08629a423320c4552fe73d585ffdc3cdf8699860 Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Wed, 25 Sep 2024 16:39:02 -0500 Subject: [PATCH 673/702] [Mosaic GPU] Fix mbarrier inline ptx for newer CTK --- jax/experimental/mosaic/gpu/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 8d1d48eb94d1..4701a72490c8 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -617,7 +617,7 @@ def arrive_expect_tx(self, bytes: int | ir.Value): i32 = ir.IntegerType.get_signless(32) bytes = arith.index_cast(i32, bytes) - nvvm.mbarrier_arrive_expect_tx(self.get_ptr(), bytes) + nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes) def get_ptr(self): ptr = ir.Type.parse("!llvm.ptr<3>") From ad6c3a7f64c90472af50d59ea294d14a2aab5575 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 25 Sep 2024 14:41:13 -0700 Subject: [PATCH 674/702] Improve docs for jnp.pad --- jax/_src/numpy/lax_numpy.py | 120 ++++++++++++++++++++++++++++++++++-- jax/_src/numpy/util.py | 7 --- 2 files changed, 115 insertions(+), 12 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 73e27245cfa9..75bcbede6650 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3750,13 +3750,123 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, "not implemented modes") -@util.implements(np.pad, lax_description="""\ -Unlike numpy, JAX "function" mode's argument (which is another function) should return -the modified array. This is because Jax arrays are immutable. -(In numpy, "function" mode's argument should modify a rank 1 array in-place.) -""") def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], mode: str | Callable[..., Any] = "constant", **kwargs) -> Array: + """Add padding to an array. + + JAX implementation of :func:`numpy.pad`. + + Args: + array: array to pad. + pad_width: specify the pad width for each dimension of an array. Padding widths + may be separately specified for *before* and *after* the array. Options are: + + - ``int`` or ``(int,)``: pad each array dimension with the same number of values + both before and after. + - ``(before, after)``: pad each array with ``before`` elements before, and ``after`` + elements after + - ``((before_1, after_1), (before_2, after_2), ... (before_N, after_N))``: specify + distinct ``before`` and ``after`` values for each array dimension. + + mode: a string or callable. Supported pad modes are: + + - ``'constant'`` (default): pad with a constant value, which defaults to zero. + - ``'empty'``: pad with empty values (i.e. zero) + - ``'edge'``: pad with the edge values of the array. + - ``'wrap'``: pad by wrapping the array. + - ``'linear_ramp'``: pad with a linear ramp to specified ``end_values``. + - ``'maximum'``: pad with the maximum value. + - ``'mean'``: pad with the mean value. + - ``'median'``: pad with the median value. + - ``'minimum'``: pad with the minimum value. + - ``'reflect'``: pad by reflection. + - ``'symmetric'``: pad by symmetric reflection. + - ````: a callable function. See Notes below. + + constant_values: referenced for ``mode = 'constant'``. Specify the constant value + to pad with. + stat_length: referenced for ``mode in ['maximum', 'mean', 'median', 'minimum']``. + An integer or tuple specifying the number of edge values to use when calculating + the statistic. + end_values: referenced for ``mode = 'linear_ramp'``. Specify the end values to + ramp the padding values to. + reflect_type: referenced for ``mode in ['reflect', 'symmetric']``. Specify whether + to use even or odd reflection. + + Returns: + A padded copy of ``array``. + + Notes: + When ``mode`` is callable, it should have the following signature:: + + def pad_func(row: Array, pad_width: tuple[int, int], + iaxis: int, kwargs: dict) -> Array: + ... + + Here ``row`` is a 1D slice of the padded array along axis ``iaxis``, with the pad + values filled with zeros. ``pad_width`` is a tuple specifying the ``(before, after)`` + padding sizes, and ``kwargs`` are any additional keyword arguments passed to the + :func:`jax.numpy.pad` function. + + Note that while in NumPy, the function should modify ``row`` in-place, in JAX the + function should return the modified ``row``. In JAX, the custom padding function + will be mapped across the padded axis using the :func:`jax.vmap` transformation. + + See also: + - :func:`jax.numpy.resize`: resize an array + - :func:`jax.numpy.tile`: create a larger array by tiling a smaller array. + - :func:`jax.numpy.repeat`: create a larger array by repeating values of a smaller array. + + Examples: + + Pad a 1-dimensional array with zeros: + + >>> x = jnp.array([10, 20, 30, 40]) + >>> jnp.pad(x, 2) + Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32) + >>> jnp.pad(x, (2, 4)) + Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32) + + Pad a 1-dimensional array with specified values: + + >>> jnp.pad(x, 2, constant_values=99) + Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32) + + Pad a 1-dimensional array with the mean array value: + + >>> jnp.pad(x, 2, mode='mean') + Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32) + + Pad a 1-dimensional array with reflected values: + + >>> jnp.pad(x, 2, mode='reflect') + Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32) + + Pad a 2-dimensional array with different paddings in each dimension: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.pad(x, ((1, 2), (3, 0))) + Array([[0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 3], + [0, 0, 0, 4, 5, 6], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], dtype=int32) + + Pad a 1-dimensional array with a custom padding function: + + >>> def custom_pad(row, pad_width, iaxis, kwargs): + ... # row represents a 1D slice of the zero-padded array. + ... before, after = pad_width + ... before_value = kwargs.get('before_value', 0) + ... after_value = kwargs.get('after_value', 0) + ... row = row.at[:before].set(before_value) + ... return row.at[len(row) - after:].set(after_value) + >>> x = jnp.array([2, 3, 4]) + >>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10) + Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32) + """ + util.check_arraylike("pad", array) pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width") if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1]) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e9d1db26731c..9c9bc5d389e1 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -114,7 +114,6 @@ def _parse_parameters(body: str) -> dict[str, str]: def implements( original_fun: Callable[..., Any] | None, update_doc: bool = True, - lax_description: str = "", sections: Sequence[str] = ('Parameters', 'Returns', 'References'), skip_params: Sequence[str] = (), module: str | None = None, @@ -132,8 +131,6 @@ def implements( update_doc: whether to transform the numpy docstring to remove references of parameters that are supported by the numpy version but not the JAX version. If False, include the numpy docstring verbatim. - lax_description: a string description that will be added to the beginning of - the docstring. sections: a list of sections to include in the docstring. The default is ["Parameters", "Returns", "References"] skip_params: a list of strings containing names of parameters accepted by the @@ -146,8 +143,6 @@ def decorator(wrapped_fun): wrapped_fun.__np_wrapped__ = original_fun # Allows this pattern: @implements(getattr(np, 'new_function', None)) if original_fun is None: - if lax_description: - wrapped_fun.__doc__ = lax_description return wrapped_fun docstr = getattr(original_fun, "__doc__", None) name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun))) @@ -181,8 +176,6 @@ def decorator(wrapped_fun): docstr = parsed.summary.strip() + "\n" if parsed.summary else "" docstr += f"\nLAX-backend implementation of :func:`{name}`.\n" - if lax_description: - docstr += "\n" + lax_description.strip() + "\n" docstr += "\n*Original docstring below.*\n" # We remove signatures from the docstrings, because they redundant at best and From f1b3251bf94bd80e703b029f9dfd35916fe2304d Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 25 Sep 2024 15:39:09 -0700 Subject: [PATCH 675/702] Change `CLANG_CUDA_COMPILER_PATH` set order. Add `--config=cuda_clang` to build.py Set `--action_env=CLANG_CUDA_COMPILER_PATH` after cuda_nvcc configuration Add `--config=cuda_clang` when `--nouse_cuda_nvcc` flag set PiperOrigin-RevId: 678873849 --- build/build.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/build/build.py b/build/build.py index 42db37fd74af..de0d5a9817fb 100755 --- a/build/build.py +++ b/build/build.py @@ -284,11 +284,13 @@ def write_bazelrc(*, remote_build, f.write("build --config=mkl_open_source_only\n") if enable_cuda: f.write("build --config=cuda\n") + if use_cuda_nvcc: + f.write("build --config=cuda_nvcc\n") + else: + f.write("build --config=cuda_clang\n") f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") if not enable_nccl: f.write("build --config=nonccl\n") - if use_cuda_nvcc: - f.write("build --config=cuda_nvcc\n") if cuda_version: f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n" .format(cuda_version=cuda_version)) From e4ca4f5a573b531c412aac87d104a52d0edf42e4 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 25 Sep 2024 16:02:11 -0700 Subject: [PATCH 676/702] Roll back cl/678765762 [Mosaic TPU] Support bitcast without forcing retiling. Reverts 37641dd4fade625563321b7e1e87165df23cf4a8 PiperOrigin-RevId: 678881199 --- .../tpu/transforms/infer_vector_layout.cc | 44 +++++++++---------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 408731e89415..2894b0797e7b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -938,16 +938,16 @@ class VectorLayoutInferer { auto out_ty = cast(op.getOutput().getType()); auto in_bitwidth = in_ty.getElementTypeBitWidth(); auto out_bitwidth = out_ty.getElementTypeBitWidth(); - auto in_layout = getLayout(op.getInput()); - LayoutOffsets in_offsets = in_layout->offsets(); - auto implicit_dim = in_layout->implicit_dim(); - if (in_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) { + auto src_layout = getLayout(op.getInput()); + LayoutOffsets src_offsets = src_layout->offsets(); + auto implicit_dim = src_layout->implicit_dim(); + if (src_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) { // Force offset to zero if the input offset on the second minor dimension // is not a multiple of the ratio of output and input bitwidth. - in_offsets[0] = 0; - } else if (!in_offsets[0].has_value() && in_bitwidth > out_bitwidth) { + src_offsets[0] = 0; + } else if (!src_offsets[0].has_value() && in_bitwidth > out_bitwidth) { // We can't preserve replicated offset for decreasing bitwidth. - in_offsets[0] = 0; + src_offsets[0] = 0; } // Force implicit dim to None if the bitwidth changes. Because we expect 2nd // minor dim size ratio matches the bitwidth ratio in input and output. @@ -959,24 +959,20 @@ class VectorLayoutInferer { } implicit_dim = ImplicitDim::kNone; } - auto in_tiling = in_layout->tiling(); - auto out_tiling = in_tiling; - auto out_offsets = in_offsets; - if (in_offsets[0].has_value()) { - out_offsets[0] = in_offsets[0].value() * in_bitwidth / out_bitwidth; - } - if ((in_tiling[0] * in_bitwidth) % out_bitwidth == 0) { - out_tiling[0] = out_tiling[0] * in_bitwidth / out_bitwidth; - } else { - // If the input sublane tiling is not bitcastable to output, we use native - // tiling for both input and output. - in_tiling = nativeTiling(in_bitwidth); - out_tiling = nativeTiling(out_bitwidth); - } - + // TODO(b/348485035): Instead of forcing to native tiling, bitcast should + // keep the input tiling and infer bitcastable tiling for output. For + // example, it is valid to bitcast vector<8x128xi32> with tile (1, 128) to + // vector<8x128xbf16> with tile (2, 128). setLayout( - op, VectorLayout(in_bitwidth, in_offsets, in_tiling, implicit_dim), - VectorLayout(out_bitwidth, out_offsets, out_tiling, implicit_dim)); + op, + VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth), + implicit_dim), + VectorLayout(out_bitwidth, + {src_offsets[0].has_value() + ? src_offsets[0].value() * in_bitwidth / out_bitwidth + : src_offsets[0], + src_offsets[1]}, + nativeTiling(out_bitwidth), implicit_dim)); return success(); } From 8ffeb2388ab77dec262f4f07b843bdd9d896466a Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Thu, 26 Sep 2024 09:39:18 +0530 Subject: [PATCH 677/702] Better doc for jnp.trace --- jax/_src/numpy/lax_numpy.py | 45 ++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 73e27245cfa9..ca70ff35dd9f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6652,10 +6652,53 @@ def triu(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) -@util.implements(np.trace, skip_params=['out']) @partial(jit, static_argnames=('axis1', 'axis2', 'dtype')) def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: + """Calculate sum of the diagonal of input along the given axes. + + JAX implementation of :func:`numpy.trace`. + + Args: + a: input array. Must have ``a.ndim >= 2``. + offset: optional, int, default=0. Diagonal offset from the main diagonal. + Can be positive or negative. + axis1: optional, default=0. The first axis along which to take the sum of + diagonal. Must be a static integer value. + axis2: optional, default=1. The second axis along which to take the sum of + diagonal. Must be a static integer value. + dtype: optional. The dtype of the output array. Should be provided as static + argument in JIT compilation. + out: Not used by JAX. + + Returns: + An array of dimension x.ndim-2 containing the sum of the diagonal elements + along axes (axis1, axis2) + + See also: + - :func:`jax.numpy.diag`: Returns the specified diagonal or constructs a diagonal + array + - :func:`jax.numpy.diagonal`: Returns the specified diagonal of an array. + - :func:`jax.numpy.diagflat`: Returns a 2-D array with the flattened input array + laid out on the diagonal. + + Examples: + >>> x = jnp.arange(1, 9).reshape(2, 2, 2) + >>> x + Array([[[1, 2], + [3, 4]], + + [[5, 6], + [7, 8]]], dtype=int32) + >>> jnp.trace(x) + Array([ 8, 10], dtype=int32) + >>> jnp.trace(x, offset=1) + Array([3, 4], dtype=int32) + >>> jnp.trace(x, axis1=1, axis2=2) + Array([ 5, 13], dtype=int32) + >>> jnp.trace(x, offset=1, axis1=1, axis2=2) + Array([2, 6], dtype=int32) + """ util.check_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") From b6d668e0d7f3fa333a1f0cb96818db8de7879e46 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 26 Sep 2024 02:17:49 -0700 Subject: [PATCH 678/702] [pallas::mosaic_gpu] Turn the accumulator into a reference * Changes the accumulator into a reference * Creates a discharged flavor of the wgmma op * run_scoped lowering discharges the input jaxpr * dereferencing the accumulator ref is done by a new primitive that behaves as expected when discharged * the deref primitive implies flushing the wgmma pipeline. * run_scoped does not allow references to be leaked. PiperOrigin-RevId: 679056765 --- jax/_src/pallas/mosaic_gpu/__init__.py | 3 +- jax/_src/pallas/mosaic_gpu/core.py | 50 ++++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 66 ++++++++-- jax/_src/pallas/mosaic_gpu/primitives.py | 147 +++++++++++------------ tests/pallas/mosaic_gpu_test.py | 10 +- 5 files changed, 181 insertions(+), 95 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index ddf27361493a..bbada82ace82 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -18,14 +18,13 @@ from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace +from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wait_barrier from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem -from jax._src.pallas.mosaic_gpu.primitives import zero_accumulator from jax._src.pallas.mosaic_gpu.primitives import wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait GMEM = GPUMemorySpace.GMEM SMEM = GPUMemorySpace.SMEM -REGS = GPUMemorySpace.REGS diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index c4489226c860..df633619e6ee 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -203,3 +203,53 @@ def get_ref_aval(self) -> AbstractMemoryRef: [self.num_barriers], BarrierType(self.num_arrivals) ) return AbstractMemoryRef(aval, SMEM) + + +@dataclasses.dataclass(frozen=True) +class WGMMAAccumulatorRef: + shape: tuple[int, int] + dtype: jnp.dtype = jnp.float32 + + def get_ref_aval(self) -> AbstractMemoryRef: + return WGMMAAbstractAccumulatorRef( + jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), GPUMemorySpace.REGS + ) + + +def _is_trivial_index(idx): + _is_deref1 = lambda i: i is Ellipsis or i == slice(None) + if isinstance(idx, tuple): + return all(_is_deref1(i) for i in idx) + + return _is_deref1(idx) + +class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): + __slots__ = ["inner_aval", "memory_space"] + + def __repr__(self) -> str: + return f'Accumulator{{{self.inner_aval.str_short()}}}' + + def join(self, other): + return _as_accum(super().join(other)) + + def update(self, inner_aval=None, memory_space=None): + return _as_accum(super().update(inner_aval=None, memory_space=None)) + + def at_least_vspace(self): + return _as_accum(super().at_least_vspace()) + + def _getitem(self, tracer, idx): + if not _is_trivial_index(idx): + raise NotImplementedError(f"Can only dereference accumulators, not slice ({idx=}).") + from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error + return wgmma_accumulator_deref(tracer) + +def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: + return WGMMAAbstractAccumulatorRef( + inner_aval=ref.inner_aval, + memory_space=ref.memory_space, # pytype: disable=attribute-error + ) + +def _ref_raise_to_shaped(ref_aval, weak_type): + return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type)) +jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 702f5ca17e61..0c6d61fa9c96 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -28,6 +28,7 @@ from jax._src import core as jax_core from jax._src import pjit from jax._src import util +from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect @@ -37,6 +38,7 @@ from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.state import discharge from jax._src.state import primitives as sp import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core as mgpu_core @@ -615,6 +617,10 @@ def _swap_lowering_rule( del tree # Unused. if indexers: raise NotImplementedError("No support for indexers yet") + if not isinstance(value, mgpu.FragmentedArray): + raise TypeError(f"Can only store arrays (got {value}).") + if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): + raise TypeError(f"Can only store to references (got {value}).") x_aval, _ = ctx.avals_in old_value = mgpu.FragmentedArray.load_strided( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) @@ -735,15 +741,57 @@ def _(val, idx): def _run_scoped_lowering_rule( ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr ): - in_avals = [v.aval.inner_aval for v in jaxpr.invars] - bytes_allocated, input_refs = ctx.module_ctx.scratch_view([ - jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype) - for aval in in_avals - ]) - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts - ) - ctx.module_ctx.stack_free_smem(bytes_allocated) + input_refs = [] + bytes_allocated = 0 + should_discharge = [] + for a in jaxpr.invars: + a = a.aval + if isinstance(a, gpu_core.WGMMAAbstractAccumulatorRef): + mlir_dtype = mlir.dtype_to_ir_type(a.dtype) + input_refs.append(mgpu.WGMMAAccumulator.zero(*a.shape, mlir_dtype)) + should_discharge.append(True) + elif a.memory_space == gpu_core.SMEM: + ref_bytes, [input_ref] = ctx.module_ctx.scratch_view( + [jax.ShapeDtypeStruct(shape=a.shape, dtype=a.dtype)] + ) + bytes_allocated += ref_bytes + input_refs.append(input_ref) + should_discharge.append(False) + else: + raise ValueError(f"Can't convert to ref: {a}") + + if any(should_discharge): + # We convert consts to args, because we only have ir.Values and + # not JAX values during lowering. discharge_state() produces JAX + # valiues for the aguments but expects them to be provided for the + # consts. We also don't want to wrap the values in refs. + no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) + should_discharge = [False] * len(consts) + should_discharge + discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) + new_input_vals = consts + tuple(input_refs) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, discharged_jaxpr, new_input_vals, () + ) + # Discharge appends to the output the refs that got discharged. + outs = outs[:-sum(should_discharge)] + else: + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts + ) + + for o in outs: + # This is definitely one of the accumulators we produced. Each + # run_scoped call is responsible for dereferencing its own + # accumulators. + if isinstance(o, mgpu.WGMMAAccumulator) or ( + isinstance(o, ir.Value) and ir.MemRefType.isinstance(o.type) + ): + raise ValueError(f"No references are allowed to escape a scope. (got {o})") + + assert len(outs) == len(jaxpr.outvars), (jaxpr, outs) + if bytes_allocated: + ctx.module_ctx.stack_free_smem(bytes_allocated) + return outs diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index ef30dd0956ec..5295507ff2fa 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -16,12 +16,10 @@ from __future__ import annotations -import dataclasses - from jax._src import core as jax_core from jax._src import effects from jax._src import state -from jax._src.interpreters import mlir +from jax._src.state import discharge from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -118,62 +116,9 @@ class _WGMMAPipelineEffect(effects.Effect): _wgmma_pipeline_effect = _WGMMAPipelineEffect() effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect) - -# Not a shaped array to avoid unexpected operations. -class WGMMAAbstractAccumulator(jax_core.AbstractValue): - __slots__ = ['shape', 'dtype'] - - def __init__(self, shape, dtype): - self.shape = shape - self.dtype = dtype - - def __eq__(self, other): - return (type(self) is type(other) - and self.dtype == other.dtype and self.shape == other.shape) - - def __hash__(self): - return hash((self.shape, self.dtype)) - - def update(self, shape=None, dtype=None): - if shape is None: - shape = self.shape - if dtype is None: - dtype = self.dtype - return WGMMAAbstractAccumulator(shape, dtype) - - def str_short(self, short_dtypes=False) -> str: - del short_dtypes - shapestr = ",".join(map(str, self.shape)) - return f"Accumulator{{{self.dtype.name}}}[{shapestr}]" - -@dataclasses.dataclass(frozen=True) -class WGMMAAccumulator: - inner_aval: WGMMAAbstractAccumulator - - shape = property(lambda self: self.inner_aval.shape) - dtype = property(lambda self: self.inner_aval.dtype) - - def as_array(self) -> jax_core.ShapedArray: - return acc_to_shaped_array_p.bind(self.inner_aval) - - -jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulator] = lambda aval, _: aval - -acc_to_shaped_array_p = jax_core.Primitive("acc_to_shaped_array") - -@acc_to_shaped_array_p.def_abstract_eval -def _acc_to_shaped_array_abstract_eval(acc) -> jax_core.ShapedArray: - return jax_core.ShapedArray(shape=acc.shape, dtype=acc.dtype) - - -@lowering.register_lowering_rule(acc_to_shaped_array_p) -def _acc_to_shaped_array_lowering_rule( - ctx: lowering.LoweringRuleContext, acc -): - del ctx - return acc.value - -wgmma_p = jax_core.Primitive("wgmma") +# WGMMA on an accumulator reference +wgmma_ref_p = jax_core.Primitive("wgmma_ref") +wgmma_ref_p.multiple_results = True def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): """Asynchronous warp group matmul. @@ -189,8 +134,8 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): n_tile: The number of tiles to use. swizzle: The swizzle pattern. """ - if not isinstance(acc, WGMMAAccumulator): - raise TypeError(acc) + if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): + raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}") rhs_transpose = ( (jnp.dtype(b.dtype).itemsize == 2) @@ -208,18 +153,43 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb: raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}") - outval = wgmma_p.bind(acc.inner_aval, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose) - return WGMMAAccumulator(outval) + return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose) -@wgmma_p.def_effectful_abstract_eval -def _wgmma_effectful_abstract_eval(acc, *args, **kwargs): - del args, kwargs - return acc, { + +@wgmma_ref_p.def_effectful_abstract_eval +def _wgmma_ref_effectful_abstract_eval(acc, *args, **kwargs): + del acc, args, kwargs + return [], { _wgmma_pipeline_effect, + state.WriteEffect(0), + state.ReadEffect(0), state.ReadEffect(1), state.ReadEffect(2), } + +@discharge.register_discharge_rule(wgmma_ref_p) +def _wgmma_ref_discharge_rule( + in_avals, out_avals, + acc, + a, + b, + swizzle, + rhs_transpose, +): + del in_avals, out_avals + return ( + wgmma_p.bind( + acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose + ), + None, + None, + ), [] + + +# Functional WGMMA, returns a shaped array. Internal. +wgmma_p = jax_core.Primitive("wgmma") + @lowering.register_lowering_rule(wgmma_p) def _wgmma_lowering_rule( ctx: lowering.LoweringRuleContext, @@ -242,6 +212,15 @@ def _wgmma_lowering_rule( nvvm_dialect.wgmma_commit_group_sync_aligned() return new_acc +@wgmma_p.def_effectful_abstract_eval +def _wgmma_effectful_abstract_eval(acc, *args, **kwargs): + del args, kwargs + return acc, { + _wgmma_pipeline_effect, + state.ReadEffect(1), + state.ReadEffect(2), + } + wgmma_wait_p = jax_core.Primitive("wgmma_wait") wgmma_wait_p.multiple_results = True @@ -260,19 +239,29 @@ def _wgmma_wait_lowering_rule(ctx: lowering.LoweringRuleContext, allow_groups): nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) return () -zero_accumulator_p = jax_core.Primitive("zero_accumulator") -def zero_accumulator(shape, dtype): - return WGMMAAccumulator(zero_accumulator_p.bind(shape=shape, dtype=dtype)) +wgmma_accumulator_deref_p = jax_core.Primitive("wgmma_accumulator_deref_p") +def wgmma_accumulator_deref(acc): + """Dereferences an accumulator register.""" -@zero_accumulator_p.def_abstract_eval -def _zero_accumulator_abstract_eval(shape, dtype): - return WGMMAAbstractAccumulator(shape=shape, dtype=dtype) + if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): + raise TypeError(f"acc must be a WGMMAAccumulatorAbstractRef, got {acc.aval=}") + return wgmma_accumulator_deref_p.bind(acc) -@lowering.register_lowering_rule(zero_accumulator_p) -def _zero_accumulator_lowering_rule( - ctx: lowering.LoweringRuleContext, shape, dtype -): +@wgmma_accumulator_deref_p.def_effectful_abstract_eval +def _wgmma_accumulator_deref_abstract_eval(acc): + # Dereferencing implies flushing so we have a wgmma pipeline effect. + ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc + assert isinstance(ret, jax_core.ShapedArray), acc + return ret, {_wgmma_pipeline_effect} + +@discharge.register_discharge_rule(wgmma_accumulator_deref_p) +def _wgmma_accumulator_deref_discharge_rule(in_avals, out_avals, acc): + del in_avals, out_avals + return (None,), wgmma_accumulator_deref_p.bind(acc) + +@lowering.register_lowering_rule(wgmma_accumulator_deref_p) +def _wgmma_accumulator_deref_lowering_rule(ctx: lowering.LoweringRuleContext, acc): del ctx - m, n = shape - return mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=mlir.dtype_to_ir_type(jnp.dtype(dtype))) + nvvm_dialect.wgmma_wait_group_sync_aligned(0) + return acc.value diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 0eb7d91960ff..431dba133ce3 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -380,11 +380,11 @@ def test_wgmma(self, dtype): swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize def kernel(a_ref, b_ref, o_ref): - acc = plgpu.zero_accumulator((64, 128), jnp.float32) - acc = plgpu.wgmma(acc, a_ref, b_ref, rhs_transpose=rhs_transpose) - plgpu.wgmma_wait(0) - # TODO(cperivol): turn acc into a reference so we can reason about effects. - o_ref[...] = acc.as_array() + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref, b_ref, rhs_transpose=rhs_transpose) + return acc_ref[...] + + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32)) key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype) From 3c25da2c599aa9f4dc47ea95d0739a535d4e2374 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Sep 2024 03:40:24 -0700 Subject: [PATCH 679/702] [Pallas/Mosaic GPU] Replace tiling/transpose fields of GPUBlockSpec with a transform list PiperOrigin-RevId: 679079269 --- jax/_src/pallas/mosaic_gpu/__init__.py | 2 ++ jax/_src/pallas/mosaic_gpu/core.py | 14 +++++--------- tests/pallas/mosaic_gpu_test.py | 14 ++++++++++---- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index bbada82ace82..187a84478c65 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -18,6 +18,8 @@ from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace +from jax._src.pallas.mosaic_gpu.core import TilingTransform +from jax._src.pallas.mosaic_gpu.core import TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index df633619e6ee..0b79121c34c1 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -132,10 +132,8 @@ class GPUBlockMapping(pallas_core.BlockMapping): @dataclasses.dataclass class GPUBlockSpec(pallas_core.BlockSpec): - # TODO(justinfu): Replace tiling a list of transforms. - tiling: tuple[int, ...] | None = None - transpose_permutation: tuple[int, ...] | None = None - swizzle: int | None = None + transforms: MemoryRefTransform | tuple[MemoryRefTransform, ...] = () + swizzle: int | None = None # TODO: apaszke - Swizzle is also a transform. def to_block_mapping( self, @@ -155,11 +153,9 @@ def to_block_mapping( grid=grid, mapped_dims=mapped_dims, ) - transforms: tuple[pallas_core.MemoryRefTransform, ...] = () - if self.tiling is not None: - transforms += (TilingTransform(self.tiling),) - if self.transpose_permutation is not None: - transforms += (TransposeTransform(self.transpose_permutation),) + transforms = self.transforms + if not isinstance(transforms, tuple): + transforms = (transforms,) return GPUBlockMapping( block_shape=bm.block_shape, block_aval=bm.block_aval, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 431dba133ce3..5abf4e2e363e 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -326,11 +326,15 @@ def kernel(o_ref): ) def test_swizzled_blockspec_shapes(self): + @functools.partial( pl.pallas_call, in_specs=[ plgpu.GPUBlockSpec( - (128, 64), lambda *i: i, tiling=(64, 64), swizzle=128 + (128, 64), + lambda *i: i, + transforms=plgpu.TilingTransform((64, 64)), + swizzle=128, ), ], out_specs=pl.BlockSpec((2, 1, 64, 64), lambda i, j: (i, j, 64, 64)), @@ -390,20 +394,22 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype) b = jax.random.uniform(key2, shape=(128, 128), dtype=dtype) + rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) + if rhs_transpose: + rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) res = pl.pallas_call( kernel, in_specs=[ plgpu.GPUBlockSpec( (64, 128), lambda i, j: (i, j), - tiling=(64, elems_128b), + transforms=plgpu.TilingTransform((64, elems_128b)), swizzle=128, ), plgpu.GPUBlockSpec( (128, 128), lambda *i: i, - transpose_permutation=(1, 0, 2, 3) if rhs_transpose else None, - tiling=(elems_128b, elems_128b), + transforms=rhs_transforms, swizzle=128, ), ], From cf51ee7ef0ac0a8257a3c7f41fd46d8ecb58a52f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 26 Sep 2024 05:09:47 -0700 Subject: [PATCH 680/702] Improve documentation for jax.jacobian --- docs/jax.rst | 1 + jax/_src/api.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/jax.rst b/docs/jax.rst index a8781d31a448..ecfeaf29e3c0 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -91,6 +91,7 @@ Automatic differentiation grad value_and_grad + jacobian jacfwd jacrev hessian diff --git a/jax/_src/api.py b/jax/_src/api.py index bd8a951954ac..6d2fc4143066 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -680,7 +680,13 @@ def jacfun(*args, **kwargs): return jac_tree, aux return jacfun -jacobian = jacrev + + +def jacobian(fun: Callable, argnums: int | Sequence[int] = 0, + has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable: + """Alias of :func:`jax.jacrev`.""" + return jacrev(fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int) + _check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev") _check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev") From 57887732bedb9a6abbd13ce0a657c6ea50aae049 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Sep 2024 05:52:02 -0700 Subject: [PATCH 681/702] [Pallas/Mosaic GPU] Disable inference of sequential axis shapes They should just be specified in the grid, so we don't need to do this. It's also incorrect, because it's not guaranteed that each input is sliced in the same dimension by the sequential axis. PiperOrigin-RevId: 679114626 --- jax/_src/pallas/mosaic_gpu/lowering.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0c6d61fa9c96..0211b549a49e 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -404,26 +404,7 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: "Multiple sequential axes are not supported in Mosaic GPU lowering." ) [sequential_axis] = sequential_axes - if any( - b_gmem.shape[sequential_axis] % b_smem.shape[1 + sequential_axis] - for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) - if b_smem - ): - raise ValueError( - "Array dimensions along the sequential axis must be divisible by" - " the corresponding block dimensions." - ) - num_steps, *rest = { - b_gmem.shape[sequential_axis] // b_smem.shape[1 + sequential_axis] - for b_gmem, b_smem in zip(in_structs_gmem, in_structs_smem) - if b_smem - } - if rest: - raise ValueError( - "Array dimensions along the sequential axis must produce the same" - " number of steps when devided by the corresponding block" - " dimensions." - ) + num_steps = grid_mapping.grid[sequential_axis] else: num_steps = 1 From 8599dbc9b2c6e64c99da23726c87dfd819d92d88 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Sep 2024 07:39:33 -0700 Subject: [PATCH 682/702] [Pallas/Mosaic GPU] Implement a more comprehensive matmul kernel to see what we're still missing I annotated a number of issues in the test. To make the test run I also needed to add support for the accumulator reference allocation and discharge in the main lowering part. Ideally, we'd defer it all to run_scoped, but run_scoped can't allocate barriers... PiperOrigin-RevId: 679143948 --- jax/_src/pallas/mosaic_gpu/lowering.py | 66 +++++++++++++++++++++--- jax/_src/pallas/mosaic_gpu/primitives.py | 9 +--- tests/pallas/mosaic_gpu_test.py | 56 ++++++++++++++++++++ 3 files changed, 116 insertions(+), 15 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0211b549a49e..3e5c7403ed6d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -55,6 +55,7 @@ zip, unsafe_zip = util.safe_zip, zip partial = functools.partial +SMEM = gpu_core.SMEM _smem_estimators = {} @@ -330,6 +331,45 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): start_indices, [grid_mapping.num_inputs] ) + smem_scratch_it = iter(scratch_buffers_smem) + scratch_buffers_template = [] + should_discharge = [] + accs = [] + for aval in scratch_avals: + match aval: + case gpu_core.WGMMAAbstractAccumulatorRef(): + scratch_buffers_template.append(None) + should_discharge.append(True) + accs.append( + mgpu.WGMMAAccumulator.zero( + *aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype) + ) + ) + case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM: + scratch_buffers_template.append(next(smem_scratch_it)) + should_discharge.append(False) + case _: + raise NotImplementedError( + f"Unsupported scratch operand type: {aval}" + ) + assert not jaxpr.outvars + if any(should_discharge): + # User-visible WGMMA APIs use the effectful accumulator references, but we + # can't lower that directly to Mosaic GPU that uses pure dataflow for + # accumulators. So we have to discharge the effects first. + assert not jaxpr.constvars + should_discharge = ( + [False] * len(grid_mapping.block_mappings) + + should_discharge + + [False] * len(extra_barriers) + ) + with grid_mapping.trace_env(): + lowered_jaxpr, _ = discharge.discharge_state( + jaxpr, (), should_discharge=should_discharge + ) + else: + lowered_jaxpr = jaxpr + # Precompute the total number of bytes transferred from GMEM to SMEM, # so that we can do a single arrive instruction for all of the inputs. in_transfer_bytes = 0 @@ -414,8 +454,8 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: for idx in range(grid_mapping.num_inputs): fetch(idx, _as_index(slot), _as_index(slot)) - @mgpu.fori(_as_index(num_steps), ()) - def _(step, _): + @mgpu.fori(_as_index(num_steps), accs) + def _(step, accs): slot = arith_dialect.remui(step, _as_index(num_stages)) if grid_mapping.num_inputs: # Only wait if async copies were issued. @@ -427,9 +467,18 @@ def _(step, _): else buffers_gmem[idx] for idx, in_smem in enumerate(it.chain(in_in_smem, out_in_smem)) ] - args.extend(scratch_buffers_smem) + accs_it = iter(accs) + scratch_buffers = [ + b if b is not None else next(accs_it) + for b in scratch_buffers_template + ] + args.extend(scratch_buffers) + # TODO(apaszke): This assumes barriers come after buffers in scratch args, + # but that's not necessarily true. args.extend(extra_barriers) - _ = lower_jaxpr_to_mosaic_gpu(module_ctx, launch_ctx, jaxpr, args) + new_accs = lower_jaxpr_to_mosaic_gpu( + module_ctx, launch_ctx, lowered_jaxpr, args + ) mgpu.commit_shared() with mgpu.single_thread(): @@ -445,20 +494,22 @@ def _(step, _): fetch(idx, next_step, slot) barriers[slot].arrive_expect_tx(in_transfer_bytes) - return () + return list(new_accs) launch_ctx.await_async_copy(0) scratch_avals = [ var.aval for var in jaxpr.invars[grid_mapping.slice_scratch_ops] ] + local_spaces = (gpu_core.SMEM, gpu_core.REGS) if not all( isinstance(aval, pallas_core.AbstractMemoryRef) - and aval.memory_space is gpu_core.SMEM + and aval.memory_space in local_spaces for aval in scratch_avals ): raise TypeError( - f"All scratch operands must be in SMEM, but got: {scratch_avals}" + "All scratch operands must be SMEM references or accumulators (ACC)," + f" but got: {scratch_avals}" ) extra_barriers = [ mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape) @@ -469,6 +520,7 @@ def _(step, _): jax.ShapeDtypeStruct(aval.shape, aval.dtype) for aval in scratch_avals if not isinstance(aval.dtype, gpu_core.BarrierType) + and aval.memory_space == gpu_core.SMEM ] smem_scratch_bytes = compiler_params.get("smem_scratch_bytes") if smem_scratch_bytes is None: diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 5295507ff2fa..dcec631e389b 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -25,7 +25,6 @@ from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering import jax.experimental.mosaic.gpu as mgpu -import jax.numpy as jnp async_copy_p = jax_core.Primitive("async_copy") async_copy_p.multiple_results = True @@ -120,7 +119,7 @@ class _WGMMAPipelineEffect(effects.Effect): wgmma_ref_p = jax_core.Primitive("wgmma_ref") wgmma_ref_p.multiple_results = True -def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): +def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128): """Asynchronous warp group matmul. The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires @@ -137,12 +136,6 @@ def wgmma(acc, a, b, *, rhs_transpose: bool | None = None, swizzle: int = 128): if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}") - rhs_transpose = ( - (jnp.dtype(b.dtype).itemsize == 2) - if rhs_transpose is None - else rhs_transpose - ) - ma, ka, tma, tka = a.shape kb, nb, tkb, tnb = b.shape mc, nc = acc.shape diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 5abf4e2e363e..7c2bb12b4607 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -437,6 +437,62 @@ def kernel(a_ref, b_ref): )(a) np.testing.assert_array_equal(b, np.ones_like(a)) + def test_realistic_matmul(self): + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + # TODO(apaszke): Make the grid and tile sizes larger + # grid_m, grid_k, grid_n = 132, 10, 4 + # TODO(apaszke): Increasing grid_k causes th test to fail. + # It seems like our pipelining implementation has a number of races. + grid_m, grid_k, grid_n = 2, 1, 2 + # tile_m = tile_n = 128 + tile_m = tile_n = 64 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + def kernel(a_ref, b_ref, o_ref, acc_ref): + plgpu.wgmma(acc_ref, a_ref, b_ref) + plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races + # TODO(apaszke): Only store in the last step. It doesn't work because we + # don't have partial discharge for control flow. + # is_last_step = pl.program_id(2) == grid_k - 1 + # @pl.when(is_last_step) + # def _epilogue(): + # pl.debug_print("{}", acc_ref[...]) + # TODO(apaszke): This is an untiled store! It's slow!! + o_ref[...] = acc_ref[...] + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) + b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + transforms=plgpu.TilingTransform((64, elems_128b)), + swizzle=128, + ), + plgpu.GPUBlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + transforms=plgpu.TilingTransform((elems_128b, elems_128b)), + swizzle=128, + ), + ], + out_specs=plgpu.GPUBlockSpec((tile_m, tile_n), lambda m, n, k: (m, n)), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], + grid=(grid_m, grid_n, grid_k), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "parallel", "sequential"], + num_stages=2, + ), + )(a, b) + np.testing.assert_allclose(res, a @ b, rtol=1e-3) + if __name__ == "__main__": absltest.main() From 0a66e2d0a4fe7bc39ea1386d32c779cd875806b8 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Sep 2024 07:57:09 -0700 Subject: [PATCH 683/702] [Pallas/MGPU] Fix a race in the pipelining code We never checked if the output windows are done writing before we reused them. Also, rename num_stages to max_concurrent_steps since we always only have 2 stages, but might be running multiple iterations at a time. Also fix the test for this that has been passing for reasons that I don't understand (it didn't even write to all entries in the output??). PiperOrigin-RevId: 679148961 --- jax/_src/pallas/mosaic_gpu/core.py | 6 +++--- jax/_src/pallas/mosaic_gpu/lowering.py | 22 ++++++++++++++-------- tests/pallas/mosaic_gpu_test.py | 10 +++++----- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 0b79121c34c1..fe8daf43e995 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -41,13 +41,13 @@ class GPUCompilerParams(pallas_core.CompilerParams): dimension of the kernel. Either "parallel" for dimensions that can execute in any order, or "sequential" for dimensions that must be executed sequentially. - num_stages: The number of pipline stages in the kernel. Defaults to 1, - meaning no pipelining is done. + max_concurrent_steps: The maximum number of sequential stages that are + active concurrently. Defaults to 1. """ PLATFORM: ClassVar[str] = "mosaic_gpu" approx_math: bool = False dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None - num_stages: int = 1 + max_concurrent_steps: int = 1 class GPUMemorySpace(enum.Enum): diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 3e5c7403ed6d..902624192dcf 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -243,7 +243,7 @@ def lower_jaxpr_to_module( block = (128,) + (1,) * (len(grid) - 1) params = compiler_params.get("mosaic_gpu", {}) approx_math = params.get("approx_math", False) - num_stages = params.get("num_stages", 1) + max_concurrent_steps = params.get("max_concurrent_steps", 1) dimension_semantics = params.get("dimension_semantics") if dimension_semantics is None: dimension_semantics = ["parallel"] * len(grid_mapping.grid) @@ -272,7 +272,9 @@ def lower_jaxpr_to_module( for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] ] in_structs_smem = [ - jax.ShapeDtypeStruct([num_stages, *bm.ref_aval.shape], bm.ref_aval.dtype) + jax.ShapeDtypeStruct( + [max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype + ) if in_smem else None for bm, in_smem in zip( @@ -293,7 +295,7 @@ def lower_jaxpr_to_module( out_structs_gmem = [*grid_mapping.out_shapes] # TODO(justinfu): Implement output Memref transforms out_structs_smem = [ - jax.ShapeDtypeStruct([num_stages, *bm.block_shape], s.dtype) + jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype) if in_smem else None for bm, in_smem, s in zip( @@ -449,17 +451,20 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: num_steps = 1 with mgpu.single_thread(): - for slot in range(min(num_stages, num_steps)): + for slot in range(min(max_concurrent_steps, num_steps)): barriers[slot].arrive_expect_tx(in_transfer_bytes) for idx in range(grid_mapping.num_inputs): fetch(idx, _as_index(slot), _as_index(slot)) @mgpu.fori(_as_index(num_steps), accs) def _(step, accs): - slot = arith_dialect.remui(step, _as_index(num_stages)) + slot = arith_dialect.remui(step, _as_index(max_concurrent_steps)) if grid_mapping.num_inputs: # Only wait if async copies were issued. barriers[slot].wait() + # We need to make sure the output copy is complete before the kernel starts + # writing to the output window. + launch_ctx.await_async_copy(max_concurrent_steps - 1) args = [ mgpu.memref_slice(buffers_smem[idx], slot) @@ -485,13 +490,14 @@ def _(step, accs): for idx in range(grid_mapping.num_outputs): store(idx, step, slot) - next_step = arith_dialect.addi(step, _as_index(num_stages)) + next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps)) next_step_in_bounds = arith_dialect.cmpi( arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps) ) + next_slot = slot # (x + y) % y == x % y with mgpu.when(next_step_in_bounds), mgpu.single_thread(): for idx in range(grid_mapping.num_inputs): - fetch(idx, next_step, slot) + fetch(idx, next_step, next_slot) barriers[slot].arrive_expect_tx(in_transfer_bytes) return list(new_accs) @@ -540,7 +546,7 @@ def _(step, accs): (*in_structs_smem, *out_structs_smem), *extra_smem_scratch, ( - mgpu.Barrier(arrival_count=1, num_barriers=num_stages), + mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps), *extra_barriers, ), ), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7c2bb12b4607..3018e0383188 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -96,8 +96,8 @@ def kernel(x_ref, o_ref, scratch_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) - @parameterized.product(num_stages=[1, 2, 3]) - def test_add_one_grid_pipelined(self, num_stages): + @parameterized.product(max_concurrent_steps=[1, 2, 3, 4]) + def test_add_one_grid_pipelined(self, max_concurrent_steps): @functools.partial( pl.pallas_call, @@ -106,9 +106,9 @@ def test_add_one_grid_pipelined(self, num_stages): out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), compiler_params=plgpu.GPUCompilerParams( dimension_semantics=["parallel", "sequential"], - num_stages=num_stages, + max_concurrent_steps=max_concurrent_steps, ), - grid=(2, 1), + grid=(2, 4), ) def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] + 1.0 @@ -488,7 +488,7 @@ def kernel(a_ref, b_ref, o_ref, acc_ref): grid=(grid_m, grid_n, grid_k), compiler_params=plgpu.GPUCompilerParams( dimension_semantics=["parallel", "parallel", "sequential"], - num_stages=2, + max_concurrent_steps=2, ), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) From 5cef547eab6edb8a9970f912e637bf45d69887b6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 26 Sep 2024 08:20:15 -0700 Subject: [PATCH 684/702] Added support for `lax.cond_p` to Pallas Mosaic GPU lowering PiperOrigin-RevId: 679156819 --- jax/_src/pallas/mosaic_gpu/lowering.py | 60 ++++++++++++++++++- .../mosaic/gpu/fragmented_array.py | 4 +- jax/experimental/mosaic/gpu/utils.py | 4 +- tests/pallas/mosaic_gpu_test.py | 21 +++++++ 4 files changed, 85 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 902624192dcf..cc19d96aa194 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -34,6 +34,7 @@ from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect +from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils @@ -677,6 +678,22 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): ) +@register_lowering_rule(lax.select_n_p) +def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): + if len(cases) != 2: + raise NotImplementedError( + "Mosaic GPU lowering only supports select_n with 2 cases, got" + f" {len(cases)}" + ) + pred_aval, *cases_avals = ctx.avals_in + [out_aval] = ctx.avals_out + pred = _ensure_fa(pred, pred_aval.dtype) + cases = _bcast(*cases, *cases_avals, out_aval) + # ``select`` expects the first case to be the true branch, but ``select_n`` + # orders the cases in reverse. + return pred.select(*reversed(cases)) + + @register_lowering_rule(lax.broadcast_in_dim_p) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, @@ -712,6 +729,16 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), lax.div_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x / y), + lax.rem_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x % y), + lax.and_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x & y), + lax.or_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x | y), + lax.xor_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x ^ y), + lax.gt_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x > y), + lax.lt_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x < y), + lax.ge_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x >= y), + lax.le_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x <= y), + lax.eq_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x == y), + lax.ne_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x != y), }) @@ -909,13 +936,41 @@ def _scan_lowering_rule( return for_out +@register_lowering_rule(lax.cond_p) +def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): + index_aval, *_arg_avals = ctx.avals_in + switch_op = scf_dialect.IndexSwitchOp( + map(mgpu_utils.dtype_to_ir_type, ctx.avals_out), + _as_index(_ensure_ir_value(index, index_aval.dtype)), + ir.DenseI64ArrayAttr.get(range(len(branches) - 1)), + num_caseRegions=len(branches) - 1, + ) + + # ``RegionSequence`` in MLIR does not support slicing, so the + # auto-generated Python bindings for ``caseRegions`` fail at runtime! + # We convert it to a list to work around that. + regions = list(switch_op.regions) + # Move the default region to the back. + regions = regions[1:] + regions[:1] + for branch, region in zip(branches, regions): + with ir.InsertionPoint(region.blocks.append()): + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args + ) + scf_dialect.yield_([ + _ensure_ir_value(out, aval.dtype) + for out, aval in zip(outs, ctx.avals_out) + ]) + return list(switch_op.results) + + def _bcast( x: ir.Value, y: ir.Value, x_aval: jax_core.ShapedArray, y_aval: jax_core.ShapedArray, out_aval: jax_core.ShapedArray, -) -> ir.Value: +) -> tuple[mgpu.FragmentedArray, mgpu.FragmentedArray]: if not isinstance(x, mgpu.FragmentedArray): x_dtype = x_aval.dtype if x_aval.weak_type: @@ -935,6 +990,7 @@ def _bcast( def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: if isinstance(x, mgpu.FragmentedArray): + assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype) return x elif isinstance(x, (np.number, np.ndarray, int, float)): return mgpu.FragmentedArray.splat( @@ -944,12 +1000,14 @@ def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: ) elif isinstance(x, ir.Value): if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)): + assert x.type == mgpu_utils.dtype_to_ir_type(dtype) return mgpu.FragmentedArray.splat(x, (), is_signed=mgpu_utils.is_signed(dtype)) raise NotImplementedError(f"Unsupported type: {type(x)}") def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value: if isinstance(x, ir.Value): + assert x.type == mgpu_utils.dtype_to_ir_type(dtype) return x elif isinstance(x, (np.number, np.ndarray, int, float)): return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ea45a7dcb7b9..ae6c40b9416d 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -784,13 +784,13 @@ def broadcast_minor(self, n): _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed ) - def select(self, x, y): + def select(self, on_true, on_false): if ( not ir.IntegerType.isinstance(self.mlir_dtype) or ir.IntegerType(self.mlir_dtype).width != 1 ): raise NotImplementedError - return self._pointwise(arith.select, x, y) + return self._pointwise(arith.select, on_true, on_false) def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]): """Call a function for each value and index.""" diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 4701a72490c8..b5c22734b0d4 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -951,6 +951,8 @@ def dtype_to_ir_type(dtype: jax.typing.DTypeLike) -> ir.Type: def is_signed(dtype: jax.typing.DTypeLike) -> bool | None: - if jnp.issubdtype(dtype, jnp.integer): + if jnp.issubdtype(dtype, jnp.bool_): + return False + elif jnp.issubdtype(dtype, jnp.integer): return jnp.issubdtype(dtype, jnp.signedinteger) return None diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 3018e0383188..267ddd0c97d5 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -376,6 +376,27 @@ def kernel(o_ref): kernel(), jnp.full([256], 5.0, dtype=jnp.float32) ) + def test_cond(self): + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(x_ref, o_ref): + acc = x_ref[...].sum() + jax.lax.cond( + acc % 2 == 0, + lambda: pl.debug_print("acc * 2: {}", acc * 2), + lambda: pl.debug_print("acc: {}", acc), + ) + o_ref[...] = jnp.broadcast_to(acc, o_ref.shape) + + x = jnp.arange(256) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertIn("acc * 2:", output()) + @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): # TensorCores can only fuse transposes of 16-bit values, and RHS From 7b53c2f39d4bd2c971b7364c8f1ba5af2d284908 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 26 Sep 2024 08:38:46 -0700 Subject: [PATCH 685/702] Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class. Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API. PiperOrigin-RevId: 679163106 --- CHANGELOG.md | 4 ++++ docs/errors.rst | 1 + jax/errors.py | 5 +++++ jax/lib/xla_client.py | 30 ++++++++++++++++++++---------- tests/errors_test.py | 16 ++++++++++++---- tests/package_structure_test.py | 2 +- 6 files changed, 43 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 063324f0179c..a424e645144c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * New Functionality * This release includes wheels for Python 3.13. Free-threading mode is not yet supported. + * `jax.errors.JaxRuntimeError` has been added as a public alias for the + formerly private `XlaRuntimeError` type. * Breaking changes * `jax_pmap_no_rank_reduction` flag is set to `True` by default. @@ -32,6 +34,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. in an error. * Internal pretty-printing tools `jax.core.pp_*` have been removed, after being deprecated in JAX v0.4.30. + * `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use + `jax.errors.JaxRuntimeError` instead. * Deletion: * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation diff --git a/docs/errors.rst b/docs/errors.rst index 96e14ed8d817..9965d6698bd4 100644 --- a/docs/errors.rst +++ b/docs/errors.rst @@ -9,6 +9,7 @@ along with representative examples of how one might fix them. .. currentmodule:: jax.errors .. autoclass:: ConcretizationTypeError .. autoclass:: KeyReuseError +.. autoclass:: JaxRuntimeError .. autoclass:: NonConcreteBooleanIndexError .. autoclass:: TracerArrayConversionError .. autoclass:: TracerBoolConversionError diff --git a/jax/errors.py b/jax/errors.py index 2a811661d1ae..6da7b717cb5f 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -26,4 +26,9 @@ UnexpectedTracerError as UnexpectedTracerError, KeyReuseError as KeyReuseError, ) + +from jax._src.lib import xla_client as _xc +JaxRuntimeError = _xc.XlaRuntimeError +del _xc + from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 7422e9fcc56d..9c57d82e837b 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -37,26 +37,36 @@ Traceback = _xc.Traceback XlaBuilder = _xc.XlaBuilder XlaComputation = _xc.XlaComputation -XlaRuntimeError = _xc.XlaRuntimeError _deprecations = { - # Added Aug 5 2024 - "_xla" : ( - "jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.", - _xc._xla - ), - "bfloat16" : ( - "jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.", - _xc.bfloat16 - ), + # Added Aug 5 2024 + "_xla": ( + "jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.", + _xc._xla, + ), + "bfloat16": ( + "jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.", + _xc.bfloat16, + ), + # Added Sep 26 2024 + "XlaRuntimeError": ( + ( + "jax.lib.xla_client.XlaRuntimeError is deprecated; use" + " jax.errors.JaxRuntimeError." + ), + _xc.XlaRuntimeError, + ), } import typing as _typing + if _typing.TYPE_CHECKING: _xla = _xc._xla bfloat16 = _xc.bfloat16 + XlaRuntimeError = _xc.XlaRuntimeError else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing diff --git a/tests/errors_test.py b/tests/errors_test.py index fa2dec95f0fa..7dfc4e51a6de 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -394,11 +394,19 @@ def test_grad_norm(self): class CustomErrorsTest(jtu.JaxTestCase): + @jtu.sample_product( - errorclass=[ - errorclass for errorclass in dir(jax.errors) - if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError'] - ], + errorclass=[ + errorclass + for errorclass in dir(jax.errors) + if errorclass.endswith('Error') + and errorclass + not in [ + 'JaxIndexError', + 'JAXTypeError', + 'JaxRuntimeError', + ] + ], ) def testErrorsURL(self, errorclass): class FakeTracer(core.Tracer): diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index e9944ec084af..71d48c2b121c 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -31,7 +31,7 @@ class PackageStructureTest(jtu.JaxTestCase): @parameterized.parameters([ # TODO(jakevdp): expand test to other public modules. - _mod("jax.errors"), + _mod("jax.errors", exclude=["JaxRuntimeError"]), _mod("jax.nn.initializers"), _mod( "jax.tree_util", From e62a50cd34bf88d87f36a3930b2dc6a0496e75d6 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 26 Sep 2024 08:44:58 -0700 Subject: [PATCH 686/702] #sdy add JAX Shardy support for shard_map. For example the following JAX program: ```py devices = np.array(jax.devices()[:8]) mesh = Mesh(devices, axis_names=('x')) a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None))) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): axis_size = lax.psum(1, 'x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) print(jax.jit(fwd).lower(a).as_text()) ``` prints: ```cpp module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["x"=8]> func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32> return %0 : tensor<8x8xi32> } func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) { %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) { %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32> sdy.return %1 : tensor<1x8xi32> } : (tensor<8x8xi32>) -> tensor<8x8xi32> return %0 : tensor<8x8xi32> } } ``` PiperOrigin-RevId: 679165100 --- jax/experimental/shard_map.py | 64 ++++++++++++++++++++++++++++ tests/BUILD | 5 +++ tests/shard_map_test.py | 78 ++++++++++++++++++++++++++++------- 3 files changed, 132 insertions(+), 15 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 10d4874d7329..0dace1977dc0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -51,6 +51,8 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, special, control_flow, ann) +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo, sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3, as_hashable_function, memoize, partition_list, merge_lists, split_list, subs_list2) @@ -643,9 +645,71 @@ def _rule_missing(prim: core.Primitive, *_, **__): # Lowering +def _shardy_shard_map_sharding( + ctx: mlir.LoweringRuleContext, mesh, names, aval_in + ) -> ir.Attribute: + axes = {name: i for i, ns in names.items() for name in ns} + ns = _make_scoped_manual_sharding(ctx, mesh, axes) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + return ns._to_sdy_sharding(aval_in.ndim).build() + + +def _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto): + in_avals_ = [v.aval for v in jaxpr.invars] + if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): + # Nested `ManualComputationOp`s cannot refer to axes that are already + # manual. So figure out what axes are free thus far and get the new axis + # context. + free_axis = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes + new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axis - auto) + else: + new_axis_context = sharding_impls.SPMDAxisContext( + mesh, frozenset(mesh.axis_names) - auto) + sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + args = (*ctx.dim_var_values, *in_nodes) + + manual_axes = sub_ctx.axis_context.manual_axes + mesh_shape = mesh.shape + manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes]) + if manual_axes_size == 1: + # No need for a `ManualComputationOp` if all manual axes are size 1. + out_nodes, _ = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args, + dim_var_values=ctx.dim_var_values) + return out_nodes + + in_shardings = sdy.TensorShardingPerValueAttr.get(map( + partial(_shardy_shard_map_sharding, ctx, mesh), + in_names, ctx.avals_in)) + out_shardings = sdy.TensorShardingPerValueAttr.get(map( + partial(_shardy_shard_map_sharding, ctx, mesh), + out_names, ctx.avals_out)) + output_types = map(mlir.aval_to_ir_type, ctx.avals_out) + manual_computation_op = sdy.ManualComputationOp( + output_types, args, in_shardings, out_shardings, + sdy.ManualAxesAttr.get( + ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) + block = ir.Block.create_at_start( + manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) + with ir.InsertionPoint(block), core.extend_axis_env_nd( + tuple(mesh.shape.items())): + out_nodes_, _ = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, + dim_var_values=ctx.dim_var_values) + sdy.ReturnOp([ir.Value(x) for x in out_nodes_]) + + return manual_computation_op.results + + def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, check_rep, rewrite, auto): del check_rep, rewrite + if config.use_shardy_partitioner.value: + return _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto) in_avals_ = [v.aval for v in jaxpr.invars] out_avals_ = [x.aval for x in jaxpr.outvars] in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, diff --git a/tests/BUILD b/tests/BUILD index 46345b6475d9..e580fb0ae363 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1346,6 +1346,11 @@ jax_multiplatform_test( jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], + enable_configs = [ + "gpu_2gpu_shardy", + "tpu_v3_2x2_shardy", + "tpu_v4_2x2_shardy", + ], shard_count = { "cpu": 50, "gpu": 10, diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 397f2d94c7f7..0c1155ddf1ab 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -848,6 +848,10 @@ def test_shmap_abstract_mesh_errors(self): @parameterized.parameters([True, False]) @jtu.run_on_devices('cpu', 'gpu', 'tpu') def test_debug_print_jit(self, jit): + if config.use_shardy_partitioner.value: + self.skipTest( + 'TODO(b/364547005): debug prints not supported by Shardy yet' + ) mesh = Mesh(jax.devices(), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) @@ -1229,13 +1233,18 @@ def foo(x): return x hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo')) - self.assertIn("call @shmap_body", hlo_str) - self.assertIn("call @shmap_body_0", hlo_str) - self.assertIn("%arg0: tensor<1xf32>", hlo_str) - self.assertIn("\"[None]\"", hlo_str) - self.assertIn("%arg1: tensor<1xf32>", hlo_str) - self.assertIn("\"[('i',)]\"", hlo_str) - self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", hlo_str) + if config.use_shardy_partitioner.value: + self.assertEqual(2, hlo_str.count('sdy.manual_computation')) + else: + self.assertIn('call @shmap_body', hlo_str) + self.assertIn('call @shmap_body_0', hlo_str) + self.assertIn('%arg0: tensor<1xf32>', hlo_str) + self.assertIn('"[None]"', hlo_str) + self.assertIn('%arg1: tensor<1xf32>', hlo_str) + self.assertIn('"[(\'i\',)]"', hlo_str) + self.assertIn( + '-> (tensor<1xf32> {jax.result_info = "[(\'i\',)]"})', hlo_str + ) def test_rewrite_process_call(self): def f(x): @@ -1759,10 +1768,18 @@ def f(x): v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertIn( - 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}', - f.lower(v).as_text('hlo'), - ) + if config.use_shardy_partitioner.value: + self.assertIn( + 'in_shardings=[<@mesh, [{"i"}, {}]>] out_shardings=[<@mesh, [{"i"},' + ' {}]>] manual_axes={"i"}', + f.lower(v).as_text(), + ) + else: + self.assertIn( + 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual,' + ' replicated}}', + f.lower(v).as_text('hlo'), + ) self.assertAllClose(v*v, f(v), check_dtypes=False) def test_sharded_prng_with_abstract_mesh(self): @@ -1909,6 +1926,11 @@ def f(): self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) def test_partial_auto_of_pjit_different_mesh(self): + if config.use_shardy_partitioner.value: + self.skipTest( + 'Shardy requires the mesh axis names to be the same across ' + 'the entire computation.' + ) mesh = jtu.create_mesh((2, 2), ('i', 'j')) mesh2 = jax.sharding.Mesh(mesh.devices, ('k', 'l')) @@ -1977,10 +1999,14 @@ def f(x): xs = jnp.arange(16.) ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs) - self.assertIn( - '{jax.result_info = "[(\'i\', \'j\', \'k\', \'a\')]"}', - ir.as_text() - ) + if config.use_shardy_partitioner.value: + self.assertIn( + 'out_shardings=[<@mesh, [{"i", "j", "k", "a"}]>]', ir.as_text() + ) + else: + self.assertIn( + "{jax.result_info = \"[('i', 'j', 'k', 'a')]\"}", ir.as_text() + ) def test_vmap_spmd_axis_name_error(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) @@ -2609,5 +2635,27 @@ def fwd(a): self.assertEqual(c.addressable_data(0).shape, (4, 2)) +@jtu.with_config(jax_use_shardy_partitioner=True) +class SdyIntegrationTest(jtu.JaxTestCase): + # Verify we can lower to a `ManualComputationOp`. + def test_shardy_collective_permute(self): + mesh = jtu.create_mesh((2,), ('x',)) + a = jax.device_put( + jnp.arange(8 * 8).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None)), + ) + + @jax.jit + @partial( + shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) + def fwd(a): + axis_size = lax.psum(1, 'x') + perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] + return lax.ppermute(a, 'x', perm=perm) + + self.assertIn('sdy.manual_computation', jax.jit(fwd).lower(a).as_text()) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 6072f97961c7535c8aed2b04f382b8ad33382444 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Thu, 26 Sep 2024 21:38:14 +0530 Subject: [PATCH 687/702] Raise ValueError when axis1==axis2 for jnp.trace --- jax/_src/numpy/lax_numpy.py | 4 ++++ tests/lax_numpy_test.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ca70ff35dd9f..70fea964c321 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6702,6 +6702,10 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int util.check_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") + + if _canonicalize_axis(axis1, ndim(a)) == _canonicalize_axis(axis2, ndim(a)): + raise ValueError(f"axis1 and axis2 can not be same. axis1={axis1} and axis2={axis2}") + dtypes.check_user_dtype_supported(dtype, "trace") a_shape = shape(a) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c6d56885a6a8..6f8167df9c29 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2725,6 +2725,11 @@ def np_fun(arg): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + def testTraceSameAxesError(self): + a = jnp.arange(1, 13).reshape(2, 3, 2) + with self.assertRaisesRegex(ValueError, r"axis1 and axis2 can not be same"): + jnp.trace(a, axis1=1, axis2=-2) + @jtu.sample_product( ashape=[(15,), (16,), (17,)], vshape=[(), (5,), (5, 5)], From a3284bd8a3f7b50db644c76215ce9718396726f3 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 26 Sep 2024 09:12:42 -0700 Subject: [PATCH 688/702] #sdy Add CPU targets in JAX. PiperOrigin-RevId: 679174535 --- tests/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index e580fb0ae363..c93f18dbb815 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -234,6 +234,7 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ + "cpu_shardy", "gpu_2gpu_shardy", "tpu_v3_2x2_shardy", "tpu_v4_2x2_shardy", @@ -1347,6 +1348,7 @@ jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], enable_configs = [ + "cpu_shardy", "gpu_2gpu_shardy", "tpu_v3_2x2_shardy", "tpu_v4_2x2_shardy", From 076287fb5cf7d107bcfe6e1a0e471b71a0c06e5a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Sep 2024 09:14:38 -0700 Subject: [PATCH 689/702] [Pallas/MGPU] Implement block spec evaluation correctly The preivous implementation made some surprising assumptions about the contents of the block specs and wasn't correct in general. The new implementation handles all the cases and seems to be sufficient to finally run the matmul example with multiple k steps while producing correct results (it's also shorter!). PiperOrigin-RevId: 679175212 --- jax/_src/pallas/mosaic_gpu/lowering.py | 92 +++++++++++--------------- tests/pallas/mosaic_gpu_test.py | 9 +-- 2 files changed, 40 insertions(+), 61 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index cc19d96aa194..92cd385be2a3 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -185,7 +185,7 @@ class LoweringError(Exception): # pylint: disable=g-bad-exception-name def _eval_index_map( module_ctx: ModuleContext, launch_ctx: mgpu.LaunchContext, - idx: ir.Value, + idx: Sequence[ir.Value], block_mapping: pallas_core.BlockMapping, ) -> Sequence[ir.Value]: block_indices = lower_jaxpr_to_mosaic_gpu( @@ -238,10 +238,7 @@ def lower_jaxpr_to_module( jaxpr, [True] * len(jaxpr.outvars), instantiate=True ) - grid = grid_mapping.grid - if len(grid) < 3: - grid += (1,) * (3 - len(grid)) - block = (128,) + (1,) * (len(grid) - 1) + block = (128, 1, 1) params = compiler_params.get("mosaic_gpu", {}) approx_math = params.get("approx_math", False) max_concurrent_steps = params.get("max_concurrent_steps", 1) @@ -256,8 +253,25 @@ def lower_jaxpr_to_module( sequential_axes = tuple( i for i, s in enumerate(dimension_semantics) if s == "sequential" ) - assert all(grid[axis] for axis in sequential_axes) - assert all(block[axis] == 1 for axis in sequential_axes) + + grid = [d for i, d in enumerate(grid_mapping.grid) if i not in sequential_axes] + if len(grid) < 3: + grid += (1,) * (3 - len(grid)) + else: + raise NotImplementedError( + "Only <=3D grids are supported in Mosaic GPU lowering." + ) + # Compute the number of steps along each sequential axis. + if sequential_axes: + # TODO(slebedev): Support multiple sequential axes. + if len(sequential_axes) > 1: + raise NotImplementedError( + "Multiple sequential axes are not supported in Mosaic GPU lowering." + ) + [sequential_axis] = sequential_axes + num_steps = grid_mapping.grid[sequential_axis] + else: + num_steps = 1 in_in_smem, out_in_smem = util.split_list( [ @@ -268,10 +282,9 @@ def lower_jaxpr_to_module( ) in_structs_gmem = [*grid_mapping.in_shapes] - in_block_shapes = [ - bm.block_shape - for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] - ] + in_block_mappings, out_block_mappings = util.split_list( + block_mappings, [grid_mapping.num_inputs] + ) in_structs_smem = [ jax.ShapeDtypeStruct( [max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype @@ -283,8 +296,7 @@ def lower_jaxpr_to_module( ) ] in_gmem_transforms = [ - cast(gpu_core.MemoryRefTransform, bm.transforms) - + cast(gpu_core.MemoryRefTransform, bm.transforms) for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] ] in_swizzles = map( @@ -322,17 +334,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): ) barriers, *extra_barriers = barriers + parallel_count = it.count() + program_ids_template = [ + _program_id(next(parallel_count)) if i not in sequential_axes else None + for i in range(len(grid_mapping.grid)) + ] module_ctx = ModuleContext( name_and_src_info.name, grid_mapping, approx_math, runtime_smem ) - program_ids = map(_program_id, range(len(grid_mapping.grid))) - start_indices = map( - partial(_eval_index_map, module_ctx, launch_ctx, program_ids), - block_mappings, - ) - in_start_indices, out_start_indices = util.split_list( - start_indices, [grid_mapping.num_inputs] - ) smem_scratch_it = iter(scratch_buffers_smem) scratch_buffers_template = [] @@ -385,20 +394,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): ) def gmem_slice( - start_indices: Sequence[ir.Value], step: ir.Value, - shape: Sequence[int], + block_mapping: pallas_core.BlockMapping, ) -> Sequence[mgpu.DynamicSlice]: + assert len(sequential_axes) == 1 + program_ids = [step if i is None else i for i in program_ids_template] + idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping) return tuple( - mgpu.ds( - arith_dialect.addi( - start_index, arith_dialect.muli(step, _as_index(dim)) - ) - if axis in sequential_axes - else start_index, - dim, - ) - for axis, (start_index, dim) in enumerate(zip(start_indices, shape)) + mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape) ) def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: @@ -410,11 +413,7 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: launch_ctx.async_copy( src_ref=in_buffers_gmem[idx], dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot), - gmem_slice=gmem_slice( - in_start_indices[idx], - step, - in_block_shapes[idx], - ), + gmem_slice=gmem_slice(step, in_block_mappings[idx]), barrier=barriers[slot], gmem_transform=tuple(gmem_transforms), swizzle=in_swizzles[idx], @@ -430,27 +429,11 @@ def store(idx: int, step: ir.Value, slot: ir.Value) -> None: launch_ctx.async_copy( src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot), dst_ref=out_buffers_gmem[idx], - gmem_slice=gmem_slice( - out_start_indices[idx], - step, - ir.MemRefType(out_buffers_smem[idx].type).shape[1:], - ), + gmem_slice=gmem_slice(step, out_block_mappings[idx]), swizzle=None, uniform=False, ) - # Compute the number of steps along each sequential axis. - if sequential_axes: - # TODO(slebedev): Support multiple sequential axes. - if len(sequential_axes) > 1: - raise NotImplementedError( - "Multiple sequential axes are not supported in Mosaic GPU lowering." - ) - [sequential_axis] = sequential_axes - num_steps = grid_mapping.grid[sequential_axis] - else: - num_steps = 1 - with mgpu.single_thread(): for slot in range(min(max_concurrent_steps, num_steps)): barriers[slot].arrive_expect_tx(in_transfer_bytes) @@ -619,6 +602,7 @@ def write_env(var: jax_core.Var, val): @register_lowering_rule(primitives.program_id_p) def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): + # TODO(apaszke): Sequential axis should be handled specially!! del ctx # Unused. return _program_id(axis) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 267ddd0c97d5..b35658ed4845 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -462,13 +462,8 @@ def test_realistic_matmul(self): dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize - # TODO(apaszke): Make the grid and tile sizes larger - # grid_m, grid_k, grid_n = 132, 10, 4 - # TODO(apaszke): Increasing grid_k causes th test to fail. - # It seems like our pipelining implementation has a number of races. - grid_m, grid_k, grid_n = 2, 1, 2 - # tile_m = tile_n = 128 - tile_m = tile_n = 64 + grid_m, grid_k, grid_n = 132, 10, 4 + tile_m = tile_n = 128 tile_k = elems_128b m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n def kernel(a_ref, b_ref, o_ref, acc_ref): From dd2ee8c7b25bcecc4dacf93b60eb596ebee8a33d Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 26 Sep 2024 09:53:48 -0700 Subject: [PATCH 690/702] [Pallas/MGPU] Skip outgoing TMA when the output is being revisited Otherwise we end up with programs that race on writes to the same GMEM location. PiperOrigin-RevId: 679189227 --- jax/_src/pallas/mosaic_gpu/lowering.py | 70 +++++++++++++++++++------ jax/experimental/mosaic/gpu/__init__.py | 1 + jax/experimental/mosaic/gpu/core.py | 10 ++-- jax/experimental/mosaic/gpu/utils.py | 24 +++++---- 4 files changed, 75 insertions(+), 30 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 92cd385be2a3..a64d4cc9f9b1 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -404,6 +404,8 @@ def gmem_slice( mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape) ) + is_memory_thread = mgpu.single_thread_predicate(per_block=True) + def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: if not in_in_smem[idx]: return @@ -419,36 +421,67 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: swizzle=in_swizzles[idx], arrive=False, # The caller must do ``arrive_expect_tx`` manually! uniform=False, + predicate=is_memory_thread, ) - def store(idx: int, step: ir.Value, slot: ir.Value) -> None: + def store( + idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value + ) -> ir.Value: if not out_in_smem[idx]: return + # We have to do some work to make sure that consecutive stores are not + # going to be writing to the same location, or else we'll end up with + # multiple concurrent writes and a racy program. + # TODO(apaszke,slebedev): In most cases output index maps depend only on + # parallel grid axes and in that case we can simply move the store to + # happen after the loop. + # TODO(apaszke,slebedev): This still diverges significantly from the TPU + # semantics in that it will move on to the next SMEM output slice even if + # it's not storing the previous one. + store_slice = gmem_slice(step, out_block_mappings[idx]) + strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset() + base_offset = _as_index(0) + for stride, slc in zip(strides, store_slice): + base_offset = arith_dialect.addi( + base_offset, arith_dialect.muli(slc.base, _as_index(stride)) + ) + base_offset_changed = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset + ) + is_last_step = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1) + ) + do_store = arith_dialect.andi( + is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step) + ) # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. launch_ctx.async_copy( src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot), dst_ref=out_buffers_gmem[idx], - gmem_slice=gmem_slice(step, out_block_mappings[idx]), + gmem_slice=store_slice, swizzle=None, uniform=False, + predicate=do_store, ) + return base_offset - with mgpu.single_thread(): - for slot in range(min(max_concurrent_steps, num_steps)): - barriers[slot].arrive_expect_tx(in_transfer_bytes) - for idx in range(grid_mapping.num_inputs): - fetch(idx, _as_index(slot), _as_index(slot)) + for slot in range(min(max_concurrent_steps, num_steps)): + barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread) + for idx in range(grid_mapping.num_inputs): + fetch(idx, _as_index(slot), _as_index(slot)) - @mgpu.fori(_as_index(num_steps), accs) - def _(step, accs): + last_store_offsets = [_as_index(-1)] * grid_mapping.num_outputs + @mgpu.fori(_as_index(num_steps), (accs, last_store_offsets)) + def _(step, carry): + accs, last_store_offsets = carry slot = arith_dialect.remui(step, _as_index(max_concurrent_steps)) if grid_mapping.num_inputs: # Only wait if async copies were issued. barriers[slot].wait() # We need to make sure the output copy is complete before the kernel starts # writing to the output window. - launch_ctx.await_async_copy(max_concurrent_steps - 1) + launch_ctx.await_async_copy(max_concurrent_steps - 1, await_read_only=True) args = [ mgpu.memref_slice(buffers_smem[idx], slot) @@ -468,23 +501,26 @@ def _(step, accs): new_accs = lower_jaxpr_to_mosaic_gpu( module_ctx, launch_ctx, lowered_jaxpr, args ) - mgpu.commit_shared() - with mgpu.single_thread(): - for idx in range(grid_mapping.num_outputs): - store(idx, step, slot) + # TODO(apaszke): Elide this if we're not going to perform any stores + mgpu.commit_shared() + new_store_offsets = [] + for idx in range(grid_mapping.num_outputs): + new_store_offsets.append( + store(idx, step, slot, last_store_offsets[idx]) + ) next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps)) next_step_in_bounds = arith_dialect.cmpi( arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps) ) next_slot = slot # (x + y) % y == x % y - with mgpu.when(next_step_in_bounds), mgpu.single_thread(): + with mgpu.when(next_step_in_bounds): + barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread) for idx in range(grid_mapping.num_inputs): fetch(idx, next_step, next_slot) - barriers[slot].arrive_expect_tx(in_transfer_bytes) - return list(new_accs) + return list(new_accs), new_store_offsets launch_ctx.await_async_copy(0) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 8057f97a7dfc..f5944c862480 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -51,6 +51,7 @@ memref_unfold, memref_unsqueeze, single_thread, + single_thread_predicate, thread_idx, tile_shape, warp_idx, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 9a03afeb4a8c..22a996efd64a 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -347,6 +347,7 @@ def async_copy( arrive: bool | None = None, uniform: bool = True, collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, + predicate: ir.Value | None = None, ): index = ir.IndexType.get() i16 = ir.IntegerType.get_signless(16) @@ -503,14 +504,17 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): barrier_ptr = barrier.get_ptr() with uniform_ctx(): if arrive: - nvvm.mbarrier_arrive_expect_tx_shared(barrier_ptr, transfer_bytes) + nvvm.mbarrier_arrive_expect_tx_shared( + barrier_ptr, transfer_bytes, predicate=predicate + ) nvvm.cp_async_bulk_tensor_shared_cluster_global( - smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], multicast_mask=multicast_mask, + smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], + multicast_mask=multicast_mask, predicate=predicate ) else: with uniform_ctx(): nvvm.cp_async_bulk_tensor_global_shared_cta( - tma_desc, smem_ptr, rev_dyn_base_indices + tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate ) nvvm.cp_async_bulk_commit_group() diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index b5c22734b0d4..a59ddbea5565 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -229,6 +229,15 @@ class ThreadSubset(enum.IntEnum): _ONCE_PER: ThreadSubset | None = None +def single_thread_predicate(per_block=True): + warp = warp_idx() + if not per_block: + warp = arith.remui(warp, c(4, warp.type)) + first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) + elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) + return arith.andi(first_warp, elected) + + @contextlib.contextmanager def single_thread(per_block=True): """Runs the context only from a single thread. @@ -244,16 +253,10 @@ def single_thread(per_block=True): yield return - warp = warp_idx() - if not per_block: - warp = arith.remui(warp, c(4, warp.type)) - first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) - elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) - should_run = arith.andi(first_warp, elected) - if_op = scf.IfOp(should_run) prev_scope = _ONCE_PER _ONCE_PER = scope try: + if_op = scf.IfOp(single_thread_predicate(per_block)) with ir.InsertionPoint(if_op.then_block): yield scf.YieldOp([]) @@ -610,14 +613,15 @@ def arrive(self): i64 = ir.IntegerType.get_signless(64) nvvm.mbarrier_arrive_shared(i64, self.get_ptr()) - def arrive_expect_tx(self, bytes: int | ir.Value): + def arrive_expect_tx( + self, bytes: int | ir.Value, predicate: ir.Value | None = None + ): if isinstance(bytes, int): bytes = c(bytes, ir.IntegerType.get_signless(32)) elif ir.IndexType.isinstance(bytes.type): i32 = ir.IntegerType.get_signless(32) bytes = arith.index_cast(i32, bytes) - - nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes) + nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes, predicate=predicate) def get_ptr(self): ptr = ir.Type.parse("!llvm.ptr<3>") From 140a8c70b44dadaa2089886faf5da8cf6c7fe548 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 26 Sep 2024 10:12:53 -0700 Subject: [PATCH 691/702] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0e732d65bdf8fb158c7b01e18139e5ba59ca7025. PiperOrigin-RevId: 679196598 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 4ad4e48c02d1..ca84599d6cff 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "a473d30392e2cea68dc90f95377de3f568ea2055" -XLA_SHA256 = "30324095a4d9454b5a8fdf0397b62cfd6f06155a077ce93cf75b64fb78f98fc0" +XLA_COMMIT = "0e732d65bdf8fb158c7b01e18139e5ba59ca7025" +XLA_SHA256 = "16e4aeca04ce94bd0fcfa32990d76be3779c026c2b649478bf27d0db0679e65c" def repo(): tf_http_archive( From 0e082f978b5af5e1a50cf89b04e3b76065d02bb4 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 26 Sep 2024 10:16:20 -0700 Subject: [PATCH 692/702] Deprecate jax.lib.xla_client.Device. jax.Device is a longstanding public name for this class. PiperOrigin-RevId: 679197718 --- CHANGELOG.md | 1 + jax/lib/xla_client.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a424e645144c..079e055aa994 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. in an error. * Internal pretty-printing tools `jax.core.pp_*` have been removed, after being deprecated in JAX v0.4.30. + * `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead. * `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use `jax.errors.JaxRuntimeError` instead. diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 9c57d82e837b..a51625eb072e 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -25,7 +25,6 @@ ArrayImpl = _xc.ArrayImpl Client = _xc.Client CompileOptions = _xc.CompileOptions -Device = _xc.Device DeviceAssignment = _xc.DeviceAssignment FftType = _xc.FftType Frame = _xc.Frame @@ -49,6 +48,10 @@ _xc.bfloat16, ), # Added Sep 26 2024 + "Device" : ( + "jax.lib.xla_client.Device is deprecated; use jax.Device instead.", + _xc.Device + ), "XlaRuntimeError": ( ( "jax.lib.xla_client.XlaRuntimeError is deprecated; use" @@ -63,6 +66,7 @@ if _typing.TYPE_CHECKING: _xla = _xc._xla bfloat16 = _xc.bfloat16 + Device = _xc.Device XlaRuntimeError = _xc.XlaRuntimeError else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr From 7e6fa3ed285ca4fc997ece1077b2319d505dd8c7 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Thu, 26 Sep 2024 23:27:23 +0530 Subject: [PATCH 693/702] Improve docs for jax.numpy: sinh, cosh and tanh --- jax/_src/numpy/ufuncs.py | 151 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 148 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index fd0a9e502b1d..dc265b8e87e1 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -710,14 +710,111 @@ def arctan(x: ArrayLike, /) -> Array: """ return lax.atan(*promote_args_inexact('arctan', x)) -@implements(np.sinh, module='numpy') + @partial(jit, inline=True) def sinh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise hyperbolic sine of input. + + JAX implementation of :obj:`numpy.sinh`. + + The hyperbolic sine is defined by: + + .. math:: + + sinh(x) = \frac{e^x - e^{-x}}{2} + + Args: + x: input array or scalar. + + Returns: + An array containing the hyperbolic sine of each element of ``x``, promoting + to inexact dtype. + + Note: + ``jnp.sinh`` is equivalent to computing ``-1j * jnp.sin(1j * x)``. + + See also: + - :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the + input. + - :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the + input. + - :func:`jax.numpy.arcsinh`: Computes the element-wise inverse of hyperbolic + sine of the input. + + Examples: + >>> x = jnp.array([[-2, 3, 5], + ... [0, -1, 4]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.sinh(x) + Array([[-3.627, 10.018, 74.203], + [ 0. , -1.175, 27.29 ]], dtype=float32) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.sin(1j * x) + Array([[-3.627+0.j, 10.018-0.j, 74.203-0.j], + [ 0. -0.j, -1.175+0.j, 27.29 -0.j]], dtype=complex64, weak_type=True) + + For complex-valued input: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.sinh(3-2j) + Array(-4.169-9.154j, dtype=complex64, weak_type=True) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.sin(1j * (3-2j)) + Array(-4.169-9.154j, dtype=complex64, weak_type=True) + """ return lax.sinh(*promote_args_inexact('sinh', x)) -@implements(np.cosh, module='numpy') + @partial(jit, inline=True) def cosh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise hyperbolic cosine of input. + + JAX implementation of :obj:`numpy.cosh`. + + The hyperbolic cosine is defined by: + + .. math:: + + cosh(x) = \frac{e^x + e^{-x}}{2} + + Args: + x: input array or scalar. + + Returns: + An array containing the hyperbolic cosine of each element of ``x``, promoting + to inexact dtype. + + Note: + ``jnp.cosh`` is equivalent to computing ``jnp.cos(1j * x)``. + + See also: + - :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input. + - :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the + input. + - :func:`jax.numpy.arccosh`: Computes the element-wise inverse of hyperbolic + cosine of the input. + + Examples: + >>> x = jnp.array([[3, -1, 0], + ... [4, 7, -5]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cosh(x) + Array([[ 10.068, 1.543, 1. ], + [ 27.308, 548.317, 74.21 ]], dtype=float32) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cos(1j * x) + Array([[ 10.068+0.j, 1.543+0.j, 1. +0.j], + [ 27.308+0.j, 548.317+0.j, 74.21 +0.j]], dtype=complex64, weak_type=True) + + For complex-valued input: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cosh(5+1j) + Array(40.096+62.44j, dtype=complex64, weak_type=True) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cos(1j * (5+1j)) + Array(40.096+62.44j, dtype=complex64, weak_type=True) + """ return lax.cosh(*promote_args_inexact('cosh', x)) @implements(np.arcsinh, module='numpy') @@ -735,9 +832,57 @@ def arccosh(x: ArrayLike, /) -> Array: result = _where(real(result) < 0, lax.neg(result), result) return result -@implements(np.tanh, module='numpy') + @partial(jit, inline=True) def tanh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise hyperbolic tangent of input. + + JAX implementation of :obj:`numpy.tanh`. + + The hyperbolic tangent is defined by: + + .. math:: + + tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}} + + Args: + x: input array or scalar. + + Returns: + An array containing the hyperbolic tangent of each element of ``x``, promoting + to inexact dtype. + + Note: + ``jnp.tanh`` is equivalent to computing ``-1j * jnp.tan(1j * x)``. + + See also: + - :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input. + - :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the + input. + - :func:`jax.numpy.arctanh`: Computes the element-wise inverse of hyperbolic + tangent of the input. + + Examples: + >>> x = jnp.array([[-1, 0, 1], + ... [3, -2, 5]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.tanh(x) + Array([[-0.762, 0. , 0.762], + [ 0.995, -0.964, 1. ]], dtype=float32) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.tan(1j * x) + Array([[-0.762+0.j, 0. -0.j, 0.762-0.j], + [ 0.995-0.j, -0.964+0.j, 1. -0.j]], dtype=complex64, weak_type=True) + + For complex-valued input: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.tanh(2-5j) + Array(1.031+0.021j, dtype=complex64, weak_type=True) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.tan(1j * (2-5j)) + Array(1.031+0.021j, dtype=complex64, weak_type=True) + """ return lax.tanh(*promote_args_inexact('tanh', x)) @implements(np.arctanh, module='numpy') From 9f4e8d00392d8aa1706e02d7053db2812bb4c6c8 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 26 Sep 2024 13:56:54 -0700 Subject: [PATCH 694/702] [XLA:Mosaic][Pallas] Enable vector.ExtractOp for non-zero indices. PiperOrigin-RevId: 679283281 --- .../tpu/transforms/apply_vector_layout.cc | 47 +++++++++++++++---- tests/pallas/pallas_error_handling_test.py | 20 ++++---- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index a1714fc8090b..f6e1c7918646 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3555,24 +3555,53 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, op.erase(); return success(); } else { - for (int64_t i : extract_op.getStaticPosition()) { - if (i != 0) { - return op.emitOpError( - "Not implemented: Only 0 indices supported for scalar results"); - } - } + // TODO(b/367459476): Support non-zero offsets. if (layout_in.offsets() != LayoutOffsets{0, 0}) { return op.emitOpError("Not implemented: Unsupported layout"); } + auto [sub_tile, lane_tile] = layout_in.tiling(); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array vregs, disassemble(builder, layout_in, extract_op.getVector(), ctx.target_shape)); TPU_ASSERT_GT_OP(vregs.num_elements(), 0); + + SmallVector indices(extract_op.getStaticPosition()); + auto vreg_slice = layout_in.vregSlice(ctx.target_shape); + std::array position = {0, 0}; + SmallVector vreg_index(indices); + // TODO(b/367459476): Support non-VREG-aligned tiling. + CHECK_EQ(lane_tile, ctx.target_shape[1]); + layout_in.insertImplicit(indices, static_cast(0)); + layout_in.insertImplicit(vreg_index, static_cast(0)); + int i = *(indices.end()-2); + int j = *(indices.end()-1); + *(vreg_index.end() -2) = i / vreg_slice[0]; + *(vreg_index.end() -1) = j / vreg_slice[1]; + layout_in.eraseImplicit(vreg_index); + position[0] = ((j % vreg_slice[1]) / lane_tile * sub_tile + ) + i % sub_tile; + position[1] = j % lane_tile; + + TPU_ASSERT_LT_OP(vreg_index, vregs.dimensions()); + Value extracted_vreg = vregs(vreg_index); + + // Invert the offsets to get the rotation amount. + position[0] = (ctx.target_shape[0] - position[0]) % ctx.target_shape[0]; + position[1] = (ctx.target_shape[1] - position[1]) % ctx.target_shape[1]; + auto res_vreg_ty = extracted_vreg.getType(); + Value shift = builder.create( + builder.getIntegerAttr(builder.getI32Type(), position[0])); + Value rotated_vreg = builder.create( + res_vreg_ty, extracted_vreg, shift, 0, /*stride*/nullptr, nullptr); + shift = builder.create( + builder.getIntegerAttr(builder.getI32Type(), position[1])); + rotated_vreg = builder.create( + res_vreg_ty, rotated_vreg, shift, 1, /*stride*/nullptr, nullptr); extract_op.replaceAllUsesWith( - builder - .create(op.getLoc(), *vregs.data(), - ArrayRef{0, 0}) + builder.create( + op.getLoc(), rotated_vreg, + ArrayRef{0, 0}) .getResult()); } extract_op.erase(); diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py index 06b4bd3e3a4f..34b0ff1492a4 100644 --- a/tests/pallas/pallas_error_handling_test.py +++ b/tests/pallas/pallas_error_handling_test.py @@ -44,22 +44,22 @@ def setUp(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Test only works on TPU.") - def test_vector_extract_nonzero(self): - input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32) - out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32) + def test_non_singular_stride(self): + input_arr = jax.random.uniform( + jax.random.key(0), (8, 128), dtype=jnp.float32) + out_shape = jax.ShapeDtypeStruct((8, 16), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), ) @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) def test_kernel(input_ref, output_ref): - val = input_ref[...] - x = val[0, 0] + val[0, 1] - output_ref[0, 0] = x + x = input_ref[:, ::8] + output_ref[...] = x # Test that a Mosaic error is raised. This assert is a guard against # underlying changes in Mosaic. @@ -67,7 +67,7 @@ def test_kernel(input_ref, output_ref): # the test example to force a different error. with self.assertRaisesRegex( error_handling.MosaicError, - "Not implemented: Only 0 indices supported for scalar results", + "Not Implemented: Stride on last dim is not 1", ): test_kernel(input_arr) @@ -78,7 +78,7 @@ def test_kernel(input_ref, output_ref): except error_handling.MosaicError as e: tb_string = traceback.format_tb(e.__traceback__) tb_string = "".join(tb_string) - self.assertEndsWith(tb_string, "x = val[0, 0] + val[0, 1]\n") + self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n") @jax.jit def kernel_in_jitted_fn(x): @@ -91,7 +91,7 @@ def kernel_in_jitted_fn(x): except error_handling.MosaicError as e: tb_string = traceback.format_tb(e.__traceback__) tb_string = "".join(tb_string) - self.assertEndsWith(tb_string, "x = val[0, 0] + val[0, 1]\n") + self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n") def test_invalid_smem_vmem_verification_error(self): input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32) From ab4590ce0a415d750ea92d62701fdba17165b297 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Fri, 27 Sep 2024 01:30:21 -0700 Subject: [PATCH 695/702] [Pallas TPU] Add a note in the Pallas Quickstart documentation about the instructions of running the existing example on TPU This fixes https://github.com/jax-ml/jax/issues/22817 This changes is originally proposed by @justinjfu in the comments of the above issue. This PR is related to https://github.com/jax-ml/jax/pull/23885. PiperOrigin-RevId: 679487218 --- docs/pallas/quickstart.ipynb | 29 +++++++++++++++++++++++++++++ docs/pallas/quickstart.md | 17 +++++++++++++++++ docs/pallas/tpu/pipelining.ipynb | 8 ++++++++ docs/pallas/tpu/pipelining.md | 2 ++ 4 files changed, 56 insertions(+) diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 5a8608f494c3..0e759a493a61 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -282,6 +282,35 @@ "On TPUs, programs are executed in a combination of parallel and sequential\n", "(depending on the architecture) so there are slightly different considerations.\n", "\n", + "To call the above kernel on TPU, run:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "796f928c", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "def iota(size: int):\n", + " return pl.pallas_call(iota_kernel,\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n", + " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", + " grid=(size,))()\n", + "iota(8)" + ] + }, + { + "cell_type": "markdown", + "id": "68f97b4e", + "metadata": {}, + "source": [ + "TPUs distinguish between vector and scalar memory spaces and in this case the\n", + "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", + "a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n", + "\n", "You can read more details at {ref}`pallas_grid`." ] }, diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index b8f9254f21d9..a8b13ea38eaf 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -188,6 +188,23 @@ operations like matrix multiplications really quickly. On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations. +To call the above kernel on TPU, run: + +```{code-cell} ipython3 +from jax.experimental.pallas import tpu as pltpu + +def iota(size: int): + return pl.pallas_call(iota_kernel, + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), + grid=(size,))() +iota(8) +``` + +TPUs distinguish between vector and scalar memory spaces and in this case the +output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is +a scalar. For more details read {ref}`pallas_tpu_pipelining`. + You can read more details at {ref}`pallas_grid`. +++ diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 2a3aa9d114de..b5f2c652b5a5 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -1,5 +1,13 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "7704d3bb", + "metadata": {}, + "source": [ + "(pallas_tpu_pipelining)=" + ] + }, { "cell_type": "markdown", "metadata": { diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 507eab658a39..19150b3832fa 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -11,6 +11,8 @@ kernelspec: name: python3 --- +(pallas_tpu_pipelining)= + +++ {"id": "teoJ_fUwlu0l"} # Pipelining From ea6ee4d7feb8a9e3f490830ef4b867797726be6b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 27 Sep 2024 02:07:23 -0700 Subject: [PATCH 696/702] Removed unused imports in `jax.experimental.mosaic.gpu.core` PiperOrigin-RevId: 679498378 --- jax/experimental/mosaic/gpu/core.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 22a996efd64a..bf5ec0dfc8af 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -19,21 +19,15 @@ import dataclasses import functools import hashlib -import itertools import math import os import pathlib -import subprocess -import tempfile import time from typing import Any, Generic, TypeVar import weakref import jax -from jax._src import config -from jax._src import core as jax_core from jax._src.interpreters import mlir -from jax._src.lib import xla_client from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import builtin @@ -42,7 +36,6 @@ from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvvm -from jaxlib.mlir.passmanager import PassManager import numpy as np from . import profiler From ea86251a60fd604f59e537572ad3223289b86767 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 27 Sep 2024 02:14:55 -0700 Subject: [PATCH 697/702] [Pallas:TPU] Fix lowering of convert_element_type(int32) -> bool. We need to add a condition on vector type since both operands of arith::CmpIOp must have same type. PiperOrigin-RevId: 679500783 --- jax/_src/pallas/mosaic/lowering.py | 15 ++++++++++++--- tests/pallas/ops_test.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index c9ff21c49689..46cbe8e4758b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1635,7 +1635,8 @@ def _convert_element_type_lowering_rule( del weak_type del sharding out_aval = ctx.avals_out[0] - old_dtype = ctx.avals_in[0].dtype + in_aval = ctx.avals_in[0] + old_dtype = in_aval.dtype out_type = aval_to_ir_type(out_aval) if old_dtype == new_dtype: @@ -1680,8 +1681,16 @@ def _convert_element_type_lowering_rule( predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) const_type = _dtype_to_ir_type(old_dtype) const_zero = ir.IntegerAttr.get(const_type, 0) - const_zero = arith.ConstantOp(const_type, const_zero) - return arith.CmpIOp(predicate, x, const_zero).result + if in_aval.shape: + in_type = aval_to_ir_type(in_aval, is_kernel_boundary=False) + vector_zeros = arith.ConstantOp( + in_type, + ir.DenseElementsAttr.get_splat(in_type, const_zero), + ) + return arith.CmpIOp(predicate, x, vector_zeros).result + return arith.CmpIOp( + predicate, x, arith.ConstantOp(const_type, const_zero) + ).result return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype), multiple_results=False)(ctx, x) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index ce9403bbc890..7bba9f01bec9 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -561,7 +561,7 @@ def test_cast(self, from_dtype, to_dtype, data): self.skipTest("Not supported: bad canonicalization") if from_dtype == "bool" and to_dtype in {"int16", "int8"}: self.skipTest("Not supported: cannot extend to sub-32 bit types") - if from_dtype in {"int32", "bfloat16", "float32"} and to_dtype == "bool": + if from_dtype in {"bfloat16", "float32"} and to_dtype == "bool": self.skipTest("Not supported: unsupported relayout") if from_dtype == "bool" and to_dtype in {"int32", "bfloat16", "float32"}: self.skipTest("Not supported: unsupported relayout") From 3ae48621dd89bf01d4e165cff99e8160133fef60 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 27 Sep 2024 02:27:44 -0700 Subject: [PATCH 698/702] Fixed Pallas Mosaic GPU test following recent changes PiperOrigin-RevId: 679504036 --- jax/_src/pallas/mosaic_gpu/lowering.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index a64d4cc9f9b1..bdb22442c6b7 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -357,6 +357,10 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): *aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype) ) ) + case gpu_core.AbstractMemoryRef() if isinstance( + aval.dtype, gpu_core.BarrierType + ): + pass case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM: scratch_buffers_template.append(next(smem_scratch_it)) should_discharge.append(False) @@ -397,7 +401,7 @@ def gmem_slice( step: ir.Value, block_mapping: pallas_core.BlockMapping, ) -> Sequence[mgpu.DynamicSlice]: - assert len(sequential_axes) == 1 + assert len(sequential_axes) <= 1 program_ids = [step if i is None else i for i in program_ids_template] idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping) return tuple( @@ -428,7 +432,7 @@ def store( idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value ) -> ir.Value: if not out_in_smem[idx]: - return + return _as_index(-1) # We have to do some work to make sure that consecutive stores are not # going to be writing to the same location, or else we'll end up with From afaf8b823dd2bd3884c636f49ee45a591ead2cba Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 27 Sep 2024 02:42:43 -0700 Subject: [PATCH 699/702] Run Pallas Mosaic GPU tests on internal CI PiperOrigin-RevId: 679508320 --- tests/pallas/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index ba82b8c4223c..b1db085a8c3a 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -205,7 +205,6 @@ jax_multiplatform_test( env = { "JAX_PALLAS_USE_MOSAIC_GPU": "1", }, - tags = ["notap"], deps = [ "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep From 5740ab3b02fd5daa455e6aa4915e5bc9604d6d21 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 27 Sep 2024 03:11:42 -0700 Subject: [PATCH 700/702] [Pallas/MGPU] Skip output transfers when they don't depend on sequenital dims Note that thanks to the previous revisiting-related checks we weren't doing the transfers anyway, but this way we can also avoid having to pay for the checks. PiperOrigin-RevId: 679516275 --- jax/_src/pallas/mosaic_gpu/lowering.py | 80 +++++++++++++++++--------- 1 file changed, 53 insertions(+), 27 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index bdb22442c6b7..0d0ac41d11e3 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -201,6 +201,11 @@ def _eval_index_map( return tuple(result) +def _uses_arguments(cjaxpr: jax_core.ClosedJaxpr) -> list[bool]: + jaxpr = cjaxpr.jaxpr + return pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))[1] + + def lower_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, @@ -270,8 +275,13 @@ def lower_jaxpr_to_module( ) [sequential_axis] = sequential_axes num_steps = grid_mapping.grid[sequential_axis] + out_sequential_invariant = [ + not _uses_arguments(bm.index_map_jaxpr)[sequential_axis] + for bm in grid_mapping.block_mappings_output + ] else: num_steps = 1 + out_sequential_invariant = [True] * len(grid_mapping.out_shapes) in_in_smem, out_in_smem = util.split_list( [ @@ -429,36 +439,42 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: ) def store( - idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value - ) -> ir.Value: + idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value | None + ) -> ir.Value | None: if not out_in_smem[idx]: return _as_index(-1) - # We have to do some work to make sure that consecutive stores are not - # going to be writing to the same location, or else we'll end up with - # multiple concurrent writes and a racy program. - # TODO(apaszke,slebedev): In most cases output index maps depend only on - # parallel grid axes and in that case we can simply move the store to - # happen after the loop. - # TODO(apaszke,slebedev): This still diverges significantly from the TPU - # semantics in that it will move on to the next SMEM output slice even if - # it's not storing the previous one. store_slice = gmem_slice(step, out_block_mappings[idx]) - strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset() - base_offset = _as_index(0) - for stride, slc in zip(strides, store_slice): - base_offset = arith_dialect.addi( - base_offset, arith_dialect.muli(slc.base, _as_index(stride)) + if out_sequential_invariant[idx]: + assert prev_base_offset is None + do_store = None # Lack of predicate defaults to True. + base_offset = None + else: + assert prev_base_offset is not None + # We have to do some work to make sure that consecutive stores are not + # going to be writing to the same location, or else we'll end up with + # multiple concurrent writes and a racy program. + # TODO(apaszke,slebedev): In most cases output index maps depend only on + # parallel grid axes and in that case we can simply move the store to + # happen after the loop. + # TODO(apaszke,slebedev): This still diverges significantly from the TPU + # semantics in that it will move on to the next SMEM output slice even if + # it's not storing the previous one. + strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset() + base_offset = _as_index(0) + for stride, slc in zip(strides, store_slice): + base_offset = arith_dialect.addi( + base_offset, arith_dialect.muli(slc.base, _as_index(stride)) + ) + base_offset_changed = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset + ) + is_last_step = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1) + ) + do_store = arith_dialect.andi( + is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step) ) - base_offset_changed = arith_dialect.cmpi( - arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset - ) - is_last_step = arith_dialect.cmpi( - arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1) - ) - do_store = arith_dialect.andi( - is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step) - ) # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. launch_ctx.async_copy( src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot), @@ -475,7 +491,7 @@ def store( for idx in range(grid_mapping.num_inputs): fetch(idx, _as_index(slot), _as_index(slot)) - last_store_offsets = [_as_index(-1)] * grid_mapping.num_outputs + last_store_offsets = [None if inv else _as_index(-1) for inv in out_sequential_invariant] @mgpu.fori(_as_index(num_steps), (accs, last_store_offsets)) def _(step, carry): accs, last_store_offsets = carry @@ -510,8 +526,11 @@ def _(step, carry): mgpu.commit_shared() new_store_offsets = [] for idx in range(grid_mapping.num_outputs): + last_offset = last_store_offsets[idx] new_store_offsets.append( - store(idx, step, slot, last_store_offsets[idx]) + store(idx, step, slot, last_offset) + if not out_sequential_invariant[idx] + else last_offset # Only store if the output can depend on the step. ) next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps)) @@ -526,6 +545,13 @@ def _(step, carry): return list(new_accs), new_store_offsets + # Outputs invariant to the sequential axis are never written from inside the + # loop. This is the only place where we store them. + last_slot = _as_index((num_steps - 1) % max_concurrent_steps) + for idx in range(grid_mapping.num_outputs): + if out_sequential_invariant[idx]: + store(idx, _as_index(0), last_slot, None) + launch_ctx.await_async_copy(0) scratch_avals = [ From 26632fd344b91af0412d69e4e8d1259f3d87ba4c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 27 Sep 2024 06:14:50 -0700 Subject: [PATCH 701/702] Replace disable_backends with enable_backends on jax_multiplatform_test. Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled". We change the relationship between enable_backends, disable_configs, enable_configs to be the following: * `enable_backends` selects a set of initial test configurations to enable, based off backend only. * `disable_configs` then prunes that set of test configurations, removing elements from the set. * `enable_configs` then adds additional configurations to the set. Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target. PiperOrigin-RevId: 679563155 --- benchmarks/mosaic/BUILD | 18 +-- docs/cuda_custom_call/BUILD | 5 +- jax/experimental/mosaic/gpu/examples/BUILD | 9 +- jaxlib/jax.bzl | 13 +- tests/BUILD | 87 +++--------- tests/mosaic/BUILD | 31 ++--- tests/pallas/BUILD | 148 ++++----------------- 7 files changed, 76 insertions(+), 235 deletions(-) diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD index 727e347e5a64..39c7aa5f3395 100644 --- a/benchmarks/mosaic/BUILD +++ b/benchmarks/mosaic/BUILD @@ -28,25 +28,11 @@ package( jax_generate_backend_suites() -DISABLED_BACKENDS = [ - "cpu", - "tpu", -] - -DISABLED_CONFIGS = [ - "gpu_v100", - "gpu_a100", - "gpu_p100", - "gpu_p100_x32", - "gpu_x32", - "gpu_pjrt_c_api", -] - jax_multiplatform_test( name = "matmul_bench", srcs = ["matmul_bench.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = ["gpu_h100"], tags = ["notap"], deps = [ "//jax:mosaic_gpu", diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD index 0089b6b9fb0d..4954ce3db4fa 100644 --- a/docs/cuda_custom_call/BUILD +++ b/docs/cuda_custom_call/BUILD @@ -32,10 +32,7 @@ jax_multiplatform_test( name = "cuda_custom_call_test", srcs = ["cuda_custom_call_test.py"], data = [":foo"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], tags = ["notap"], deps = [ "//jax:extend", diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index 57f78cb2c5c8..fe1a7e9180ac 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -13,7 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_library") -load("//jaxlib:jax.bzl", "jax_py_test", "py_deps") +load("//jaxlib:jax.bzl", "jax_multiplatform_test", "py_deps") licenses(["notice"]) @@ -48,18 +48,17 @@ py_library( ], ) -jax_py_test( +jax_multiplatform_test( name = "run_matmul", srcs = ["matmul.py"], + enable_backends = [], + enable_configs = ["gpu_h100"], main = "matmul.py", tags = [ "manual", "notap", - "requires-gpu-sm90-only", ], deps = [ - "//jax", "//jax:mosaic_gpu", - "//learning/brain/research/jax:gpu_support", ] + py_deps("numpy"), ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 65ec572c7ee2..2e37e694b506 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -231,15 +231,22 @@ def jax_multiplatform_test( shard_count = None, deps = [], data = [], - disable_backends = None, # buildifier: disable=unused-variable + enable_backends = None, backend_variant_args = {}, # buildifier: disable=unused-variable backend_tags = {}, # buildifier: disable=unused-variable disable_configs = None, # buildifier: disable=unused-variable - enable_configs = None, # buildifier: disable=unused-variable + enable_configs = [], config_tags_overrides = None, # buildifier: disable=unused-variable tags = [], main = None, pjrt_c_api_bypass = False): # buildifier: disable=unused-variable + # enable_configs and disable_configs do not do anything in OSS, only in Google's CI. + # The order in which `enable_backends`, `enable_configs`, and `disable_configs` are applied is + # as follows: + # 1. `enable_backends` is applied first, enabling all test configs for the given backends. + # 2. `disable_configs` is applied second, disabling the named test configs. + # 3. `enable_configs` is applied last, enabling the named test configs. + if main == None: if len(srcs) == 1: main = srcs[0] @@ -256,7 +263,7 @@ def jax_multiplatform_test( "--jax_platform_name=" + backend, ] test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, []) - if disable_backends and backend in disable_backends: + if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]): test_tags += ["manual"] if backend == "gpu": test_tags += tf_cuda_tests_tags() diff --git a/tests/BUILD b/tests/BUILD index c93f18dbb815..df9a28236e6a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -66,7 +66,10 @@ jax_py_test( jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], - disable_backends = ["tpu"], + enable_backends = [ + "cpu", + "gpu", + ], tags = ["multiaccelerator"], deps = py_deps("tensorflow_core"), ) @@ -160,10 +163,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "gpu_memory_flags_test_no_preallocation", srcs = ["gpu_memory_flags_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], env = { "XLA_PYTHON_CLIENT_PREALLOCATE": "0", }, @@ -173,10 +173,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "gpu_memory_flags_test", srcs = ["gpu_memory_flags_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], env = { "XLA_PYTHON_CLIENT_PREALLOCATE": "1", }, @@ -273,10 +270,7 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"}, tags = [ "config-cuda-only", @@ -290,10 +284,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "mock_gpu_test", srcs = ["mock_gpu_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], tags = [ "config-cuda-only", ], @@ -556,11 +547,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "lax_metal_test", srcs = ["lax_metal_test.py"], - disable_backends = [ - "cpu", - "gpu", - "tpu", - ], + enable_backends = ["metal"], tags = ["notap"], deps = [ "//jax:internal_test_util", @@ -649,10 +636,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "metadata_test", srcs = ["metadata_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) jax_py_test( @@ -672,10 +656,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "multi_device_test", srcs = ["multi_device_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) jax_multiplatform_test( @@ -734,10 +715,7 @@ jax_multiplatform_test( name = "polynomial_test", srcs = ["polynomial_test.py"], # No implementation of nonsymmetric Eigendecomposition. - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], shard_count = { "cpu": 10, }, @@ -753,25 +731,18 @@ jax_multiplatform_test( jax_multiplatform_test( name = "heap_profiler_test", srcs = ["heap_profiler_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) jax_multiplatform_test( name = "profiler_test", srcs = ["profiler_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) jax_multiplatform_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], - disable_backends = ["tpu"], # The following cases are disabled because they time out in Google's CI, mostly because the # CUDA kernels in Torch take a very long time to compile. disable_configs = [ @@ -779,6 +750,10 @@ jax_multiplatform_test( "gpu_a100", # Pytorch A100 build times out in Google's CI. "gpu_h100", # Pytorch H100 build times out in Google's CI. ], + enable_backends = [ + "cpu", + "gpu", + ], tags = [ "not_build:arm", # TODO(b/355237462): Re-enable once MSAN issue is addressed. @@ -1019,16 +994,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "sparse_nm_test", srcs = ["sparse_nm_test.py"], - config_tags_overrides = { - "gpu_a100": { - "ondemand": False, # Include in presubmit. - }, - }, - disable_backends = [ - "cpu", - "gpu", - "tpu", - ], + enable_backends = [], enable_configs = [ "gpu_a100", "gpu_h100", @@ -1386,13 +1352,10 @@ jax_multiplatform_test( jax_multiplatform_test( name = "experimental_rnn_test", srcs = ["experimental_rnn_test.py"], - disable_backends = [ - "tpu", - "cpu", - ], disable_configs = [ "gpu_a100", # Numerical precision problems. ], + enable_backends = ["gpu"], shard_count = 15, deps = [ "//jax:rnn", @@ -1505,10 +1468,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "fused_attention_stablehlo_test", srcs = ["fused_attention_stablehlo_test.py"], - disable_backends = [ - "tpu", - "cpu", - ], + enable_backends = ["gpu"], shard_count = { "gpu": 4, }, @@ -1542,10 +1502,7 @@ jax_py_test( jax_multiplatform_test( name = "cudnn_fusion_test", srcs = ["cudnn_fusion_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], enable_configs = [ "gpu_a100", "gpu_h100", diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 4d33e228b906..3d1348371f07 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -28,31 +28,16 @@ package( jax_generate_backend_suites() -DISABLED_BACKENDS = [ - "cpu", - "tpu", -] - -DISABLED_CONFIGS = [ - "gpu_a100", - "gpu_a100_x32", - "gpu_p100", - "gpu_p100_x32", - "gpu_pjrt_c_api", - "gpu_v100", - "gpu_x32", -] - jax_multiplatform_test( name = "gpu_test", srcs = ["gpu_test.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], enable_configs = [ "gpu_h100", "gpu_h100_2gpu", ], shard_count = 4, + tags = ["multiaccelerator"], deps = [ "//jax:mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), @@ -61,8 +46,8 @@ jax_multiplatform_test( jax_multiplatform_test( name = "matmul_test", srcs = ["matmul_test.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = ["gpu_h100"], shard_count = 5, deps = [ "//jax:mosaic_gpu", @@ -73,8 +58,8 @@ jax_multiplatform_test( jax_multiplatform_test( name = "flash_attention", srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = ["gpu_h100"], main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py", tags = ["notap"], deps = [ @@ -85,8 +70,8 @@ jax_multiplatform_test( jax_multiplatform_test( name = "flash_attention_test", srcs = ["flash_attention_test.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = ["gpu_h100"], deps = [ "//jax:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:flash_attention", diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index b1db085a8c3a..044f82067510 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -38,11 +38,9 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_p100", - "gpu_p100_x32", + enable_backends = [ + "cpu", + "tpu", ], enable_configs = [ "gpu_a100_x32", @@ -75,9 +73,6 @@ jax_multiplatform_test( "gpu_p100_x32", "gpu_h100", ], - shard_count = { - "tpu": 1, - }, deps = [ "//jax:pallas", "//jax:pallas_tpu", @@ -130,8 +125,9 @@ jax_multiplatform_test( srcs = [ "indexing_test.py", ], - disable_backends = [ - "gpu", + enable_backends = [ + "cpu", + "tpu", ], tags = [ "noasan", # Times out. @@ -154,14 +150,7 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", @@ -186,19 +175,7 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_backends = [ - "cpu", - "tpu", - ], - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_a100", - "gpu_a100_x32", - "gpu_p100", - "gpu_p100_x32", - "gpu_h100", - ], + enable_backends = [], enable_configs = [ "gpu_h100_x32", ], @@ -220,15 +197,7 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - "gpu_pjrt_c_api", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", @@ -251,15 +220,7 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - "gpu_pjrt_c_api", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", ], @@ -303,10 +264,7 @@ jax_multiplatform_test( srcs = [ "pallas_error_handling_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:pallas", "//jax:pallas_tpu", @@ -321,10 +279,7 @@ jax_multiplatform_test( srcs = [ "tpu_all_gather_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:pallas_tpu_ops", ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), @@ -335,10 +290,7 @@ jax_multiplatform_test( srcs = [ "tpu_gmm_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], shard_count = 50, tags = [ "noasan", # Times out. @@ -360,10 +312,7 @@ jax_multiplatform_test( srcs = ["tpu_pallas_test.py"], # The flag is necessary for ``pl.debug_print`` tests to work on TPU. args = ["--logtostderr"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:extend", "//jax:pallas_tpu", @@ -376,8 +325,9 @@ jax_multiplatform_test( srcs = [ "tpu_ops_test.py", ], - disable_backends = [ - "gpu", + enable_backends = [ + "cpu", + "tpu", ], deps = [ "//jax:pallas", @@ -390,10 +340,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_pallas_distributed_test", srcs = ["tpu_pallas_distributed_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:extend", "//jax:pallas_tpu", @@ -404,10 +351,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_pallas_pipeline_test", srcs = ["tpu_pallas_pipeline_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], shard_count = 5, tags = [ "noasan", # Times out. @@ -424,10 +368,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_pallas_async_test", srcs = ["tpu_pallas_async_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], tags = [ ], deps = [ @@ -438,10 +379,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_pallas_mesh_test", srcs = ["tpu_pallas_mesh_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], tags = [ "noasan", "nomsan", @@ -458,10 +396,7 @@ jax_multiplatform_test( srcs = [ "tpu_pallas_random_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:pallas", "//jax:pallas_tpu", @@ -474,10 +409,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], shard_count = 5, tags = [ "noasan", # Times out. @@ -494,10 +426,7 @@ jax_multiplatform_test( srcs = [ "tpu_splash_attention_kernel_test.py", ], - disable_backends = [ - "gpu", - "cpu", - ], + enable_backends = ["tpu"], shard_count = 24, tags = [ "noasan", # Times out. @@ -514,8 +443,9 @@ jax_multiplatform_test( srcs = [ "tpu_splash_attention_mask_test.py", ], - disable_backends = [ - "gpu", + enable_backends = [ + "cpu", + "tpu", ], deps = [ "//jax:pallas_tpu_ops", @@ -532,17 +462,7 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_backends = [ - "tpu", - ], - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_p100", - "gpu_p100_x32", - "gpu_a100", - "gpu_h100", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", @@ -565,17 +485,7 @@ jax_multiplatform_test( "ondemand": False, # Include in presubmit. }, }, - disable_backends = [ - "tpu", - ], - disable_configs = [ - "gpu_v100", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", From 5a1d0a6c2637432d49df3bf36f4c3255294be86e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 27 Sep 2024 08:52:42 -0700 Subject: [PATCH 702/702] Include the sdy MLIR dialect in jaxlib. We're seeing test failures from tests assuming that this dialect exists. But given we plan to enable it at some point, we may as well just include it in the build. The size impact is small (around 400K uncompressed). PiperOrigin-RevId: 679608092 --- jaxlib/tools/build_wheel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 52a17c451aea..3c40c2d11fb5 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -351,6 +351,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}",