Skip to content

Commit

Permalink
Merge pull request #196 from ROCm/ci-upstream-sync-82_1
Browse files Browse the repository at this point in the history
CI: 01/08/25 upstream sync
  • Loading branch information
github-actions[bot] authored Jan 8, 2025
2 parents bc06c93 + 8d94998 commit 90eab82
Show file tree
Hide file tree
Showing 75 changed files with 958 additions and 533 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
with:
repository: data-apis/array-api-tests
# TODO(jakevdp) update this to a stable release/tag when available.
ref: '1572b129c6682211abfe139e112592226c361a6c' # Latest commit as of 2024-12-04
ref: 'f7a74a685d78d98203fa991fc19a5d1fda57c212' # Latest commit as of 2025-01-07
submodules: 'true'
path: 'array-api-tests'
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum
supported version until June 2025.
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
the case of many arguments ({jax-issue}`#25214`).
Expand Down
3 changes: 1 addition & 2 deletions ci/run_pytest_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,5 @@ echo "Running GPU tests..."
"$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \
tests examples \
--deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \
--deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \
--deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \
--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric
--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric
10 changes: 5 additions & 5 deletions docs/advanced-autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

### Hessian-vector products with `jax.grad`-of-`jax.grad`

One thing you can do with higher-order {func}`jax.vmap` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)
One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)

A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)).

Expand All @@ -259,11 +259,11 @@ for any $v \in \mathbb{R}^n$.

The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.

Luckily, {func}`jax.vmap` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity:
Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity:

$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$,

where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.vmap` is efficient.
where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient.

In JAX code, you can just write this:

Expand Down Expand Up @@ -357,7 +357,7 @@ To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfw

### Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff)

JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.vmap` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.
JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.


#### JVPs in math
Expand Down Expand Up @@ -473,7 +473,7 @@ vjp :: (a -> b) -> a -> (b, CT b -> CT a)

where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`.

This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.vmap` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.

There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).

Expand Down
8 changes: 1 addition & 7 deletions docs/aot.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,8 @@ see the {ref}`export` APIs.
See the {mod}`jax.stages` documentation for more details on what functionality
the lowering and compiled functions provide.

In place of `jax.jit` above, you can also `lower(...)` the result of
{func}`jax.pmap`, as well as `pjit` and `xmap` (from
{mod}`jax.experimental.pjit` and {mod}`jax.experimental.maps` respectively). In
each case, you can `compile()` the result similarly.

All optional arguments to `jit`---such as `static_argnums`---are respected in
the corresponding lowering, compilation, and execution. Again the same goes for
`pmap`, `pjit`, and `xmap`.
the corresponding lowering, compilation, and execution.

In the example above, we can replace the arguments to `lower` with any objects
that have `shape` and `dtype` attributes:
Expand Down
55 changes: 55 additions & 0 deletions docs/gradient-checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,61 @@ print_saved_residuals(loss_checkpoint2, params, x, y)

Another policy which refers to names is `jax.checkpoint_policies.save_only_these_names`.

#### Custom policies for offload

You may consider offloading to CPU memory instead of recomputing when checkpointing to save accelerator memory. `jax.checkpoint_policies.offload_dot_with_no_batch_dims` can offload the results of matrix multiplications with no batch dimensions to the CPU.

```{code-cell}
from jax.ad_checkpoint import checkpoint
def checkpoint_offload_dot_with_no_batch_dims(self):
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
"device", "pinned_host")
@functools.partial(checkpoint, policy=policy)
def f(x):
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.sum(x)
return x
```

One of JAX's checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented through `jax.checkpoint_policies.save_and_offload_only_these_names`, which has four arguments: `names_which_can_be_saved`, `names_which_can_be_offloaded`, the offloading source, and destination. Names listed in `names_which_can_be_saved` are kept on the device, names listed in `names_which_can_be_offloaded` are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint names `y`, `z`, and `w`, `y` can be saved on the device, `z` can be offloaded to CPU memory, and `w` can be recomputed.

```{code-cell}
from jax.ad_checkpoint import checkpoint, checkpoint_name
from jax._src import test_util as jtu
def checkpoint_names_saved_offloaded_recomputed(self):
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
s = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s)
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z"],
offload_src='device', offload_dst='pinned_host')
@functools.partial(checkpoint, policy=policy)
def f(x):
def g(ys, _):
y, _ = ys
y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z")
z = z.T
w = checkpoint_name(jnp.sin(z), "w")
return (w.T, jnp.sum(w)), None
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
return scan_out
```

The code defines a function `f` that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Inside `f`, there is a nested function `g` that performs the core computations. The `jax.lax.scan` function is used to apply `g` repeatedly over the input data.

#### List of policies

The policies are:
Expand Down
30 changes: 28 additions & 2 deletions examples/ffi/src/jax_ffi_example/cpu_examples.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
Counter, CounterImpl,
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());

// --------
// Aliasing
// --------
//
// This example demonstrates how input-output aliasing works. The handler
// doesn't do anything except to check that the input and output pointers
// address the same data.

ffi::Error AliasingImpl(ffi::AnyBuffer input,
ffi::Result<ffi::AnyBuffer> output) {
if (input.element_type() != output->element_type() ||
input.element_count() != output->element_count()) {
return ffi::Error::InvalidArgument(
"The input and output data types and sizes must match.");
}
if (input.untyped_data() != output->untyped_data()) {
return ffi::Error::InvalidArgument(
"When aliased, the input and output buffers should point to the same "
"data.");
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
Aliasing, AliasingImpl,
ffi::Ffi::Bind().Arg<ffi::AnyBuffer>().Ret<ffi::AnyBuffer>());

// Boilerplate for exposing handlers to Python
NB_MODULE(_cpu_examples, m) {
m.def("registrations", []() {
Expand All @@ -111,9 +138,8 @@ NB_MODULE(_cpu_examples, m) {
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
registrations["dictionary_attr"] =
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));

registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));

registrations["aliasing"] = nb::capsule(reinterpret_cast<void *>(Aliasing));
return registrations;
});
}
6 changes: 6 additions & 0 deletions examples/ffi/src/jax_ffi_example/cpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,9 @@ def dictionary_attr(**kwargs):
def counter(index):
return jax.ffi.ffi_call(
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))


def aliasing(x):
return jax.ffi.ffi_call(
"aliasing", jax.ShapeDtypeStruct(x.shape, x.dtype),
input_output_aliases={0: 0})(x)
13 changes: 12 additions & 1 deletion examples/ffi/tests/cpu_examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import absltest, parameterized

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -91,5 +91,16 @@ def counter_fun(x):
self.assertEqual(counter_fun(0)[1], 3)


class AliasingTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")

@parameterized.parameters((jnp.linspace(0, 0.5, 10),), (jnp.int32(6),))
def test_basic(self, x):
self.assertAllClose(cpu_examples.aliasing(x), x)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
6 changes: 5 additions & 1 deletion jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
import numpy as np
Expand Down Expand Up @@ -198,7 +199,10 @@ def get_compile_options(
logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
if env_options_overrides is None:
env_options_overrides = {}
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
if xla_extension_version > 302:
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
else:
env_options_overrides['xla_gpu_graph_min_graph_size'] = '100000'

if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
Expand Down
8 changes: 3 additions & 5 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from jax._src import effects
from jax._src import compute_on
from jax._src import mesh as mesh_lib
from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton
from jax._src.partition_spec import PartitionSpec as P
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
Expand Down Expand Up @@ -1659,12 +1659,12 @@ def _maybe_modify_sharding(sharding):

new_spec = []
for s in sharding.spec:
if s is None or isinstance(s, UnconstrainedSingleton):
if s is None:
new_spec.append(s)
else:
temp_s = s[0] if isinstance(s, tuple) else s
new_spec.append(
P.UNCONSTRAINED
None
if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
return sharding.with_spec(new_spec)

Expand Down Expand Up @@ -1762,8 +1762,6 @@ def _get_shape_sharding_str(shape, spec):
for s1, s2 in zip(shape, spec):
if s2 is None:
out.append(f"{s1}")
elif isinstance(s2, UnconstrainedSingleton):
out.append(f"{s1}")
elif isinstance(s2, tuple):
ss = ','.join(s for s in s2)
out.append(f"{s1}@({ss})")
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
closed_jaxpr,
name="tmp_xla_computation",
platforms=module_context.platforms,
backend_or_name=module_context.backend_or_name,
backend=module_context.backend,
axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
)
result_sharding = _pack_result_sharding(result_shape, result_shardings)
Expand Down
31 changes: 16 additions & 15 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,38 @@
import enum
from functools import partial
import itertools
import time
from typing import Any, NamedTuple
import logging
import threading

import numpy as np
import time
from typing import Any, NamedTuple

import jax
from jax._src import api
from jax._src import array
from jax._src import basearray
from jax._src import config
from jax._src import core
from jax._src import api
from jax._src import array
from jax._src import dtypes
from jax._src import lib
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.abstract_arrays import array_types
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src import lib
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.monitoring import record_event_duration_secs
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.monitoring import record_event_duration_secs, record_event_time_span
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
SingleDeviceSharding, NamedSharding, TransferToMemoryKind,
from jax._src.sharding_impls import ( NamedSharding,
SingleDeviceSharding, TransferToMemoryKind,
is_single_device_sharding)
from jax._src.layout import Layout, DeviceLocalLayout
import numpy as np


JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
Expand Down Expand Up @@ -177,12 +176,14 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
start_time = time.time()
yield
elapsed_time = time.time() - start_time
end_time = time.time()
elapsed_time = end_time - start_time
if logger.isEnabledFor(log_priority):
logger.log(log_priority, fmt.format(
fun_name=fun_name, elapsed_time=elapsed_time))
if event is not None:
record_event_duration_secs(event, elapsed_time)
record_event_time_span(event, start_time, end_time)


def should_tuple_args(num_args: int, platform: str) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ def is_token(typ, attrs):
with ir.InsertionPoint(entry_block):
# Make a context just for lowering the dimension value computations
module_context = mlir.ModuleContext(
backend_or_name="cpu", platforms=["cpu"],
backend=None, platforms=["cpu"],
axis_context=sharding_impls.ShardingContext(0),
keepalives=[], channel_iterator=itertools.count(1),
host_callbacks=[], module=wrapped_module, context=context,
Expand Down
Loading

0 comments on commit 90eab82

Please sign in to comment.