Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: 01/08/25 upstream sync #196

Merged
merged 34 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f124232
Add activation offloading examples in a subsection
zhenying-liu Dec 18, 2024
522a8fd
consolidate the code example
zhenying-liu Dec 19, 2024
5e3a692
Merge branch 'jax-ml:main' into activation-offloading-doc
zhenying-liu Jan 5, 2025
f881f50
Update the advanced autodiff tutorial and replace some vmap with grad
Exferro Dec 12, 2024
5b80892
[Mosaic GPU] Use `num_q_heads=2` in `flash_attention.py`
andportnoy Jan 7, 2025
fdb6af8
Clean up `backend_or_name` vs. `platforms` in lowering code.
gnecula Jan 7, 2025
57c2afe
Merge pull request #25441 from Exferro:fixed_advanced_autodiff_doc
Google-ML-Automation Jan 7, 2025
4023810
[AutoPGLE] FIx PGLE kokoro test failures.
Google-ML-Automation Jan 7, 2025
8c9a539
[Pallas] Fix pallas_call lowering mutating compiler params during Tri…
justinjfu Jan 7, 2025
00c363e
Update XLA dependency to use revision
Google-ML-Automation Jan 7, 2025
62656b3
Add an example demonstrating input-output aliasing with the FFI.
dfm Nov 21, 2024
64c0f62
Sort manual axes when lowering `jax.shard_map` to `sdy.manual_computa…
ZixuanJiang Jan 7, 2025
f1777d5
Merge pull request #25042 from dfm:ffi-example-input-output-alias
Google-ML-Automation Jan 7, 2025
f6c9e87
[array api] update test suite to latest commit
jakevdp Jan 7, 2025
392a851
Increase the minimum SciPy version to 1.11.1.
hawkinsp Jan 8, 2025
7be127f
[Pallas] Improvements to core_map
sharadmv Jan 8, 2025
755d6cd
[sharding_in_types] Aval sharding under full auto mode should contain…
yashk2810 Jan 8, 2025
6d08f36
Merge pull request #25761 from jakevdp:array-api-update
Google-ML-Automation Jan 8, 2025
21fb171
Merge branch 'jax-ml:main' into activation-offloading-doc
zhenying-liu Jan 8, 2025
1bd781d
Add JAX events that have time spans, not only durations.
Google-ML-Automation Jan 8, 2025
81db321
Merge pull request #25594 from zhenying-liu:activation-offloading-doc
Google-ML-Automation Jan 8, 2025
90201ce
Removed leftover mentions of xmap from the code
superbobry Jan 8, 2025
4718121
Merge pull request #25754 from andportnoy:patch-4
Google-ML-Automation Jan 8, 2025
bf94389
[Mosaic] Use tpu::CreateMask for getX32VmaskByPaddingEnd.
WindQAQ Jan 8, 2025
e954930
[Mosaic TPU] Add support for true divide in bf16 on TPUv6
apaszke Jan 8, 2025
5fd1b2f
[Mosaic TPU] Add support for second minor broadcasts with packed types
apaszke Jan 8, 2025
f96339b
[Mosaic TPU] Be much more aggressive in inferring large 2nd minor lay…
apaszke Jan 8, 2025
51b9fe3
[JAX] Add a new jax_num_cpu_devices flag that allows the user to spec…
hawkinsp Jan 8, 2025
f1f98af
[pallas:mosaic_gpu] Fix the tests following the changes to `pl.core_map`
superbobry Jan 8, 2025
3fa5572
Port tests away from setUpClass and setUpModule to setUp alone.
hawkinsp Jan 8, 2025
0389d61
Add a unittest test extension that runs test cases in parallel using …
hawkinsp Jan 8, 2025
5c097c8
#sdy Move Shardy mesh lift inlining pass after verification.
bartchr808 Jan 8, 2025
5511949
Update XLA dependency to use revision
Google-ML-Automation Jan 8, 2025
8d94998
Use our own XLA
charleshofer Jan 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
ref: '1572b129c6682211abfe139e112592226c361a6c' # Latest commit as of 2024-12-04
ref: 'f7a74a685d78d98203fa991fc19a5d1fda57c212' # Latest commit as of 2025-01-07
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
supported version until June 2025.
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
the case of many arguments ({jax-issue}`#25214`).
Expand Down
3 changes: 1 addition & 2 deletions ci/run_pytest_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,5 @@ echo "Running GPU tests..."
"$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \
tests examples \
--deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \
--deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \
--deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \
--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric
--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric
10 changes: 5 additions & 5 deletions docs/advanced-autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

### Hessian-vector products with `jax.grad`-of-`jax.grad`

One thing you can do with higher-order {func}`jax.vmap` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)
One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)

A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)).

Expand All @@ -259,11 +259,11 @@ for any $v \in \mathbb{R}^n$.

The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.

Luckily, {func}`jax.vmap` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity:
Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity:

$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$,

where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.vmap` is efficient.
where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient.

In JAX code, you can just write this:

Expand Down Expand Up @@ -357,7 +357,7 @@ To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfw

### Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff)

JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.vmap` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.
JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.


#### JVPs in math
Expand Down Expand Up @@ -473,7 +473,7 @@ vjp :: (a -> b) -> a -> (b, CT b -> CT a)

where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`.

This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.vmap` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.

There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).

Expand Down
8 changes: 1 addition & 7 deletions docs/aot.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,8 @@ see the {ref}`export` APIs.
See the {mod}`jax.stages` documentation for more details on what functionality
the lowering and compiled functions provide.

In place of `jax.jit` above, you can also `lower(...)` the result of
{func}`jax.pmap`, as well as `pjit` and `xmap` (from
{mod}`jax.experimental.pjit` and {mod}`jax.experimental.maps` respectively). In
each case, you can `compile()` the result similarly.

All optional arguments to `jit`---such as `static_argnums`---are respected in
the corresponding lowering, compilation, and execution. Again the same goes for
`pmap`, `pjit`, and `xmap`.
the corresponding lowering, compilation, and execution.

In the example above, we can replace the arguments to `lower` with any objects
that have `shape` and `dtype` attributes:
Expand Down
55 changes: 55 additions & 0 deletions docs/gradient-checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,61 @@ print_saved_residuals(loss_checkpoint2, params, x, y)

Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`.

#### Custom policies for offload

You may consider offloading to CPU memory instead of recomputing when checkpointing to save accelerator memory. `jax.checkpoint_policies.offload_dot_with_no_batch_dims` can offload the results of matrix multiplications with no batch dimensions to the CPU.

```{code-cell}
from jax.ad_checkpoint import checkpoint

def checkpoint_offload_dot_with_no_batch_dims(self):
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
"device", "pinned_host")

@functools.partial(checkpoint, policy=policy)
def f(x):
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.sum(x)
return x
```

One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed.

```{code-cell}
from jax.ad_checkpoint import checkpoint, checkpoint_name
from jax._src import test_util as jtu

def checkpoint_names_saved_offloaded_recomputed(self):
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
s = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s)

policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
offload_src='device', offload_dst='pinned_host')

@functools.partial(checkpoint, policy=policy)
def f(x):
def g(ys, _):
y, _ = ys
y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z")
z = z.T
w = checkpoint_name(jnp.sin(z), "w")
return (w.T, jnp.sum(w)), None
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
return scan_out
```

The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data.

#### List of policies

The policies are:
Expand Down
30 changes: 28 additions & 2 deletions examples/ffi/src/jax_ffi_example/cpu_examples.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
Counter, CounterImpl,
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());

// --------
// Aliasing
// --------
//
// This example demonstrates how input-output aliasing works. The handler
// doesn't do anything except to check that the input and output pointers
// address the same data.

ffi::Error AliasingImpl(ffi::AnyBuffer input,
ffi::Result<ffi::AnyBuffer> output) {
if (input.element_type() != output->element_type() ||
input.element_count() != output->element_count()) {
return ffi::Error::InvalidArgument(
"The input and output data types and sizes must match.");
}
if (input.untyped_data() != output->untyped_data()) {
return ffi::Error::InvalidArgument(
"When aliased, the input and output buffers should point to the same "
"data.");
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
Aliasing, AliasingImpl,
ffi::Ffi::Bind().Arg<ffi::AnyBuffer>().Ret<ffi::AnyBuffer>());

// Boilerplate for exposing handlers to Python
NB_MODULE(_cpu_examples, m) {
m.def("registrations", []() {
Expand All @@ -111,9 +138,8 @@ NB_MODULE(_cpu_examples, m) {
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
registrations["dictionary_attr"] =
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));

registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));

registrations["aliasing"] = nb::capsule(reinterpret_cast<void *>(Aliasing));
return registrations;
});
}
6 changes: 6 additions & 0 deletions examples/ffi/src/jax_ffi_example/cpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ def dictionary_attr(**kwargs):
def counter(index):
return jax.ffi.ffi_call(
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))


def aliasing(x):
return jax.ffi.ffi_call(
"aliasing", jax.ShapeDtypeStruct(x.shape, x.dtype),
input_output_aliases={0: 0})(x)
13 changes: 12 additions & 1 deletion examples/ffi/tests/cpu_examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import absltest, parameterized

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -91,5 +91,16 @@ def counter_fun(x):
self.assertEqual(counter_fun(0)[1], 3)


class AliasingTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")

@parameterized.parameters((jnp.linspace(0, 0.5, 10),), (jnp.int32(6),))
def test_basic(self, x):
self.assertAllClose(cpu_examples.aliasing(x), x)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
6 changes: 5 additions & 1 deletion jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
import numpy as np
Expand Down Expand Up @@ -198,7 +199,10 @@ def get_compile_options(
logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
if env_options_overrides is None:
env_options_overrides = {}
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
if xla_extension_version > 302:
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
else:
env_options_overrides['xla_gpu_graph_min_graph_size'] = '100000'

if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
Expand Down
8 changes: 3 additions & 5 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from jax._src import effects
from jax._src import compute_on
from jax._src import mesh as mesh_lib
from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton
from jax._src.partition_spec import PartitionSpec as P
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
Expand Down Expand Up @@ -1659,12 +1659,12 @@ def _maybe_modify_sharding(sharding):

new_spec = []
for s in sharding.spec:
if s is None or isinstance(s, UnconstrainedSingleton):
if s is None:
new_spec.append(s)
else:
temp_s = s[0] if isinstance(s, tuple) else s
new_spec.append(
P.UNCONSTRAINED
None
if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
return sharding.with_spec(new_spec)

Expand Down Expand Up @@ -1762,8 +1762,6 @@ def _get_shape_sharding_str(shape, spec):
for s1, s2 in zip(shape, spec):
if s2 is None:
out.append(f"{s1}")
elif isinstance(s2, UnconstrainedSingleton):
out.append(f"{s1}")
elif isinstance(s2, tuple):
ss = ','.join(s for s in s2)
out.append(f"{s1}@({ss})")
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
closed_jaxpr,
name="tmp_xla_computation",
platforms=module_context.platforms,
backend_or_name=module_context.backend_or_name,
backend=module_context.backend,
axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
)
result_sharding = _pack_result_sharding(result_shape, result_shardings)
Expand Down
31 changes: 16 additions & 15 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,38 @@
import enum
from functools import partial
import itertools
import time
from typing import Any, NamedTuple
import logging
import threading

import numpy as np
import time
from typing import Any, NamedTuple

import jax
from jax._src import api
from jax._src import array
from jax._src import basearray
from jax._src import config
from jax._src import core
from jax._src import api
from jax._src import array
from jax._src import dtypes
from jax._src import lib
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.abstract_arrays import array_types
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src import lib
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.monitoring import record_event_duration_secs
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.monitoring import record_event_duration_secs, record_event_time_span
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
SingleDeviceSharding, NamedSharding, TransferToMemoryKind,
from jax._src.sharding_impls import ( NamedSharding,
SingleDeviceSharding, TransferToMemoryKind,
is_single_device_sharding)
from jax._src.layout import Layout, DeviceLocalLayout
import numpy as np


JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
Expand Down Expand Up @@ -177,12 +176,14 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
start_time = time.time()
yield
elapsed_time = time.time() - start_time
end_time = time.time()
elapsed_time = end_time - start_time
if logger.isEnabledFor(log_priority):
logger.log(log_priority, fmt.format(
fun_name=fun_name, elapsed_time=elapsed_time))
if event is not None:
record_event_duration_secs(event, elapsed_time)
record_event_time_span(event, start_time, end_time)


def should_tuple_args(num_args: int, platform: str) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ def is_token(typ, attrs):
with ir.InsertionPoint(entry_block):
# Make a context just for lowering the dimension value computations
module_context = mlir.ModuleContext(
backend_or_name="cpu", platforms=["cpu"],
backend=None, platforms=["cpu"],
axis_context=sharding_impls.ShardingContext(0),
keepalives=[], channel_iterator=itertools.count(1),
host_callbacks=[], module=wrapped_module, context=context,
Expand Down
Loading
Loading